# Stochastic MuZero Training for 2048 (JAX)

This notebook trains a Stochastic MuZero agent to play 2048 using JAX.

**Supported Hardware:**
- NVIDIA Tesla P100/T4 (Kaggle GPU)
- TPU v3-8 (Kaggle TPU)
- CPU (slower, but works anywhere)

**Paper:** "Planning in Stochastic Environments with a Learned Model" (ICLR 2022)

## 1. Environment Setup

In [None]:
# ##>: Clone repository and checkout the correct branch.
# ##!: Replace with your actual repository URL before running.
REPO_URL = 'https://github.com/YOUR_USERNAME/simulate_2048.git'  # CHANGE THIS
BRANCH = 'muzero'  # The branch containing the training code

!git clone {REPO_URL}
%cd simulate_2048
!git checkout {BRANCH}
!git log --oneline -3  # Verify we're on the correct branch

In [None]:
# ##>: Install JAX dependencies.
# ##&: Kaggle GPU has CUDA, so install jax[cuda12].
!pip install -q jax[cuda12] flax optax mctx chex orbax-checkpoint

In [None]:
import os
import sys
from pathlib import Path

# ##>: Add project root to path.
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f'Project root: {project_root}')
print(f'Python version: {sys.version}')

## 2. Hardware Detection

Detects GPU/TPU availability and configures JAX appropriately.

In [None]:
import jax
import jax.numpy as jnp

print(f'JAX version: {jax.__version__}')
print(f'Available devices: {jax.devices()}')
print(f'Default backend: {jax.default_backend()}')

In [None]:
# ##>: Detect available hardware.
print('=' * 60)
print('HARDWARE CONFIGURATION')
print('=' * 60)

devices = jax.devices()
backend = jax.default_backend()

if backend == 'gpu':
    print(f'GPU detected: {len(devices)} device(s)')
    for i, dev in enumerate(devices):
        print(f'  - Device {i}: {dev}')
    ACCELERATOR = 'GPU'
elif backend == 'tpu':
    print(f'TPU detected: {len(devices)} core(s)')
    ACCELERATOR = 'TPU'
else:
    print('No GPU/TPU detected, using CPU')
    ACCELERATOR = 'CPU'

print('=' * 60)

## 3. Training Configuration

Configure hyperparameters based on detected hardware and available runtime.

In [None]:
# ##>: Training configuration.
# ##!: Adjust these based on Kaggle session limits (12h for GPU, 9h for TPU).

# ##>: Choose training mode.
TRAINING_MODE = 'small'  # Options: 'tiny' (testing), 'small' (balanced), 'full' (paper)

# ##>: Training duration.
# ##&: 10k steps is reasonable for a single Kaggle session (~2-4 hours on P100).
NUM_TRAINING_STEPS = 10_000  # Adjust based on available time

# ##>: Checkpointing.
CHECKPOINT_DIR = '/kaggle/working/checkpoints'

# ##>: Random seed for reproducibility.
SEED = 42

print(f'Training mode: {TRAINING_MODE}')
print(f'Training steps: {NUM_TRAINING_STEPS:,}')
print(f'Checkpoint directory: {CHECKPOINT_DIR}')

In [None]:
from reinforce.training.config import TrainConfig, default_config, small_config, tiny_config


def create_kaggle_config(mode: str, accelerator: str, num_steps: int) -> TrainConfig:
    """
    Create configuration optimized for Kaggle hardware.

    Parameters
    ----------
    mode : str
        Training mode: 'tiny', 'small', or 'full'.
    accelerator : str
        Detected hardware: 'GPU', 'TPU', or 'CPU'.
    num_steps : int
        Total training steps.

    Returns
    -------
    TrainConfig
        Optimized configuration.
    """
    if mode == 'tiny':
        base = tiny_config()
    elif mode == 'full':
        base = default_config()
    else:  # small
        base = small_config()

    # ##>: Adjust for hardware.
    if accelerator == 'GPU':
        batch_size = min(base.batch_size, 512)
    elif accelerator == 'TPU':
        batch_size = min(base.batch_size, 1024)
    else:
        batch_size = min(base.batch_size, 128)

    # ##>: Create new config with adjusted parameters.
    return TrainConfig(
        observation_shape=base.observation_shape,
        action_size=base.action_size,
        codebook_size=base.codebook_size,
        hidden_size=base.hidden_size,
        num_residual_blocks=base.num_residual_blocks,
        num_simulations=base.num_simulations,
        discount=base.discount,
        dirichlet_alpha=base.dirichlet_alpha,
        dirichlet_fraction=base.dirichlet_fraction,
        pb_c_init=base.pb_c_init,
        pb_c_base=base.pb_c_base,
        replay_buffer_size=base.replay_buffer_size,
        min_buffer_size=base.min_buffer_size,
        max_trajectory_length=base.max_trajectory_length,
        batch_size=batch_size,
        num_unroll_steps=base.num_unroll_steps,
        td_steps=base.td_steps,
        td_lambda=base.td_lambda,
        learning_rate=base.learning_rate,
        weight_decay=base.weight_decay,
        max_grad_norm=base.max_grad_norm,
        warmup_steps=base.warmup_steps,
        training_steps=num_steps,
        checkpoint_interval=base.checkpoint_interval,
        log_interval=base.log_interval,
        eval_interval=base.eval_interval,
        eval_games=base.eval_games,
        seed=SEED,
    )


