# Hello World, TurboZero Backgammon 🏁

`turbozero` provides a vectorized implementation of AlphaZero. 

In a nutshell, this means we can massively speed up training, by collecting many self-play games and running Monte Carlo Tree Search in parallel across one or more GPUs!

As the user, you just need to provide:
* environment dynamics functions (step and init) that adhere to the TurboZero spec
* a conversion function for environment state -> neural net input
* and a few hyperparameters!

TurboZero takes care of the rest. 😀 

## Getting Started

Follow the instructions in the repo readme to properly install dependencies and set up your environment.

## Environments

In order to take advantage of the batched implementation of AlphaZero, we need to pair it with a vectorized environment.

Fortunately, there are many great vectorized RL environment libraries, one I like in particular is [pgx](https://github.com/sotetsuk/pgx).

In [1]:
# Only if needed
#!pip install git+https://github.com/sile16/turbozero.git

import sys
print(f"Python version {sys.version}")

#jax.config.update('jax_platform_name', 'gpu')
import jax

from prompt_toolkit import HTML

import pgx
import pgx.backgammon as bg
import os

#print(f"Env: {os.environ}")
print(f"Jax Version: {jax.__version__}, Backend: {jax.default_backend()} ")
print(f"PGX Version {pgx.__version__}")




env = bg.Backgammon(short_game=True)

print(f"Backgammon version: {env.version} Num Actions: {env.num_actions} Simple Doubles: {env.simple_doubles}" )
print(env.stochastic_action_probs)

# create key
key = jax.random.PRNGKey(0)
state = env.init(key)
from IPython.display import HTML
display(HTML(state.to_svg()))





Python version 3.12.10 (main, May 21 2025, 10:26:13) [GCC 13.3.0]
Jax Version: 0.6.0, Backend: gpu 
PGX Version 2.5.10
Backgammon version: v2 Num Actions: 156 Simple Doubles: False
[0.02777778 0.02777778 0.02777778 0.02777778 0.02777778 0.02777778
 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556
 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556 0.05555556
 0.05555556 0.05555556 0.05555556]


## Environment Dynamics

Turbozero needs to interface with the environment in order to build search trees and collect self-play episodes.

We can define this interface with the following functions:
* `env_step_fn`: given an environment state and an action, return the new environment state 
```python
    EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]]
```
* `env_init_fn`: given a key, initialize and reutrn a new environment state
```python
    EnvInitFn = Callable[[chex.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]]
```
Fortunately, environment libraries implement these for us! We just need to extract a few key pieces of information 
from the environment state so that we can match the TurboZero specification. We store this in a StepMetadata object:

In [2]:
from core.types import StepMetadata
#%psource StepMetadata

* `rewards` stores the rewards emitted for each player for the given timestep
* `action_mask` is a mask across all possible actions, where legal actions are set to `True`, and invalid/illegal actions are set to `False`
* `terminated` True if the environment is terminated/completed
* `cur_player_id`: id of the current player
* `step`: step number

We can define the environment interface for `Backgammon` as follows:

In [3]:
import chex
from core.bgcommon import bg_step_fn
from functools import partial
step_fn = partial(bg_step_fn, env)

def init_fn(key):
    """Initializes a new environment state."""
    state = env.init(key)
    # No need to force non-stochastic, let the environment handle it
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
        step=state._step_count
    )

Pretty easy!

## Neural Network

Next, we'll need to define the architecture of the neural network 

A simple implementation of the residual neural network used in the _AlphaZero_ paper is included for your convenience. 

You can implement your own architecture using `flax.linen`.

In [4]:
import jax
import jax.numpy as jnp
import flax.linen as nn

