# Makhos (Thai Checkers) Neural Network Training

This notebook demonstrates the complete ML pipeline:
1. Generate training data from AI vs AI self-play
2. Train a neural network
3. Save and export the model

## Setup

In [None]:
# Install dependencies
!pip install torch numpy

In [None]:
# Clone repository (if not already cloned)
# !git clone https://github.com/YOUR_USERNAME/makhos-expo.git
# %cd makhos-expo

In [None]:
# Install Node.js and tsx (for data generation)
# This requires a Node.js environment
# For Colab, you can install Node.js with:
!apt-get update && apt-get install -y nodejs npm
!npm install -g tsx

## Step 1: Generate Training Data

In [None]:
# Method 1: Direct function call (Recommended for Colab)
import sys
sys.path.insert(0, '.')
from gen_data import generate_data

# Quick test: 100 games (5-10 minutes)
generate_data(total_games=100, batch_size=100, time_per_move=500)

# For full dataset, use:
# generate_data(total_games=5000, batch_size=1000, time_per_move=1000)

# Method 2: Command line (Alternative)
# !python gen_data.py --total_games 100 --batch_size 100 --time_per_move 500

In [None]:
# Load and inspect the generated data
import numpy as np

data = np.load('training_data.npz')
states = data['states']
policy_targets = data['policy_targets']
legal_masks = data['legal_masks']
values = data['values']

print(f"States shape: {states.shape}")
print(f"Policy targets shape: {policy_targets.shape}")
print(f"Legal masks shape: {legal_masks.shape}")
print(f"Values shape: {values.shape}")
print(f"\nValue distribution:")
print(f"  P1 wins (+1): {(values == 1).sum()}")
print(f"  Draws (0): {(values == 0).sum()}")
print(f"  P2 wins (-1): {(values == -1).sum()}")
print(f"\nAvg legal moves: {legal_masks.sum(axis=(1,2)).mean():.1f}")

## Step 2: Train Neural Network

In [None]:
# Method 1: Direct function call (Recommended for Colab)
from train import train_model

train_model(
    data_path="training_data.npz",
    model_type="simple",
    hidden_size=512,
    epochs=50,
    batch_size=32,
    lr=0.001,
    output_dir="checkpoints"
)

# Method 2: Command line (Alternative)
# !python train.py --data training_data.npz --model_type simple --epochs 50

In [None]:
# Alternative: Train with ResNet model (slower, potentially better)
# from train import train_model
# 
# train_model(
#     data_path="training_data.npz",
#     model_type="resnet",
#     num_channels=128,
#     num_res_blocks=6,
#     epochs=50,
#     batch_size=32,
#     lr=0.001,
#     output_dir="checkpoints"
# )

## Step 3: Test the Model

In [None]:
import torch
from model import create_model

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model("simple", hidden_size=512)
checkpoint = torch.load('checkpoints/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print("Model loaded successfully!")

In [None]:
# Test on a random position from the dataset
import numpy as np
import torch

# Get a random example
idx = np.random.randint(len(states))
state = states[idx]
true_policy = policy_targets[idx]
legal_mask = legal_masks[idx]
true_value = values[idx]

# Make prediction
state_tensor = torch.from_numpy(state).unsqueeze(0).float().to(device)
with torch.no_grad():
    policy_logits, value_pred = model(state_tensor)
    
    # Apply legal move mask
    policy_logits_flat = policy_logits.view(-1)
    legal_mask_flat = torch.from_numpy(legal_mask.reshape(-1)).to(device)
    policy_logits_masked = torch.where(legal_mask_flat > 0, policy_logits_flat, torch.tensor(-1e9).to(device))
    
    policy_probs = torch.softmax(policy_logits_masked, dim=0).cpu().numpy()
    value_pred = value_pred.item()

# Get top-5 predicted moves (only legal moves)
legal_indices = np.where(legal_mask.reshape(-1) > 0)[0]
legal_probs = policy_probs[legal_indices]
top_5_legal = legal_indices[np.argsort(legal_probs)[-5:][::-1]]

print("Top 5 predicted moves (from, to):")
for i, move_idx in enumerate(top_5_legal, 1):
    from_sq = move_idx // 32
    to_sq = move_idx % 32
    prob = policy_probs[move_idx]
    print(f"  {i}. ({from_sq} -> {to_sq}) p={prob:.4f}")

# Find true move
true_move_idx = np.argmax(true_policy.reshape(-1))
true_from = true_move_idx // 32
true_to = true_move_idx % 32

print(f"\nTrue move: ({true_from} -> {true_to})")
print(f"\nPredicted value: {value_pred:.3f}")
print(f"True value: {true_value:.0f}")

## Step 4: Download Model

In [None]:
# Download the trained model files
from google.colab import files

files.download('checkpoints/best_model.pt')
files.download('checkpoints/final_model.pt')
files.download('checkpoints/final_model_scripted.pt')

## Next Steps

1. **Generate more data**: Use `--num_games 1000` or more for better results
2. **Tune hyperparameters**: Experiment with different learning rates, model sizes, etc.
3. **Implement MCTS**: Combine the neural network with Monte Carlo Tree Search
4. **Self-play training**: Use the trained model to generate new training data (AlphaZero-style)