config = create_kaggle_config(TRAINING_MODE, ACCELERATOR, NUM_TRAINING_STEPS)

print('\nConfiguration:')
print(f'  Hidden size: {config.hidden_size}')
print(f'  Residual blocks: {config.num_residual_blocks}')
print(f'  Batch size: {config.batch_size}')
print(f'  MCTS simulations: {config.num_simulations}')
print(f'  Replay buffer size: {config.replay_buffer_size:,}')

## 4. Initialize Trainer

Create the JAX-based trainer with the configured settings.

In [None]:
from reinforce.training.trainer import Trainer

# ##>: Create checkpoint directory.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ##>: Initialize trainer.
trainer = Trainer(
    config=config,
    checkpoint_dir=CHECKPOINT_DIR,
)

# ##>: Initialize training state.
trainer.initialize(seed=SEED)

print('Trainer initialized successfully.')
print(f'Checkpoints will be saved to: {CHECKPOINT_DIR}')

## 5. Fill Replay Buffer

Generate initial games to populate the replay buffer before training.

In [None]:
print('Generating initial games for replay buffer...')
print(f'Target: {config.min_buffer_size} trajectories')
print()

trainer.fill_buffer(show_progress=True)

## 6. Training

Run the main training loop with progress tracking.

In [None]:
print('Starting training...')
print(f'Hardware: {ACCELERATOR}')
print(f'Steps: {NUM_TRAINING_STEPS:,}')
print()

results = trainer.train(
    num_steps=NUM_TRAINING_STEPS,
    show_progress=True,
)

print('\n' + '=' * 50)
print('Training Complete')
print('=' * 50)
print(f'Total steps: {results["total_steps"]}')
print(f'Total time: {results["total_time_seconds"]:.1f}s')
print(f'Steps/second: {results["steps_per_second"]:.2f}')

## 7. Final Evaluation

In [None]:
print('=' * 60)
print('FINAL EVALUATION (20 games)')
print('=' * 60)

final_eval = trainer.evaluate(num_games=20, show_progress=True)

print(f'\nMean Reward: {final_eval["mean_reward"]:.1f}')
print(f'Max Reward: {final_eval["max_reward"]:.1f}')
print(f'Mean Max Tile: {final_eval["mean_max_tile"]:.1f}')
print(f'Best Tile Achieved: {final_eval["max_tile"]}')
print(f'Mean Game Length: {final_eval["mean_length"]:.1f} moves')
print('=' * 60)

## 8. Save Final Model

In [None]:
# ##>: Save final checkpoint.
trainer.save_checkpoint()
print(f'Final checkpoint saved.')

# ##>: List all checkpoints.
import glob

checkpoints = sorted(glob.glob(f'{CHECKPOINT_DIR}/*'))
print(f'\nAll checkpoints ({len(checkpoints)}):')  
for cp in checkpoints:
    print(f'  {cp}')

## 9. Download Checkpoints

Package checkpoints for download from Kaggle.

In [None]:
import shutil

# ##>: Create zip archive of checkpoints.
archive_path = '/kaggle/working/checkpoints_archive'
shutil.make_archive(archive_path, 'zip', CHECKPOINT_DIR)
print(f'Checkpoints archived to: {archive_path}.zip')

# ##>: Show file size.
archive_size = os.path.getsize(f'{archive_path}.zip') / (1024 * 1024)
print(f'Archive size: {archive_size:.1f} MB')

---

## Notes

**Hardware-specific optimizations:**

- **Tesla P100/T4**: Uses CUDA backend with optimized batch sizes.
- **TPU**: Uses TPU backend with larger batch sizes for better utilization.
- **CPU**: Uses smaller batch sizes for memory efficiency.

**Training tips:**

1. Start with `TRAINING_MODE = 'tiny'` to verify everything works.
2. Use `TRAINING_MODE = 'small'` for a balance of quality and speed.
3. The paper uses 20M steps with `TRAINING_MODE = 'full'` - this requires multiple sessions.

**JAX advantages:**

- JIT compilation for faster execution after warmup
- Automatic vectorization with vmap
- Native TPU support
- Functional programming paradigm for better reproducibility