<a href="https://colab.research.google.com/github/sile16/bgai/blob/main/bg_training_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hello World, TurboZero Backgammon üèÅ

In [1]:
# prompt: ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: '/tmp/pgx'

#%pip install --upgrade pip
#%pip install -U git+https://github.com/sile16/pgx.git@master



In [2]:
import pgx
print(f"PGX Version {pgx.__version__}")

PGX Version 2.5.10


In [3]:
#%pip install git+https://github.com/sile16/turbozero.git

In [4]:
import jax
print("Jax Version: ",jax.__version__)
#jax.config.update('jax_platform_name', 'gpu')
from jax.lib import xla_bridge
from prompt_toolkit import HTML
print("Default backend:", jax.default_backend())

import pgx
import pgx.backgammon as bg

print(f"Jax Version {jax.__version__}")


env = bg.Backgammon(short_game=True)
print(env.simple_doubles)
print(env.num_actions)
print(env.stochastic_action_probs)

# create key
key = jax.random.PRNGKey(0)
state = env.init(key)
from IPython.display import HTML
display(HTML(state.to_svg()))





Jax Version:  0.6.2
Default backend: gpu
Jax Version 0.6.2
False
156
[0.02777778 0.02777778 0.02777778 0.02777778 0.02777778 0.02777778
 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556
 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556
 0.05555556 0.05555556 0.05555556]


In [5]:
from core.types import StepMetadata
%psource StepMetadata

