<a href="https://colab.research.google.com/github/sile16/bgai/blob/main/bg_training_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hello World, TurboZero Backgammon 🏁

In [2]:
# prompt: ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: '/tmp/pgx'

%pip install --upgrade pip
%pip install -U git+https://github.com/sile16/pgx.git@master



Collecting git+https://github.com/sile16/pgx.git@master
  Cloning https://github.com/sile16/pgx.git (to revision master) to /tmp/pip-req-build-44gnha60
  Running command git clone --filter=blob:none --quiet https://github.com/sile16/pgx.git /tmp/pip-req-build-44gnha60
  Running command git checkout -b master --track origin/master
  Switched to a new branch 'master'
  Branch 'master' set up to track remote branch 'master' from 'origin'.
  Resolved https://github.com/sile16/pgx.git to commit 224c8e556f3f33ad5e4b4d89a56ed85d660b6155
  Running command git submodule update --init --recursive -q
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting svgwrite (from pgx==2.5.9)
  Downloading svgwrite-1.4.3-py3-none-any.whl.metadata (8.8 kB)
Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)
Building wheels for collected packages: pgx
  Building wheel for pgx (pyprojec

In [1]:
import pgx
print(f"PGX Version {pgx.__version__}")

PGX Version 2.5.9


In [2]:
%pip install git+https://github.com/sile16/turbozero.git

Collecting git+https://github.com/sile16/turbozero.git
  Cloning https://github.com/sile16/turbozero.git to /tmp/pip-req-build-o1ddt1z6
  Running command git clone --filter=blob:none --quiet https://github.com/sile16/turbozero.git /tmp/pip-req-build-o1ddt1z6
  Resolved https://github.com/sile16/turbozero.git to commit c0d1dfd2c67adc953c9363e5691b2714d4bae13f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pgx@ git+https://github.com/sile16/pgx.git@master (from turbozero==0.1.2)
  Cloning https://github.com/sile16/pgx.git (to revision master) to /tmp/pip-install-yc43ijdy/pgx_f3af9330203147939da19a49ee6fde80
  Running command git clone --filter=blob:none --quiet https://github.com/sile16/pgx.git /tmp/pip-install-yc43ijdy/pgx_f3af9330203147939da19a49ee6fde80
  Running command git checkout -b master --track origin/master
  Switched to a new branch 'master'
 

In [2]:
import jax
print("Jax Version: ",jax.__version__)
#jax.config.update('jax_platform_name', 'gpu')
from jax.lib import xla_bridge
from prompt_toolkit import HTML
print("Default backend:", jax.default_backend())

import pgx
import pgx.backgammon as bg

print(f"Jax Version {jax.__version__}")


env = bg.Backgammon(simple_doubles=True)
print(env.simple_doubles)
print(env.num_actions)
print(env.stochastic_action_probs)

# create key
key = jax.random.PRNGKey(0)
state = env.init(key)
from IPython.display import HTML
display(HTML(state.to_svg()))





Jax Version:  0.5.2
Default backend: tpu
Jax Version 0.5.2
True
156
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]


In [3]:
from core.types import StepMetadata
%psource StepMetadata

In [4]:
import chex
from typing import Tuple

def step_fn(state: bg.State, action: int, key: chex.PRNGKey) -> Tuple[bg.State, StepMetadata]:
    """Combined step function for backgammon environment that handles both deterministic and stochastic actions."""
    # print(f"[DEBUG-BG_STEP-{time.time()}] Called with state (stochastic={state.is_stochastic}), action={action}") # Optional debug

    # Handle stochastic vs deterministic branches
    def stochastic_branch(operand):
        s, a, _ = operand # state, action, key (key ignored for stochastic step)
        # Use env instance captured by closure (assuming env is accessible in this scope)
        return env.stochastic_step(s, a)

    def deterministic_branch(operand):
        s, a, k = operand # state, action, key
        # Use env instance captured by closure
        return env.step(s, a, k)

    # Use conditional to route to the appropriate branch
    # The key is only needed for the deterministic branch
    new_state = jax.lax.cond(
        state.is_stochastic,
        stochastic_branch,
        deterministic_branch,
        (state, action, key) # Pass all required operands
    )

    # Create standard metadata
    metadata = StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
        step=new_state._step_count
    )

    return new_state, metadata