# Pre‑activation ResNet‑V2 block
class ResBlockV2(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        r = x
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Dense(self.features, use_bias=False)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Dense(self.features, use_bias=False)(x)
        return x + r

class ResNetTurboZero(nn.Module):
    num_actions: int            # 156 here
    hidden_dim: int = 256
    num_blocks: int = 10

    @nn.compact
    def __call__(self, x, train: bool = False):
        # 1) ResNet tower
        x = nn.Dense(self.hidden_dim, use_bias=False)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        for _ in range(self.num_blocks):
            x = ResBlockV2(self.hidden_dim)(x)

        # 2) Policy head: single Dense into 156 logits
        policy_logits = nn.Dense(self.num_actions)(x)

        # 3) Value head
        v = nn.LayerNorm()(x)
        v = nn.relu(v)
        v = nn.Dense(1)(v)
        v = jnp.squeeze(v, -1)

        return policy_logits, v
    

resnet_model = ResNetTurboZero(
    num_actions=env.num_actions,  # i.e. micro_steps*(micro_src + micro_die)
    hidden_dim=256,
    num_blocks=10
)

from core.networks.mlp import MLPConfig, MLP

# Replace the resnet with an MLP network
mlp_network = MLP(MLPConfig(
    hidden_dims=[128, 128, 64],  # Adjust layer sizes as needed
    policy_head_out_size=env.num_actions,
    value_head_out_size=1
))



We also need a way to convert our environment's state into something our neural network can take as input (i.e. structured data -> Array). `pgx` conveniently includes this in `state.observation`, but for other environments you may need to perform the conversion yourself.

In [5]:
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet_simple = AZResnet(AZResnetConfig(
    policy_head_out_size=env.num_actions,
    num_blocks=4,
    num_channels=32,
))



In [6]:
model = resnet_model

In [7]:
def state_to_nn_input(state):
    return state.observation

## Evaluator

Next, we can initialize our evaluator, AlphaZero, which takes the following parameters:

* `eval_fn`: function used to evaluate a leaf node (returns a policy and value)
* `num_iterations`: number of MCTS iterations to run before returning the final policy
* `max_nodes`: maximum capacity of search tree
* `branching_factor`: branching factor of search tree == policy_size
* `action_selector`: the algorithm used to select an action to take at any given search node, choose between:
    * `PUCTSelector`: AlphaZero action selection algorithm
    * `MuZeroPUCTSelector`: MuZero action selection algorithm
    * or write your own! :)

There are also a few other optional parameters, a few of the important ones are:
* `temperature`: temperature applied to move probabilities prior to sampling (0.0 == argmax, ->inf == completely random sampling). I reccommend setting this to 1.0 for training (default) and 0.0 for evaluation.
* `dirichlet_alpha`: magnitude of Dirichlet noise to add to root policy (default 0.3). Generally, the more actions are possible in a game, the smaller this value should be. 
* `dirichlet_epsilon`: proportion of root policy composed of Dirichlet noise (default 0.25)


We use `make_nn_eval_fn` to create a leaf evaluation function that uses our neural network to generate a policy and a value for the given state. 

In [8]:

from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.stochastic_mcts import StochasticMCTS
import jax.numpy as jnp

