# Behavioral Cloning: Train Your First Agent

Train a neural network to imitate MCTS policy using supervised learning.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/christianwissmann85/ai-cardgame/blob/master/notebooks/04_behavioral_cloning.ipynb)

In [None]:
# Colab setup (uncomment if needed)
# !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# import os; os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"
# !pip install git+https://github.com/christianwissmann85/ai-cardgame.git

In [None]:
# Setup: Change to repo root directory (required for data files)
import os
from pathlib import Path


def find_repo_root():
    path = Path.cwd()
    while path != path.parent:
        if (path / 'data' / 'cards').exists():
            return path
        path = path.parent
    return None

repo_root = find_repo_root()
if repo_root:
    os.chdir(repo_root)
    print(f"Working directory: {os.getcwd()}")
else:
    print("Warning: Could not find repo root.")

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm

## 1. Load Dataset

In [None]:
from essence_wars.data import MCTSDataset

# Load dataset (use your own path or download from HuggingFace)
dataset_path = Path("data/datasets")
datasets = list(dataset_path.glob("*.jsonl.gz")) + list(dataset_path.glob("*.jsonl"))

if datasets:
    dataset = MCTSDataset(str(datasets[0]))
    print(f"Loaded {len(dataset)} samples from {datasets[0].name}")
else:
    print("No dataset found. Generating demo data...")
    # Generate demo data in the correct game-level format
    import gzip
    import json

    from essence_wars._core import PyGame

    demo_path = Path("/tmp/demo_dataset.jsonl.gz")
    with gzip.open(demo_path, 'wt') as f:
        for game_id in range(500):  # More games for training
            game = PyGame()
            game.reset(seed=game_id)

            moves = []
            while not game.is_done():
                player = game.current_player()
                obs = list(map(float, game.observe()))
                mask = list(map(float, game.action_mask()))
                action = game.greedy_action()

                # Create MCTS-style policy (greedy = 100% on chosen action)
                mcts_policy = [0.0] * 256
                mcts_policy[action] = 1.0

                move = {
                    "player": player,
                    "state_tensor": obs,
                    "action_mask": mask,
                    "action": action,
                    "mcts_policy": mcts_policy,
                    "mcts_value": 0.0,
                }
                moves.append(move)
                game.step(action)

            # Determine winner from game result
            winner = 0 if game.get_reward(0) > 0 else (1 if game.get_reward(1) > 0 else -1)

            game_record = {
                "game_id": game_id,
                "deck1": "default",
                "deck2": "default",
                "winner": winner,
                "moves": moves,
            }
            f.write(json.dumps(game_record) + "\n")

    dataset = MCTSDataset(str(demo_path))
    print(f"Generated {len(dataset)} samples")

In [None]:
# Train/val split
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)

## 2. Define the Model

We use the same architecture as AlphaZero: shared backbone with policy and value heads.

In [None]:
# Import the official AlphaZeroNetwork for compatibility with checkpoints
from essence_wars.agents.networks import AlphaZeroNetwork

# You can also view the architecture:
# AlphaZeroNetwork uses:
# - Input projection (Linear -> ReLU)
# - Residual tower (N residual blocks)
# - Policy head (Linear -> ReLU -> Linear)
# - Value head (Linear -> ReLU -> Linear -> Tanh)

In [None]:
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = AlphaZeroNetwork(
    obs_dim=326,
    action_dim=256,
    hidden_dim=256,
    num_blocks=4,
).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

## 3. Training Loop

In [None]:
def compute_loss(model, batch, device):
    """Compute BC loss: KL divergence for policy, MSE for value."""
    obs = batch['obs'].to(device)
    mask = batch['mask'].to(device).bool()
    policy_target = batch['policy_target'].to(device)
    value_target = batch['value_target'].to(device)

    # Forward pass
    logits, value = model(obs, mask)

    # Policy loss: KL divergence with proper masking
    log_probs = F.log_softmax(logits, dim=-1)
    mask_float = mask.float()

    # Mask target and renormalize
    masked_target = policy_target * mask_float
    target_sum = masked_target.sum(dim=-1, keepdim=True).clamp(min=1e-8)
    masked_target = masked_target / target_sum

    # KL divergence over legal actions
    policy_loss = -(masked_target * log_probs * mask_float).sum(dim=-1).mean()

    # Value loss: MSE
    value_loss = F.mse_loss(value.squeeze(-1), value_target)

    # Total loss
    total_loss = policy_loss + value_loss

    # Metrics
    with torch.no_grad():
        pred_actions = logits.argmax(dim=-1)
        target_actions = policy_target.argmax(dim=-1)
        accuracy = (pred_actions == target_actions).float().mean()

    return total_loss, policy_loss.item(), value_loss.item(), accuracy.item()