@chex.dataclass(frozen=[38;5;28;01mTrue[39;00m)
[38;5;28;01mclass[39;00m StepMetadata:
    [33m"""Metadata for a step in the environment.[39m
[33m    - `rewards`: rewards received by the players[39m
[33m    - `action_mask`: mask of valid actions[39m
[33m    - `terminated`: whether the environment is terminated[39m
[33m    - `cur_player_id`: current player id[39m
[33m    - `step`: step number[39m
[33m    """[39m
    rewards: chex.Array
    action_mask: chex.Array
    terminated: bool
    cur_player_id: int
    step: int


In [6]:
import chex
from typing import Tuple

def step_fn(state: bg.State, action: int, key: chex.PRNGKey) -> Tuple[bg.State, StepMetadata]:
    """Combined step function for backgammon environment that handles both deterministic and stochastic actions."""
    # print(f"[DEBUG-BG_STEP-{time.time()}] Called with state (stochastic={state._is_stochastic}), action={action}") # Optional debug

    # Handle stochastic vs deterministic branches
    def stochastic_branch(operand):
        s, a, _ = operand # state, action, key (key ignored for stochastic step)
        # Use env instance captured by closure (assuming env is accessible in this scope)
        return env.stochastic_step(s, a)

    def deterministic_branch(operand):
        s, a, k = operand # state, action, key
        # Use env instance captured by closure
        return env.step(s, a, k)

    # Use conditional to route to the appropriate branch
    # The key is only needed for the deterministic branch
    new_state = jax.lax.cond(
        state._is_stochastic,
        stochastic_branch,
        deterministic_branch,
        (state, action, key) # Pass all required operands
    )

    # Create standard metadata
    metadata = StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
        step=new_state._step_count
    )

    return new_state, metadata

In [7]:
def init_fn(key):
    """Initializes a new environment state."""
    state = env.init(key)
    # No need to force non-stochastic, let the environment handle it
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
        step=state._step_count
    )

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn

# Define a dense residual block for vector inputs.
class ResidualDenseBlock(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        residual = x
        x = nn.Dense(self.features)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        x = nn.LayerNorm()(x)
        return nn.relu(x + residual)

# Updated ResNet-style network with 6-way value head for backgammon outcomes.
class ResNetTurboZero(nn.Module):
    """ResNet-style network with 6-way value head for backgammon outcomes.
    
    Value head outputs logits for 6 outcomes:
    [win, gammon_win, backgammon_win, loss, gammon_loss, backgammon_loss]
    """
    num_actions: int       # e.g. 156 for backgammon
    num_hidden: int = 256  # Hidden layer dimension (increased from 128)
    num_blocks: int = 6    # Number of residual blocks (increased from 2)
    value_head_out_size: int = 6  # 6-way outcome distribution

    @nn.compact
    def __call__(self, x, train: bool = False):
        # Initial projection.
        x = nn.Dense(self.num_hidden)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)

        # Process through a series of residual blocks.
        for _ in range(self.num_blocks):
            x = ResidualDenseBlock(self.num_hidden)(x)

        # Policy head: project features to logits over possible actions.
        policy_logits = nn.Dense(self.num_actions)(x)

        # 6-way value head: outputs logits, converted to probs by loss fn.
        value_logits = nn.Dense(self.value_head_out_size)(x)
        return policy_logits, value_logits


# Larger model for better learning capacity
resnet_model = ResNetTurboZero(env.num_actions, num_hidden=256, num_blocks=6)

In [9]:
def state_to_nn_input(state):
    return state.observation

In [10]:
from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.stochastic_mcts import StochasticMCTS
import jax.numpy as jnp

# Training evaluator: StochasticMCTS using NN
# Reduced iterations for faster collection (was 300, now 100)
evaluator = StochasticMCTS(
    eval_fn=make_nn_eval_fn(resnet_model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=100,  # Reduced from 300 for faster collection
    max_nodes=400,       # Reduced from 500
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=1.0,
)

In [11]:
# Test evaluator: more iterations for better evaluation quality
evaluator_test = StochasticMCTS(
    eval_fn=make_nn_eval_fn(resnet_model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=100,  # More iterations for testing
    max_nodes=400,
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0,  # Greedy for testing
)

In [None]:
from core.evaluators.evaluation_fns import make_nn_eval_fn_no_params_callable
import chex
import sys
from pathlib import Path

# Add bgai to path (parent directory of notebooks)
bgai_root = Path().resolve().parent
if str(bgai_root) not in sys.path:
    sys.path.insert(0, str(bgai_root))

# --- Import pip count eval from bgai (now returns 6-way value logits) ---
from bgai.bgevaluators import backgammon_pip_count_eval


# Test evaluator: Regular MCTS using pip count
pip_count_mcts_evaluator_test = StochasticMCTS(  # optimizes for moves
    eval_fn=backgammon_pip_count_eval, # Use pip count eval fn (now returns 6-way logits)
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=100, # Give it slightly more iterations maybe
    max_nodes=200,
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0 # Deterministic action selection for testing
)

# --- GnuBG Evaluator (for testing against strong baseline) ---
from bgai.gnubg_evaluator import GnubgEvaluator

gnubg_evaluator = GnubgEvaluator(env=env)

In [13]:
from core.memory.replay_memory import EpisodeReplayBuffer

# Larger replay buffer for more diverse training data
replay_memory = EpisodeReplayBuffer(capacity=4000)  # Increased from 2000

## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The `Trainer` takes many parameters, so let's walk through them all:
* `batch_size`: # of parallel environments used to collect self-play games
* `train_batch_size`: size of minibatch used during training step
* `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. This is used to populate the replay memory with some initial samples
* `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play per epoch
* `train_steps_per_epoch`: # of train steps per epoch
* `nn`: neural network (`linen.Module`)
* `loss_fn`: loss function used for training, we use a provided default loss which implements the loss function used in the `AlphaZero` paper
* `optimizer`: an `optax` optimizer used for training
* `evaluator`: the `Evaluator` to use during self-play, we initialized ours using `AlphaZero(MCTS)`
* `memory_buffer`: the memory buffer used to store samples from self-play games, we  initialized ours using `EpisodeReplayBuffer`
* `max_episode_steps`: maximum number of steps/turns to allow before truncating an episode
* `env_step_fn`: environment step function (we defined ours above)
* `env_init_fn`: environment init function (we defined ours above)
* `state_to_nn_input_fn`: function to convert environment state to nn input (we defined ours above)
* `testers`: any number of `Tester`s, used to evaluate a given model and take their own parameters. We'll use the two evaluators defined above to initialize two Testers.
* `evaluator_test`: (Optional) Evaluator used within Testers. By default used `evaluator`, but sometimes you may want to test with a larger MCTS iteration budget for example, or a different move sampling temperature
* `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.)
* `extract_model_params_fn`: (Optional) in special cases we need to define how to extract all model parameters from a flax `TrainState`. The default function handles BatchNorm, but if another special-case technique applied across batches is used (e.g. Dropout) we would need to define a function to extract the appropriate parameters. You usually won't need to define this!
* `wandb_project_name`: (Optional) Weights and Biases project name. You will be prompted to login if a name is provided. If a name is provided, a run will be initialized and loss and other metrics will be logged to the given wandb project.
* `ckpt_dir`: (Optional) directory to store checkpoints in, by default this is set to `/tmp/turbozero_checkpoints`
* `max_checkpoints`: (Optional) maximum number of most-recent checkpoints to retain (default: 2)
* `num_devices`: (Optional) number of hardware accelerators (GPUs/TPUs) to use. If not given, all available hardware accelerators are used
* `wandb_run`: (Optional) continues from an initialized `wandb` run if provided, otherwise a new one is initialized
* `extra_wandb_config`: (Optional) any extra metadata to store in the `wandb` run config

A training epoch is comprised of M collection steps, followed by N training steps sampling minibatches from replay memory. Optionally, any number of Testers evaluate the current model. At the end of each epoch, a checkpoint is saved.

If you are using one or more GPUs (reccommended), TurboZero by default will run on all your available hardware.

In [None]:
from functools import partial
from core.testing.two_player_baseline import TwoPlayerBaseline
from core.training.loss_fns import az_default_loss_fn
from core.training.stochastic_train import StochasticTrainer
from core.training.train import Trainer
import optax

# =============================================================================
# TRAINING CONFIGURATION FOR RTX 4090 - 20 EPOCHS (~6-12 hours)
# =============================================================================
# Key changes from original:
# - batch_size: 64 -> 256 (more parallelism, 4090 has 24GB)
# - train_batch_size: 32 -> 128 (better GPU utilization)
# - collection_steps_per_epoch: 2048 -> 4096 (more data per epoch)
# - train_steps_per_epoch: 1024 -> 2048 (more training per epoch)
# - MCTS iterations: 300 -> 100 (faster collection, still good policy)
# - Neural net: 2 blocks -> 6 blocks, 128 -> 256 hidden (more capacity)
# - Replay buffer: 2000 -> 4000 (more diversity)
# - Temperature decay: 1.0 -> 0.2 over training (exploration -> exploitation)
# =============================================================================

trainer = StochasticTrainer(
    batch_size=256,                   # Increased from 64 (4090 can handle it)
    train_batch_size=128,             # Increased from 32
    warmup_steps=0,
    collection_steps_per_epoch=4096,  # Increased from 2048
    train_steps_per_epoch=2048,       # Increased from 1024
    nn=resnet_model,
    loss_fn=partial(az_default_loss_fn, l2_reg_lambda=1e-4),
    optimizer=optax.adam(3e-4),       # Slightly higher LR for faster convergence
    evaluator=evaluator,
    memory_buffer=replay_memory,
    max_episode_steps=500,            # Reduced from 1000 (games rarely go this long)
    env_step_fn=step_fn,
    env_init_fn=init_fn,
    state_to_nn_input_fn=state_to_nn_input,
    ckpt_dir="/tmp/ckpts",
    testers=[
        TwoPlayerBaseline(
            num_episodes=64,          # More episodes for reliable eval
            baseline_evaluator=pip_count_mcts_evaluator_test,
            name='pip_count_baseline'
        ),
        TwoPlayerBaseline(
            num_episodes=16,          # Fewer episodes (gnubg is slower, not batched)
            baseline_evaluator=gnubg_evaluator,
            name='gnubg_baseline'
        ),
    ],
    evaluator_test=evaluator_test,
    data_transform_fns=[],
    wandb_project_name="bgai-training"  # Enable wandb logging
)

# =============================================================================
# TEMPERATURE DECAY SCHEDULE
# =============================================================================
# Decay from 1.0 (exploration) to 0.2 (mostly exploitation) over training
# This helps early training explore diverse moves, then focus on best moves later
# =============================================================================
NUM_EPOCHS = 20
COLLECTION_STEPS_PER_EPOCH = 4096
BATCH_SIZE = 256
TOTAL_STEPS = NUM_EPOCHS * COLLECTION_STEPS_PER_EPOCH * BATCH_SIZE

def temperature_schedule(step):
    """Linear decay from 1.0 to 0.2 over training."""
    start_temp = 1.0
    end_temp = 0.2
    progress = min(step / TOTAL_STEPS, 1.0)
    return start_temp - (start_temp - end_temp) * progress

trainer.set_temp_fn(temperature_schedule)
print(f"Temperature schedule: 1.0 -> 0.2 over {TOTAL_STEPS:,} steps")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msile16[0m ([33msile16-self[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# Run training for 20 epochs, evaluate every 2 epochs for frequent feedback
output = trainer.train_loop(seed=0, num_epochs=20, eval_every=2)

Training


2025-11-29 15:59:31.798522: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-29 15:59:31.798564: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-29 15:59:31.798581: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-29 15:59:31.798604: W external/xla/xla/service/gpu/au

Training Done
Epoch 0: {'grad_norm': '8.3523', 'loss': '2.4573', 'policy_accuracy': '0.4971', 'policy_loss': '1.1759', 'value_loss': '0.6902'}
Testing
Epoch 0: {'pip_count_baseline_avg_outcome': '-0.2188'}


ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/home/sile/.pyenv/versions/3.12.10/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7b852af31400> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/home/sile/.pyenv/versions/3.12.10/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7b852af31400> is already entered
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-2' coro=<_async_in_context.<locals>.run_in_context() done, defined at /home/sile/.pyenv/versions/3.12.10/lib/python3.12/site-packages/ipykernel/utils.py:57> wait_for=<Task pending name='T

Epoch 0: {'gnubg_baseline_avg_outcome': '0.6250'}
Temperature: 0.96
Collecting self-play games
Training


ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/home/sile/.pyenv/versions/3.12.10/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7b852af31400> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/home/sile/.pyenv/versions/3.12.10/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7b852af31400> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/home/sile/.pyenv/versions/3.12.10/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeE

Training Done
Epoch 1: {'grad_norm': '5.5510', 'loss': '2.1993', 'policy_accuracy': '0.5562', 'policy_loss': '1.0733', 'value_loss': '0.6478'}
Temperature: 0.9199999999999999
Collecting self-play games
Training


ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-31' coro=<Kernel.shell_main() running at /home/sile/.pyenv/versions/3.12.10/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-32' coro=<_async_in_context.<locals>.run_in_context() running at /home/sile/.pyenv/versions/3.12.10/lib/python3.12/site-packages/ipykernel/utils.py:60> wait_for=<Task pending name='Task-980' coro=<Kernel.shell_main() running at /home/sile/.pyenv/versions/3.12.10/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /home/sile/.pyenv/versions/3.12.10/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
Exception ignored in: <coroutine object Kernel.shell_main at 0x7b7c63484a40>
Traceback (most recent call last):
  File "<string>", line 1, in <lambda>
KeyError: '__import__'
Exception i

Training Done
Epoch 2: {'grad_norm': '3.7136', 'loss': '2.0415', 'policy_accuracy': '0.6159', 'policy_loss': '1.0025', 'value_loss': '0.6460'}
Testing
Epoch 2: {'pip_count_baseline_avg_outcome': '-0.0781'}
Epoch 2: {'gnubg_baseline_avg_outcome': '-0.8125'}
Temperature: 0.88
Collecting self-play games
Training
Training Done
Epoch 3: {'grad_norm': '2.7637', 'loss': '1.9250', 'policy_accuracy': '0.6197', 'policy_loss': '1.0007', 'value_loss': '0.6051'}
Temperature: 0.84
Collecting self-play games
Training
Training Done
Epoch 4: {'grad_norm': '2.1672', 'loss': '1.8159', 'policy_accuracy': '0.6406', 'policy_loss': '0.9832', 'value_loss': '0.5779'}
Testing
Epoch 4: {'pip_count_baseline_avg_outcome': '-0.0469'}
Epoch 4: {'gnubg_baseline_avg_outcome': '0.4375'}
Temperature: 0.8
Collecting self-play games
Training
Training Done
Epoch 5: {'grad_norm': '1.8031', 'loss': '1.7868', 'policy_accuracy': '0.6401', 'policy_loss': '0.9973', 'value_loss': '0.5869'}
Temperature: 0.76
Collecting self-play g