# Training evaluator: StochasticMCTS using NN
evaluator = StochasticMCTS(   #Explores new moves
    eval_fn=make_nn_eval_fn(model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=1600,  
    max_nodes=2000,      
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.99,
)

We also define a separate evaluator with different parameters to use for testing purposes. We'll give this one a larger budget (num_iterations), and set the temperature to zero so it always chooses the most-visited action after search is complete.

In [9]:

evaluator_test = StochasticMCTS(   #Use optimized moves, temperature=0.0
    eval_fn=make_nn_eval_fn(model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=300,  # Very few iterations
    max_nodes=1000,      # Very small tree
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0,
)

We can use similar ideas to write a greedy baseline evaluation function, one that doesn't use a neural network at all!

Instead, it simply counts the number of tiles for the active player and compares it to the number of tiles controlled by the other player, so the value is higher for states where the active player controls more tiles than the other player.

Using similar techniques as before, we can create another AlphaZero evaluator to test against.

In [10]:
from core.evaluators.evaluation_fns import make_nn_eval_fn_no_params_callable
import chex
from core.bgcommon import bg_hit2_eval


# Test evaluator: Regular MCTS using pip count
bg_hit2_mcts_evaluator_test = StochasticMCTS(  # optimizes for moves
    eval_fn=bg_hit2_eval, # Use pip count eval fn
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=30, # Give it slightly more iterations maybe
    max_nodes=100,
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0 # Deterministic action selection for testing
)

## Rendering
We can optionally provide a `render_fn` that will record games played by our model against one of the baselines and save it as a `.gif`.

I've included a helper fn that takes care of this:

This helper function depends upon cairoSVG, which itself depends upon `cairo`, which you'll need to install on your system.

On Ubuntu, this can be done with:

In [11]:
#! apt-get update && apt-get -y install libcairo2-dev

If you're on another OS, consult https://www.cairographics.org/download/ for installation guidelines

In [12]:
from functools import partial
from core.testing.utils import render_pgx_2p
render_fn = partial(render_pgx_2p, p1_label='Black', p2_label='White', duration=1200)


In [13]:
def get_temperature(train_steps) -> float:
    
    if train_steps < 1e5:
        return 1.0
    elif train_steps < 2e5:
        return 0.5
    elif train_steps < 3e5:
        return 0.1
    else:
        # Greedy selection.
        return 0.0

## Replay Memory Buffer

Next, we'll initialize a replay memory buffer to hold selfplay trajectories that we can sample from during training. This actually just defines an interface, the buffer state itself will be initialized and managed internally.

The replay buffer is batched, it retains a buffer of trajectories across a batch dimension. We specify a `capacity`: the amount of samples stored in a single buffer. The total capacity of the entire replay buffer is then `batch_size * capacity`, where `batch_size` is the number of environments/self-play games being run in parallel.

In [14]:
from core.memory.replay_memory import EpisodeReplayBuffer
collect_batch_size = 32
collection_steps_per_epoch = 1024

buffer_size = collect_batch_size * collection_steps_per_epoch * 1 # buffer holds 4 epochs, 

replay_memory = EpisodeReplayBuffer(capacity=buffer_size )
print(f"replay buffer size: {buffer_size}")

replay buffer size: 32768


## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The `Trainer` takes many parameters, so let's walk through them all:
* `batch_size`: # of parallel environments used to collect self-play games
* `train_batch_size`: size of minibatch used during training step
* `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. This is used to populate the replay memory with some initial samples
* `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play per epoch
* `train_steps_per_epoch`: # of train steps per epoch
* `nn`: neural network (`linen.Module`)
* `loss_fn`: loss function used for training, we use a provided default loss which implements the loss function used in the `AlphaZero` paper
* `optimizer`: an `optax` optimizer used for training
* `evaluator`: the `Evaluator` to use during self-play, we initialized ours using `AlphaZero(MCTS)`
* `memory_buffer`: the memory buffer used to store samples from self-play games, we  initialized ours using `EpisodeReplayBuffer`
* `max_episode_steps`: maximum number of steps/turns to allow before truncating an episode
* `env_step_fn`: environment step function (we defined ours above)
* `env_init_fn`: environment init function (we defined ours above)
* `state_to_nn_input_fn`: function to convert environment state to nn input (we defined ours above)
* `testers`: any number of `Tester`s, used to evaluate a given model and take their own parameters. We'll use the two evaluators defined above to initialize two Testers.
* `evaluator_test`: (Optional) Evaluator used within Testers. By default used `evaluator`, but sometimes you may want to test with a larger MCTS iteration budget for example, or a different move sampling temperature
* `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.)
* `extract_model_params_fn`: (Optional) in special cases we need to define how to extract all model parameters from a flax `TrainState`. The default function handles BatchNorm, but if another special-case technique applied across batches is used (e.g. Dropout) we would need to define a function to extract the appropriate parameters. You usually won't need to define this!
* `wandb_project_name`: (Optional) Weights and Biases project name. You will be prompted to login if a name is provided. If a name is provided, a run will be initialized and loss and other metrics will be logged to the given wandb project.
* `ckpt_dir`: (Optional) directory to store checkpoints in, by default this is set to `/tmp/turbozero_checkpoints`
* `max_checkpoints`: (Optional) maximum number of most-recent checkpoints to retain (default: 2)
* `num_devices`: (Optional) number of hardware accelerators (GPUs/TPUs) to use. If not given, all available hardware accelerators are used
* `wandb_run`: (Optional) continues from an initialized `wandb` run if provided, otherwise a new one is initialized
* `extra_wandb_config`: (Optional) any extra metadata to store in the `wandb` run config

A training epoch is comprised of M collection steps, followed by N training steps sampling minibatches from replay memory. Optionally, any number of Testers evaluate the current model. At the end of each epoch, a checkpoint is saved.

If you are using one or more GPUs (reccommended), TurboZero by default will run on all your available hardware.

In [15]:
from functools import partial
from core.testing.two_player_baseline import TwoPlayerBaseline
from core.training.loss_fns import az_default_loss_fn
from core.training.stochastic_train import StochasticTrainer
import optax

train_ratio = 1
total_collected_per_epoch = collect_batch_size * collection_steps_per_epoch # 2048 so to minimize games truncated to < 10%
train_batch_size = 1024 
train_steps = int(train_ratio * total_collected_per_epoch / train_batch_size) # train steps vs collect steps

print(f"Collect_batch: {collect_batch_size} Total_Collected_steps_per_epoch: {total_collected_per_epoch}")
print(f"train_batch_size={train_batch_size} train_steps: {train_steps}")
print(f"Ratio of trainstep to collect steps: {train_ratio}")



trainer = StochasticTrainer(
    warmup_steps=0,
    batch_size=collect_batch_size,
    collection_steps_per_epoch=collection_steps_per_epoch,
    train_batch_size=train_batch_size,
    train_steps_per_epoch=train_steps,  #       
    nn=model,
    loss_fn=partial(az_default_loss_fn, l2_reg_lambda=1e-4),
    optimizer=optax.adam(1e-4),
    # Use the stochastic evaluator for training
    evaluator=evaluator, 
    memory_buffer=replay_memory,
    max_episode_steps=500,  # should be enough for most games
    env_step_fn=step_fn,
    env_init_fn=init_fn,
    state_to_nn_input_fn=state_to_nn_input,
    ckpt_dir = "/tmp/ckpts",
    testers=[
        # Use our custom BackgammonTwoPlayerBaseline
        TwoPlayerBaseline(
            num_episodes=30,
            baseline_evaluator=bg_hit2_mcts_evaluator_test,
            #render_fn=render_fn,
            #render_dir='training_eval/pip_count_baseline',
            name='hit2_eval'
        )
    ],
    # Use the pip count MCTS evaluator for testing
    evaluator_test=evaluator_test, 
    data_transform_fns=[],  # No data transforms as requested
    wandb_project_name="TurboZero-Backgammon-Notebook"
)

Collect_batch: 32 Total_Collected_steps_per_epoch: 32768
train_batch_size=1024 train_steps: 32
Ratio of trainstep to collect steps: 1


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msile16[0m ([33msile16-self[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training

Now all that's left to do is to kick off the training loop! We need to pass an initial seed for reproducibility, and the number of epochs to run for!

If you've set up `wandb`, you can track metrics via plots in the run dashboard. Metrics will also be printed to the console. 

IMPORTANT: The first epoch will not execute quickly! This is because there is significant overhead in JAX compilation (nearly all of the training loop is JIT-compiled). This will cause the first epoch to run very slowly, as JIT-compiled functions are traced and compiled the first time they are run. Expect epochs after the first to execute much more quickly. Typically, GPU utilization will also be low/zero during this period.

It's also worth mentioning that the hyperparameters in this notebook are just here for example purposes. Regardless of the task, they will need to be tuned according to the characteristics of the environment as well as your available hardware and time/cost constraints.

In [16]:
#output = trainer.train_loop(seed=0, num_epochs=50, eval_every=5)


output = None
# Update temperature for each epoch
adjust_temp_count=5

trainer.set_temp_fn(get_temperature)

output = trainer.train_loop(seed=42, num_epochs=100, eval_every=5)

print("Training loop completed successfully.")

Temperature: 1.0
Collecting self-play games
Epoch 0: {'hit2_eval_avg_outcome': '-0.1333', 'perf/test_hit2_eval_time_sec': '347.9156'}
Temperature: 1.0
Collecting self-play games
Training
Training Done
Epoch 3: {'l2_reg': '1.1024', 'loss': '2.7417', 'policy_entropy': '0.9696', 'policy_loss': '1.1116', 'value_loss': '0.5277', 'train/train_time_sec': '6.1123', 'train/train_steps_per_sec': '5360.9985', 'collect/temperature': '1.0000', 'collect/collect_time_sec': '1059.5687', 'collect/collect_steps_per_sec': '0.9664', 'buffer/populated': '90760.0000', 'buffer/has_reward': '1046561.0000', 'buffer/fullness_pct': '8.6555', 'perf/epoch_time_sec': '1065.6863'}
Temperature: 0.5
Collecting self-play games
Training
Training Done
Epoch 4: {'l2_reg': '1.0977', 'loss': '2.7256', 'policy_entropy': '0.9731', 'policy_loss': '1.0908', 'value_loss': '0.5371', 'train/train_time_sec': '6.0423', 'train/train_steps_per_sec': '5423.1011', 'collect/temperature': '0.5000', 'collect/collect_time_sec': '1070.8238',

and GIFs generated will appear in the same directory as this notebook, and also on your `wandb` dashboard.