In [None]:
# Training config
epochs = 10
lr = 1e-3

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

# Training history
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': [],
    'policy_loss': [], 'value_loss': [],
}

In [None]:
# Training loop
for epoch in range(epochs):
    # Training
    model.train()
    train_losses, train_accs = [], []
    policy_losses, value_losses = [], []

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
        optimizer.zero_grad()
        loss, pol_loss, val_loss, acc = compute_loss(model, batch, device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_losses.append(loss.item())
        train_accs.append(acc)
        policy_losses.append(pol_loss)
        value_losses.append(val_loss)

    # Validation
    model.eval()
    val_losses, val_accs = [], []

    with torch.no_grad():
        for batch in val_loader:
            loss, _, _, acc = compute_loss(model, batch, device)
            val_losses.append(loss.item())
            val_accs.append(acc)

    scheduler.step()

    # Record history
    history['train_loss'].append(np.mean(train_losses))
    history['val_loss'].append(np.mean(val_losses))
    history['train_acc'].append(np.mean(train_accs))
    history['val_acc'].append(np.mean(val_accs))
    history['policy_loss'].append(np.mean(policy_losses))
    history['value_loss'].append(np.mean(value_losses))

    print(f"Epoch {epoch+1}: "
          f"train_loss={history['train_loss'][-1]:.4f}, "
          f"val_loss={history['val_loss'][-1]:.4f}, "
          f"train_acc={history['train_acc'][-1]:.1%}, "
          f"val_acc={history['val_acc'][-1]:.1%}")

## 4. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Total Loss')
axes[0].legend()

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Policy Accuracy')
axes[1].legend()

# Loss breakdown
axes[2].plot(history['policy_loss'], label='Policy')
axes[2].plot(history['value_loss'], label='Value')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Loss')
axes[2].set_title('Loss Components')
axes[2].legend()

plt.tight_layout()
plt.show()

## 5. Evaluate Against Baselines

In [None]:
from essence_wars.benchmark import EssenceWarsBenchmark, NeuralAgent


# Create agent from trained model
class TrainedAgent:
    def __init__(self, model, device, name="BC-Agent"):
        self.model = model
        self.device = device
        self._name = name
        self.model.eval()

    @property
    def name(self):
        return self._name

    def select_action(self, obs, mask):
        with torch.no_grad():
            obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
            mask_t = torch.from_numpy(mask).bool().unsqueeze(0).to(self.device)
            logits, _ = self.model(obs_t, mask_t)
            return logits.argmax(dim=-1).item()

    def reset(self):
        pass

agent = TrainedAgent(model, device, name="BC-Trained")

In [None]:
# Run benchmark (quick evaluation)
benchmark = EssenceWarsBenchmark(games_per_opponent=20, verbose=True)
results = benchmark.evaluate(agent, baselines=["random", "greedy"])

print("\n" + "=" * 50)
print(results.summary())

## 6. Save the Model

In [None]:
# Save checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'args': {
        'hidden_dim': 256,
        'num_blocks': 4,
    },
    'history': history,
    'final_val_accuracy': history['val_acc'][-1],
}

save_path = Path("bc_model.pt")
torch.save(checkpoint, save_path)
print(f"Model saved to {save_path}")

In [None]:
# Load and verify
loaded_agent = NeuralAgent.from_checkpoint(str(save_path), device=str(device))
print(f"Loaded agent: {loaded_agent.name}")

## Next Steps

Your BC model is a great starting point! To improve further:

1. **More data**: Train on larger datasets (10k-100k games)
2. **Fine-tuning**: Use AlphaZero self-play to improve beyond the MCTS teacher
3. **Hyperparameters**: Try larger models, different learning rates

Continue with **05_alphazero_training.ipynb** for self-play training!