# Dataset Exploration

Explore pre-generated MCTS datasets for training agents.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/christianwissmann85/ai-cardgame/blob/master/notebooks/03_dataset_exploration.ipynb)

In [None]:
# Colab setup (uncomment if needed)
# !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# import os; os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"
# !pip install git+https://github.com/christianwissmann85/ai-cardgame.git

In [None]:
# Setup: Change to repo root directory (required for data files)
import os
from pathlib import Path


def find_repo_root():
    path = Path.cwd()
    while path != path.parent:
        if (path / 'data' / 'cards').exists():
            return path
        path = path.parent
    return None

repo_root = find_repo_root()
if repo_root:
    os.chdir(repo_root)
    print(f"Working directory: {os.getcwd()}")
else:
    print("Warning: Could not find repo root.")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

## Load Dataset

Datasets are JSONL files (optionally gzipped) with one sample per line.

In [None]:
from essence_wars.data import MCTSDataset

# Option 1: Load from local file
# dataset = MCTSDataset("data/datasets/mcts_1k_sims100.jsonl.gz")

# Option 2: Load with convenience function
# dataset, loader = load_mcts_dataset("data/datasets/mcts_1k.jsonl.gz", batch_size=64)

# For this demo, we'll use the sample dataset included in the repo
dataset_path = Path("data/datasets")
datasets = list(dataset_path.glob("*.jsonl.gz")) + list(dataset_path.glob("*.jsonl"))

if datasets:
    print(f"Found datasets: {[d.name for d in datasets]}")
    dataset = MCTSDataset(str(datasets[0]), max_games=1000)  # Limit for demo
else:
    print("No local dataset found. Creating a small synthetic one for demo...")
    # Generate synthetic data for demo in the correct game-level format
    import gzip
    import json

    from essence_wars._core import PyGame

    demo_path = Path("/tmp/demo_dataset.jsonl.gz")
    with gzip.open(demo_path, 'wt') as f:
        for game_id in range(100):
            game = PyGame()
            game.reset(seed=game_id)

            moves = []
            while not game.is_done():
                player = game.current_player()
                obs = list(map(float, game.observe()))
                mask = list(map(float, game.action_mask()))
                action = game.greedy_action()

                # Create MCTS-style policy (greedy = 100% on chosen action)
                mcts_policy = [0.0] * 256
                mcts_policy[action] = 1.0

                move = {
                    "player": player,
                    "state_tensor": obs,
                    "action_mask": mask,
                    "action": action,
                    "mcts_policy": mcts_policy,
                    "mcts_value": 0.0,  # Placeholder
                }
                moves.append(move)
                game.step(action)

            # Determine winner from game result
            winner = 0 if game.get_reward(0) > 0 else (1 if game.get_reward(1) > 0 else -1)

            game_record = {
                "game_id": game_id,
                "deck1": "default",
                "deck2": "default",
                "winner": winner,
                "moves": moves,
            }
            f.write(json.dumps(game_record) + "\n")

    dataset = MCTSDataset(str(demo_path))

print(f"\nDataset loaded: {len(dataset)} samples")

## Dataset Structure

In [None]:
# Inspect a single sample
sample = dataset[0]

print("Sample keys:")
for key, value in sample.items():
    if hasattr(value, 'shape'):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"  {key}: {type(value).__name__}")

In [None]:
# Detailed view of fields
print("=== Observation (state tensor) ===")
print(f"Shape: {sample['obs'].shape}")
print(f"Range: [{sample['obs'].min():.3f}, {sample['obs'].max():.3f}]")

print("\n=== Action Mask ===")
print(f"Shape: {sample['mask'].shape}")
print(f"Legal actions: {sample['mask'].sum().item():.0f}")

print("\n=== Policy Target (from MCTS) ===")
print(f"Shape: {sample['policy_target'].shape}")
print(f"Sum: {sample['policy_target'].sum().item():.3f} (should be ~1.0)")
top_actions = sample['policy_target'].argsort(descending=True)[:5]
print(f"Top 5 actions: {top_actions.tolist()}")

print("\n=== Value Target ===")
print(f"Value: {sample['value_target'].item():.3f}")

## Policy Distribution Analysis

In [None]:
# Collect statistics across dataset
n_samples = min(1000, len(dataset))
policy_entropies = []
legal_action_counts = []
top1_probs = []

for i in range(n_samples):
    s = dataset[i]
    policy = s['policy_target'].numpy()
    mask = s['mask'].numpy()

    # Count legal actions
    legal_action_counts.append(mask.sum())

    # Top-1 probability
    top1_probs.append(policy.max())

    # Entropy of policy (only over legal actions)
    legal_probs = policy[mask > 0.5]
    if len(legal_probs) > 0 and legal_probs.sum() > 0:
        legal_probs = legal_probs / legal_probs.sum()  # Renormalize
        entropy = -np.sum(legal_probs * np.log(legal_probs + 1e-10))
        policy_entropies.append(entropy)

