# Stochastic MuZero Training for 2048

This notebook trains a Stochastic MuZero agent to play 2048.

**Supported Hardware:**
- NVIDIA Tesla P100 (16GB VRAM)
- CPU (slower, but works anywhere)

**Paper:** "Planning in Stochastic Environments with a Learned Model" (ICLR 2022)

## 1. Environment Setup

In [None]:
# ##>: Clone repository and checkout the correct branch.
# ##!: Replace with your actual repository URL before running.
REPO_URL = 'https://github.com/YOUR_USERNAME/simulate_2048.git'  # CHANGE THIS
BRANCH = 'muzero'  # The branch containing the training code

!git clone {REPO_URL}
%cd simulate_2048
!git checkout {BRANCH}
!git log --oneline -3  # Verify we're on the correct branch

In [None]:
# ##>: Install dependencies.
# ##&: Kaggle has most dependencies pre-installed, we only need gymnasium.
!pip install -q gymnasium

In [None]:
import os
import sys
from pathlib import Path

# ##>: Add project root to path.
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f'Project root: {project_root}')
print(f'Python version: {sys.version}')

## 2. Hardware Detection

Detects GPU availability and configures TensorFlow appropriately.

In [None]:
import tensorflow as tf

print(f'TensorFlow version: {tf.__version__}')
print(f'Keras version: {tf.keras.__version__}')

In [None]:
import tensorflow as tf

# ##>: Detect available hardware.
print('=' * 60)
print('HARDWARE CONFIGURATION')
print('=' * 60)

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f'GPU detected: {len(gpus)} device(s)')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
        details = tf.config.experimental.get_device_details(gpu)
        print(f'  - {details.get("device_name", "Unknown GPU")}')
    ACCELERATOR = 'GPU'
else:
    print('No GPU detected, using CPU')
    ACCELERATOR = 'CPU'

print('=' * 60)

## 3. Training Configuration

Configure hyperparameters based on detected hardware and available runtime.

In [None]:
# ##>: Training configuration.
# ##!: Adjust these based on Kaggle session limits (12h for GPU, 9h for TPU).

# ##>: Choose training mode.
TRAINING_MODE = 'medium'  # Options: 'small' (testing), 'medium' (balanced), 'full' (paper)

# ##>: Training duration.
# ##&: 50k steps is reasonable for a single Kaggle session (~4-6 hours on P100).
# ##&: For longer training, use checkpoints to resume across sessions.
NUM_TRAINING_STEPS = 50_000  # Adjust based on available time

# ##>: Checkpointing.
CHECKPOINT_DIR = '/kaggle/working/checkpoints'
CHECKPOINT_INTERVAL = 5_000

# ##>: Logging and evaluation.
LOG_INTERVAL = 100
EVAL_INTERVAL = 5_000
EVAL_GAMES = 10

# ##>: Game generation frequency.
# ##&: Higher values generate more diverse training data but slow training.
GAMES_PER_STEP = 2  # Generate 2 games per training step (every 10 steps)

print(f'Training mode: {TRAINING_MODE}')
print(f'Training steps: {NUM_TRAINING_STEPS:,}')
print(f'Checkpoint interval: {CHECKPOINT_INTERVAL:,}')

In [None]:
from reinforce.training.config import StochasticMuZeroConfig, default_2048_config, small_2048_config


def create_kaggle_config(mode: str, accelerator: str, num_steps: int) -> StochasticMuZeroConfig:
    """
    Create configuration optimized for Kaggle hardware.

    Parameters
    ----------
    mode : str
        Training mode: 'small', 'medium', or 'full'.
    accelerator : str
        Detected hardware: 'GPU' or 'CPU'.
    num_steps : int
        Total training steps (used to adjust temperature schedule).

    Returns
    -------
    StochasticMuZeroConfig
        Optimized configuration.
    """
    if mode == 'small':
        config = small_2048_config()
    elif mode == 'full':
        config = default_2048_config()
    else:  # medium
        config = StochasticMuZeroConfig(
            hidden_size=192,
            num_residual_blocks=7,
            num_simulations=75,
            replay_buffer_size=50_000,
            batch_size=512,
            training_steps=num_steps,
        )

    # ##>: Adjust batch sizes based on hardware.
    if accelerator == 'GPU':
        # ##&: P100 has 16GB; keep batch size reasonable.
        config.batch_size = min(config.batch_size, 512)
    else:
        # ##>: CPU uses smaller batches.
        config.batch_size = min(config.batch_size, 128)

    # ##>: Adjust temperature schedule for shorter training.
    # ##&: Default schedule expects 300k+ steps; scale down proportionally.
    if num_steps < 300_000:
        scale = num_steps / 300_000
        config.temperature_schedule = [
            (0, 1.0),
            (int(100_000 * scale), 0.5),
            (int(200_000 * scale), 0.1),
            (int(250_000 * scale), 0.0),
        ]

    return config


config = create_kaggle_config(TRAINING_MODE, ACCELERATOR, NUM_TRAINING_STEPS)