In [6]:
def init_fn(key):
    """Initializes a new environment state."""
    state = env.init(key)
    # No need to force non-stochastic, let the environment handle it
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
        step=state._step_count
    )

In [7]:
import jax
import jax.numpy as jnp
import flax.linen as nn

# Define a dense residual block for vector inputs.
class ResidualDenseBlock(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        residual = x
        x = nn.Dense(self.features)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        x = nn.LayerNorm()(x)
        return nn.relu(x + residual)

# Updated ResNet-style network that accepts a 'train' keyword.
class ResNetTurboZero(nn.Module):
    num_actions: int       # e.g. 6 for our simplified backgammon
    num_hidden: int = 128  # Hidden layer dimension
    num_blocks: int = 2    # Number of residual blocks

    @nn.compact
    def __call__(self, x, train: bool = False):
        # Initial projection.
        x = nn.Dense(self.num_hidden)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)

        # Process through a series of residual blocks.
        for _ in range(self.num_blocks):
            x = ResidualDenseBlock(self.num_hidden)(x)

        # Policy head: project features to logits over possible actions.
        policy_logits = nn.Dense(self.num_actions)(x)

        # Value head: project features to a single scalar.
        value = nn.Dense(1)(x)
        value = jnp.squeeze(value, axis=-1)
        return policy_logits, value


resnet_model = ResNetTurboZero(env.num_actions, num_hidden=128, num_blocks=2)

In [8]:
def state_to_nn_input(state):
    return state.observation

In [9]:
from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.stochastic_mcts import StochasticMCTS
import jax.numpy as jnp