print(f"Analyzed {n_samples} samples")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Legal actions distribution
axes[0].hist(legal_action_counts, bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Number of Legal Actions')
axes[0].set_ylabel('Count')
axes[0].set_title('Legal Actions per State')
axes[0].axvline(np.mean(legal_action_counts), color='red', linestyle='--',
                label=f'Mean: {np.mean(legal_action_counts):.1f}')
axes[0].legend()

# Policy entropy distribution
axes[1].hist(policy_entropies, bins=30, edgecolor='black', alpha=0.7, color='green')
axes[1].set_xlabel('Policy Entropy')
axes[1].set_ylabel('Count')
axes[1].set_title('MCTS Policy Entropy')
axes[1].axvline(np.mean(policy_entropies), color='red', linestyle='--',
                label=f'Mean: {np.mean(policy_entropies):.2f}')
axes[1].legend()

# Top-1 probability distribution
axes[2].hist(top1_probs, bins=30, edgecolor='black', alpha=0.7, color='orange')
axes[2].set_xlabel('Top-1 Probability')
axes[2].set_ylabel('Count')
axes[2].set_title('MCTS Confidence')
axes[2].axvline(np.mean(top1_probs), color='red', linestyle='--',
                label=f'Mean: {np.mean(top1_probs):.2f}')
axes[2].legend()

plt.tight_layout()
plt.show()

## Action Distribution

In [None]:
# What actions does MCTS prefer?
action_counts = np.zeros(256)
for i in range(n_samples):
    s = dataset[i]
    action = s['action'].item() if hasattr(s['action'], 'item') else s['action']
    action_counts[action] += 1

# Group by action type
play_card = action_counts[:100].sum()
attack = action_counts[100:150].sum()
ability = action_counts[150:250].sum()
end_turn = action_counts[255]

categories = ['PlayCard\n(0-99)', 'Attack\n(100-149)', 'Ability\n(150-249)', 'EndTurn\n(255)']
counts = [play_card, attack, ability, end_turn]

plt.figure(figsize=(8, 5))
plt.bar(categories, counts, color=['#2ecc71', '#e74c3c', '#9b59b6', '#3498db'])
plt.ylabel('Count')
plt.title('Action Type Distribution')
for i, (cat, count) in enumerate(zip(categories, counts)):
    plt.text(i, count + max(counts)*0.02, f'{count/sum(counts)*100:.1f}%',
             ha='center', fontsize=10)
plt.show()

## State Space Visualization

In [None]:
# Collect state features
turns = []
p1_lives = []
p2_lives = []
p1_essences = []

for i in range(n_samples):
    obs = dataset[i]['obs'].numpy()
    turns.append(obs[0] * 30)  # Denormalize
    p1_lives.append(obs[6] * 20)  # Player 1 life
    p2_lives.append(obs[6+75] * 20)  # Player 2 life
    p1_essences.append(obs[7] * 10)  # Player 1 essence

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Turn distribution
axes[0, 0].hist(turns, bins=30, edgecolor='black', alpha=0.7)
axes[0, 0].set_xlabel('Turn Number')
axes[0, 0].set_title('Turn Distribution')

# Life totals
axes[0, 1].hist(p1_lives, bins=20, alpha=0.7, label='Player 1')
axes[0, 1].hist(p2_lives, bins=20, alpha=0.7, label='Player 2')
axes[0, 1].set_xlabel('Life Total')
axes[0, 1].set_title('Life Distribution')
axes[0, 1].legend()

# Life difference
life_diff = np.array(p1_lives) - np.array(p2_lives)
axes[1, 0].hist(life_diff, bins=30, edgecolor='black', alpha=0.7, color='purple')
axes[1, 0].set_xlabel('Life Difference (P1 - P2)')
axes[1, 0].set_title('Life Advantage')
axes[1, 0].axvline(0, color='red', linestyle='--')

# Essence distribution
axes[1, 1].hist(p1_essences, bins=10, edgecolor='black', alpha=0.7, color='gold')
axes[1, 1].set_xlabel('Essence (Mana)')
axes[1, 1].set_title('Essence Distribution')

plt.tight_layout()
plt.show()

## Using with PyTorch DataLoader

In [None]:
from torch.utils.data import DataLoader

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,  # Use 0 for notebooks
)

# Get a batch
batch = next(iter(dataloader))

print("Batch contents:")
for key, value in batch.items():
    print(f"  {key}: shape={value.shape}, dtype={value.dtype}")

In [None]:
# Training loop example (pseudocode)
print("Training loop example:")
print()
print("for batch in dataloader:")
print("    obs = batch['obs'].to(device)")
print("    mask = batch['mask'].to(device)")
print("    policy_target = batch['policy_target'].to(device)")
print("    value_target = batch['value_target'].to(device)")
print("    ")
print("    policy, value = model(obs, mask)")
print("    policy_loss = kl_div(policy, policy_target)")
print("    value_loss = mse(value, value_target)")
print("    loss = policy_loss + value_loss")

## Loading from Hugging Face

For larger datasets, load directly from Hugging Face:

In [None]:
# Example (requires huggingface_hub)
# from essence_wars.data import load_from_huggingface
#
# # Load 100k game dataset
# dataset = load_from_huggingface(
#     "christianwissmann85/essence-wars-mcts-100k",
#     split="train"
# )
# print(f"Loaded {len(dataset)} samples")

## Next Steps

- **04_behavioral_cloning.ipynb** - Train a neural network on this data
- **05_alphazero_training.ipynb** - Fine-tune with self-play