print('\nConfiguration:')
print(f'  Hidden size: {config.hidden_size}')
print(f'  Residual blocks: {config.num_residual_blocks}')
print(f'  Batch size: {config.batch_size}')
print(f'  MCTS simulations: {config.num_simulations}')
print(f'  Replay buffer size: {config.replay_buffer_size:,}')
print(f'  Temperature schedule: {config.temperature_schedule}')

## 4. Initialize Trainer

Create the trainer with the configured settings.

In [None]:
from reinforce.training.trainer import StochasticMuZeroTrainer

# ##>: Create checkpoint directory.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ##>: Initialize trainer.
trainer = StochasticMuZeroTrainer(
    config=config,
    checkpoint_dir=CHECKPOINT_DIR,
)

print('Trainer initialized successfully.')
print(f'Checkpoints will be saved to: {CHECKPOINT_DIR}')

## 5. Training with Monitoring

Run training with progress tracking and periodic evaluation.

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output, display


class TrainingMonitor:
    """
    Monitor training progress with live plots.

    Tracks losses, rewards, and max tiles achieved.
    """

    def __init__(self, update_interval: int = 500):
        self.update_interval = update_interval
        self.steps: list[int] = []
        self.losses: list[float] = []
        self.rewards: list[float] = []
        self.max_tiles: list[int] = []
        self.eval_steps: list[int] = []
        self.eval_rewards: list[float] = []
        self.eval_tiles: list[int] = []

    def update(self, step: int, metrics: dict) -> None:
        """Record metrics and update plots periodically."""
        self.steps.append(step)
        self.losses.append(metrics['losses']['total'])
        self.rewards.append(metrics['avg_reward'])

        if step > 0 and step % self.update_interval == 0:
            self._plot()

    def add_eval(self, step: int, reward: float, max_tile: int) -> None:
        """Record evaluation results."""
        self.eval_steps.append(step)
        self.eval_rewards.append(reward)
        self.eval_tiles.append(max_tile)

    def _plot(self) -> None:
        """Generate training progress plots."""
        clear_output(wait=True)

        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        # ##>: Loss plot.
        axes[0].plot(self.steps, self.losses, 'b-', alpha=0.7, linewidth=0.5)
        if len(self.losses) > 100:
            # ##>: Moving average for smoothing.
            window = min(100, len(self.losses) // 10)
            smoothed = [
                sum(self.losses[max(0, i - window) : i + 1]) / min(i + 1, window) for i in range(len(self.losses))
            ]
            axes[0].plot(self.steps, smoothed, 'r-', linewidth=2, label='Smoothed')
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Total Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, alpha=0.3)

        # ##>: Reward plot.
        axes[1].plot(self.steps, self.rewards, 'g-', alpha=0.7, linewidth=0.5)
        if self.eval_steps:
            axes[1].scatter(self.eval_steps, self.eval_rewards, c='red', s=50, zorder=5, label='Eval')
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Average Reward')
        axes[1].set_title('Episode Reward')
        axes[1].grid(True, alpha=0.3)
        if self.eval_steps:
            axes[1].legend()

        # ##>: Max tile plot (evaluation only).
        if self.eval_steps:
            axes[2].bar(range(len(self.eval_tiles)), self.eval_tiles, color='purple', alpha=0.7)
            axes[2].set_xticks(range(len(self.eval_tiles)))
            axes[2].set_xticklabels([f'{s // 1000}k' for s in self.eval_steps], rotation=45)
            axes[2].axhline(y=2048, color='r', linestyle='--', label='2048')
            axes[2].axhline(y=4096, color='g', linestyle='--', label='4096')
        axes[2].set_xlabel('Evaluation')
        axes[2].set_ylabel('Max Tile')
        axes[2].set_title('Best Tile Achieved')
        axes[2].grid(True, alpha=0.3)
        if self.eval_steps:
            axes[2].legend()

        plt.tight_layout()
        display(fig)
        plt.close(fig)

        # ##>: Print summary.
        print(f'\nStep {self.steps[-1]:,} / {NUM_TRAINING_STEPS:,}')
        print(f'  Loss: {self.losses[-1]:.4f}')
        print(f'  Avg Reward: {self.rewards[-1]:.1f}')
        if self.eval_tiles:
            print(f'  Best Tile: {max(self.eval_tiles)}')


monitor = TrainingMonitor(update_interval=500)

In [None]:
def training_callback(step: int, metrics: dict) -> None:
    """
    Callback for training loop.

    Updates monitor and runs periodic evaluation.
    """
    monitor.update(step, metrics)

    # ##>: Periodic evaluation.
    if step > 0 and step % EVAL_INTERVAL == 0:
        eval_results = trainer.evaluate(num_games=EVAL_GAMES)
        monitor.add_eval(step, eval_results['mean_reward'], eval_results['max_tile'])
        print(f'\n[Eval @ {step:,}] Reward: {eval_results["mean_reward"]:.1f}, Max Tile: {eval_results["max_tile"]}')