# Training evaluator: StochasticMCTS using NN
evaluator = StochasticMCTS(   #Explores new moves
    eval_fn=make_nn_eval_fn(resnet_model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=300,
    max_nodes=500,
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=1.0,
)

In [10]:
evaluator_test = StochasticMCTS(   #Use optimized moves, temperature=0.0
    eval_fn=make_nn_eval_fn(resnet_model, state_to_nn_input),
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=300,  # Very few iterations
    max_nodes=500,      # Very small tree
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0,
)

In [12]:
from core.evaluators.evaluation_fns import make_nn_eval_fn_no_params_callable
import chex

# --- Pip Count Eval Fn (for test evaluator) ---
def backgammon_pip_count_eval(state: chex.ArrayTree, params: chex.ArrayTree, key: chex.PRNGKey):
    """Calculates value based on pip count difference. Ignores params/key.
    The boar is always from the current players perspective,
    current player is postivie numbers opponent is negative """
    board = state._board
    pips = state._board[1:25]

    born_off_current = board[26] * 30
    born_off_opponent = board[27] * 30

    #ignore bar, basically 0 points per pip on bar
    point_map = jnp.arange(1, 25, dtype=jnp.int32)

    value = jnp.sum(pips * point_map) + born_off_current + born_off_opponent

    # Uniform policy over legal actions for greedy baseline
    policy_logits = jnp.where(state.legal_action_mask, 0.0, -jnp.inf)

    return policy_logits, jnp.array(value)


# Test evaluator: Regular MCTS using pip count
pip_count_mcts_evaluator_test = StochasticMCTS(  # optimizes for moves
    eval_fn=backgammon_pip_count_eval, # Use pip count eval fn
    stochastic_action_probs=env.stochastic_action_probs,
    num_iterations=30, # Give it slightly more iterations maybe
    max_nodes=100,
    branching_factor=env.num_actions,
    action_selector=PUCTSelector(),
    temperature=0.0 # Deterministic action selection for testing
)

In [13]:
from core.memory.replay_memory import EpisodeReplayBuffer

replay_memory = EpisodeReplayBuffer(capacity=2000)

## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The `Trainer` takes many parameters, so let's walk through them all:
* `batch_size`: # of parallel environments used to collect self-play games
* `train_batch_size`: size of minibatch used during training step
* `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. This is used to populate the replay memory with some initial samples
* `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play per epoch
* `train_steps_per_epoch`: # of train steps per epoch
* `nn`: neural network (`linen.Module`)
* `loss_fn`: loss function used for training, we use a provided default loss which implements the loss function used in the `AlphaZero` paper
* `optimizer`: an `optax` optimizer used for training
* `evaluator`: the `Evaluator` to use during self-play, we initialized ours using `AlphaZero(MCTS)`
* `memory_buffer`: the memory buffer used to store samples from self-play games, we  initialized ours using `EpisodeReplayBuffer`
* `max_episode_steps`: maximum number of steps/turns to allow before truncating an episode
* `env_step_fn`: environment step function (we defined ours above)
* `env_init_fn`: environment init function (we defined ours above)
* `state_to_nn_input_fn`: function to convert environment state to nn input (we defined ours above)
* `testers`: any number of `Tester`s, used to evaluate a given model and take their own parameters. We'll use the two evaluators defined above to initialize two Testers.
* `evaluator_test`: (Optional) Evaluator used within Testers. By default used `evaluator`, but sometimes you may want to test with a larger MCTS iteration budget for example, or a different move sampling temperature
* `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.)
* `extract_model_params_fn`: (Optional) in special cases we need to define how to extract all model parameters from a flax `TrainState`. The default function handles BatchNorm, but if another special-case technique applied across batches is used (e.g. Dropout) we would need to define a function to extract the appropriate parameters. You usually won't need to define this!
* `wandb_project_name`: (Optional) Weights and Biases project name. You will be prompted to login if a name is provided. If a name is provided, a run will be initialized and loss and other metrics will be logged to the given wandb project.
* `ckpt_dir`: (Optional) directory to store checkpoints in, by default this is set to `/tmp/turbozero_checkpoints`
* `max_checkpoints`: (Optional) maximum number of most-recent checkpoints to retain (default: 2)
* `num_devices`: (Optional) number of hardware accelerators (GPUs/TPUs) to use. If not given, all available hardware accelerators are used
* `wandb_run`: (Optional) continues from an initialized `wandb` run if provided, otherwise a new one is initialized
* `extra_wandb_config`: (Optional) any extra metadata to store in the `wandb` run config

A training epoch is comprised of M collection steps, followed by N training steps sampling minibatches from replay memory. Optionally, any number of Testers evaluate the current model. At the end of each epoch, a checkpoint is saved.

If you are using one or more GPUs (reccommended), TurboZero by default will run on all your available hardware.

In [17]:
from functools import partial
from core.testing.two_player_baseline import TwoPlayerBaseline
from core.training.loss_fns import az_default_loss_fn
from core.training.stochastic_train import StochasticTrainer
from core.training.train import Trainer
import optax

!export WANDB_NOTEBOOK_NAME="turbozero-notebooks"

trainer = StochasticTrainer(
    batch_size=64,
    train_batch_size=32,
    warmup_steps=0,
    collection_steps_per_epoch=2048,  # number of steps played in the env per epoch
    train_steps_per_epoch=1024,       # Just 2 training step
    nn=resnet_model,
    loss_fn=partial(az_default_loss_fn, l2_reg_lambda=1e-4),
    optimizer=optax.adam(1e-4),
    # Use the stochastic evaluator for training
    evaluator=evaluator,
    memory_buffer=replay_memory,
    max_episode_steps=1000,  # should be enough for most games
    env_step_fn=step_fn,
    env_init_fn=init_fn,
    state_to_nn_input_fn=state_to_nn_input,
    ckpt_dir = "/tmp/ckpts",
    testers=[
        # Use our custom BackgammonTwoPlayerBaseline
        TwoPlayerBaseline(
            num_episodes=32,
            baseline_evaluator=pip_count_mcts_evaluator_test,
            #render_fn=render_fn,
            #render_dir='training_eval/pip_count_baseline',
            name='pip_count_baseline'
        )
    ],
    # Use the pip count MCTS evaluator for testing
    evaluator_test=evaluator_test,
    data_transform_fns=[],  # No data transforms as requested
    wandb_project_name=None
)

In [None]:
output = trainer.train_loop(seed=0, num_epochs=150, eval_every=5)