In [None]:
# ##>: Run training.
print('Starting training...')
print(f'Hardware: {ACCELERATOR}')
print(f'Steps: {NUM_TRAINING_STEPS:,}')
print()

history = trainer.train(
    num_steps=NUM_TRAINING_STEPS,
    log_interval=LOG_INTERVAL,
    checkpoint_interval=CHECKPOINT_INTERVAL,
    games_per_step=GAMES_PER_STEP,
    callback=training_callback,
)

print('\nTraining complete!')

## 6. Final Evaluation

In [None]:
print('=' * 60)
print('FINAL EVALUATION (20 games)')
print('=' * 60)

final_eval = trainer.evaluate(num_games=20)

print(f'Mean Reward: {final_eval["mean_reward"]:.1f}')
print(f'Max Reward: {final_eval["max_reward"]:.1f}')
print(f'Mean Max Tile: {final_eval["mean_max_tile"]:.1f}')
print(f'Best Tile Achieved: {final_eval["max_tile"]}')
print(f'Mean Game Length: {final_eval["mean_length"]:.1f} moves')
print('=' * 60)

## 7. Save Final Model

In [None]:
# ##>: Save final checkpoint.
final_checkpoint = f'{CHECKPOINT_DIR}/final'
trainer.learner.save_checkpoint(final_checkpoint)
print(f'Final model saved to: {final_checkpoint}')

# ##>: List all checkpoints.
import glob

checkpoints = sorted(glob.glob(f'{CHECKPOINT_DIR}/*'))
print(f'\nAll checkpoints ({len(checkpoints)}):')
for cp in checkpoints:
    print(f'  {cp}')

## 8. Training Summary

In [None]:

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# ##>: Total loss.
axes[0, 0].plot(history['total_loss'], 'b-', alpha=0.5, linewidth=0.5)
window = max(1, len(history['total_loss']) // 100)
smoothed = [
    sum(history['total_loss'][max(0, i - window) : i + 1]) / min(i + 1, window)
    for i in range(len(history['total_loss']))
]
axes[0, 0].plot(smoothed, 'r-', linewidth=2)
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

# ##>: Component losses.
axes[0, 1].plot(history['policy_loss'], label='Policy', alpha=0.7)
axes[0, 1].plot(history['value_loss'], label='Value', alpha=0.7)
axes[0, 1].plot(history['reward_loss'], label='Reward', alpha=0.7)
axes[0, 1].plot(history['chance_loss'], label='Chance', alpha=0.7)
axes[0, 1].set_title('Component Losses')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# ##>: Episode reward.
if history['episode_reward']:
    axes[1, 0].plot(history['episode_reward'], 'g-', alpha=0.5, linewidth=0.5)
    window = max(1, len(history['episode_reward']) // 50)
    smoothed = [
        sum(history['episode_reward'][max(0, i - window) : i + 1]) / min(i + 1, window)
        for i in range(len(history['episode_reward']))
    ]
    axes[1, 0].plot(smoothed, 'darkgreen', linewidth=2)
axes[1, 0].set_title('Episode Reward')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Reward')
axes[1, 0].grid(True, alpha=0.3)

# ##>: Max tile distribution.
if history['max_tile']:
    tile_counts = {}
    for tile in history['max_tile']:
        tile_counts[tile] = tile_counts.get(tile, 0) + 1
    tiles = sorted(tile_counts.keys())
    counts = [tile_counts[t] for t in tiles]
    axes[1, 1].bar([str(t) for t in tiles], counts, color='purple', alpha=0.7)
    axes[1, 1].set_title('Max Tile Distribution')
    axes[1, 1].set_xlabel('Tile Value')
    axes[1, 1].set_ylabel('Count')
    axes[1, 1].tick_params(axis='x', rotation=45)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/training_summary.png', dpi=150)
plt.show()

print(f'\nSummary plot saved to: {CHECKPOINT_DIR}/training_summary.png')

## 9. Download Checkpoints

Package checkpoints for download from Kaggle.

In [None]:
import shutil

# ##>: Create zip archive of checkpoints.
archive_path = '/kaggle/working/checkpoints_archive'
shutil.make_archive(archive_path, 'zip', CHECKPOINT_DIR)
print(f'Checkpoints archived to: {archive_path}.zip')

# ##>: Show file size.
archive_size = os.path.getsize(f'{archive_path}.zip') / (1024 * 1024)
print(f'Archive size: {archive_size:.1f} MB')

---

## Notes

**Hardware-specific optimizations:**

- **Tesla P100**: Uses default strategy with memory growth enabled. 16GB VRAM supports batch sizes up to 512.
- **CPU**: Uses smaller batch sizes for memory efficiency.

**Training tips:**

1. Start with `TRAINING_MODE = 'small'` to verify everything works.
2. Use `TRAINING_MODE = 'medium'` for a balance of quality and speed.
3. The paper uses 20M steps with `TRAINING_MODE = 'full'` - this requires multiple sessions.

**Resuming training:**

```python
trainer.load_checkpoint('/kaggle/working/checkpoints/step_10000')
```