# Hello World, TurboZero Backgammon 🏁

`turbozero` provides a vectorized implementation of AlphaZero. 

In a nutshell, this means we can massively speed up training, by collecting many self-play games and running Monte Carlo Tree Search in parallel across one or more GPUs!

As the user, you just need to provide:
* environment dynamics functions (step and init) that adhere to the TurboZero spec
* a conversion function for environment state -> neural net input
* and a few hyperparameters!

TurboZero takes care of the rest. 😀 

## Getting Started

Follow the instructions in the repo readme to properly install dependencies and set up your environment.

## Environments

In order to take advantage of the batched implementation of AlphaZero, we need to pair it with a vectorized environment.

Fortunately, there are many great vectorized RL environment libraries, one I like in particular is [pgx](https://github.com/sotetsuk/pgx).

In [10]:
import pgx
import pgx.backgammon as bg

env = bg.Backgammon(simple_doubles=True)
print(env.simple_doubles)
print(env.num_actions)
print(env.stochastic_action_probs)


True
156
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]


## Environment Dynamics

Turbozero needs to interface with the environment in order to build search trees and collect self-play episodes.

We can define this interface with the following functions:
* `env_step_fn`: given an environment state and an action, return the new environment state 
```python
    EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]]
```
* `env_init_fn`: given a key, initialize and reutrn a new environment state
```python
    EnvInitFn = Callable[[chex.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]]
```
Fortunately, environment libraries implement these for us! We just need to extract a few key pieces of information 
from the environment state so that we can match the TurboZero specification. We store this in a StepMetadata object:

In [11]:
from core.types import StepMetadata
%psource StepMetadata

[0;34m@[0m[0mchex[0m[0;34m.[0m[0mdataclass[0m[0;34m([0m[0mfrozen[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mclass[0m [0mStepMetadata[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""Metadata for a step in the environment.[0m
[0;34m    - `rewards`: rewards received by the players[0m
[0;34m    - `action_mask`: mask of valid actions[0m
[0;34m    - `terminated`: whether the environment is terminated[0m
[0;34m    - `cur_player_id`: current player id[0m
[0;34m    - `step`: step number[0m
[0;34m    """[0m[0;34m[0m
[0;34m[0m    [0mrewards[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0maction_mask[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0mterminated[0m[0;34m:[0m [0mbool[0m[0;34m[0m
[0;34m[0m    [0mcur_player_id[0m[0;34m:[0m [0mint[0m[0;34m[0m
[0;34m[0m    [0mstep[0m[0;34m:[0m [0mint[0m[0;34m[0m[0;34m[0m[0m


* `rewards` stores the rewards emitted for each player for the given timestep
* `action_mask` is a mask across all possible actions, where legal actions are set to `True`, and invalid/illegal actions are set to `False`
* `terminated` True if the environment is terminated/completed
* `cur_player_id`: id of the current player
* `step`: step number

We can define the environment interface for `Backgammon` as follows:

In [12]:
def step_fn(state, action):
    """Handle regular backgammon moves.
    
    Args:
        state: Current environment state
        action: The move to make (index into legal_action_mask)
        
    Returns:
        Tuple of (new_state, metadata)
    """
    new_state = env.step(state, action)
    

    
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
        step=new_state._step_count
    )

def stochastic_step_fn(state, action):
    """Handle stochastic dice rolls.
    
    Args:
        state: Current environment state
        action: The move to make (index into env.stochastic_action_probs)
        
    Returns:
        Tuple of (new_state, metadata)
    """
    
    new_state = env.stochastic_step(state, action)
    
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
        step=new_state._step_count
    )
    
def stochastic_action_mask():
    """Return the action mask for the stochastic step function.
    
    For backgammon we are going to break out all possible dice rolls
        
    Returns:
        jnp.array: action mask for the stochastic step function

    
    """
    
    jnp.ones((1, env.num_actions))
                       
def init_fn(key):
    """Initialize a new backgammon game.
    
    Args:
        key: Random key for initialization
        
    Returns:
        Tuple of (initial_state, metadata)
    """
    state = env.init(key)

    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.action_probs,
        terminated=state.terminated,
        cur_player_id=state.current_player,
        step=state._step_count
    )
    


Pretty easy!

## Neural Network

Next, we'll need to define the architecture of the neural network 

A simple implementation of the residual neural network used in the _AlphaZero_ paper is included for your convenience. 

You can implement your own architecture using `flax.linen`.

In [13]:
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet = AZResnet(AZResnetConfig(
    policy_head_out_size=env.num_actions,
    num_blocks=4,
    num_channels=32,
))

We also need a way to convert our environment's state into something our neural network can take as input (i.e. structured data -> Array). `pgx` conveniently includes this in `state.observation`, but for other environments you may need to perform the conversion yourself.

In [14]:
def state_to_nn_input(state):
    return state.observation

## Evaluator

Next, we can initialize our evaluator, AlphaZero, which takes the following parameters:

* `eval_fn`: function used to evaluate a leaf node (returns a policy and value)
* `num_iterations`: number of MCTS iterations to run before returning the final policy
* `max_nodes`: maximum capacity of search tree
* `branching_factor`: branching factor of search tree == policy_size
* `action_selector`: the algorithm used to select an action to take at any given search node, choose between:
    * `PUCTSelector`: AlphaZero action selection algorithm
    * `MuZeroPUCTSelector`: MuZero action selection algorithm
    * or write your own! :)

There are also a few other optional parameters, a few of the important ones are:
* `temperature`: temperature applied to move probabilities prior to sampling (0.0 == argmax, ->inf == completely random sampling). I reccommend setting this to 1.0 for training (default) and 0.0 for evaluation.
* `dirichlet_alpha`: magnitude of Dirichlet noise to add to root policy (default 0.3). Generally, the more actions are possible in a game, the smaller this value should be. 
* `dirichlet_epsilon`: proportion of root policy composed of Dirichlet noise (default 0.25)


We use `make_nn_eval_fn` to create a leaf evaluation function that uses our neural network to generate a policy and a value for the given state. 

In [15]:
from core.evaluators.alphazero import AlphaZero
from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.mcts import MCTS

# alphazero can take an arbirary search `backend`
# here we use classic MCTS
az_evaluator = AlphaZero(MCTS)(
    eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),
    num_iterations = 32,
    max_nodes = 40,
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    temperature = 1.0
)

We also define a separate evaluator with different parameters to use for testing purposes. We'll give this one a larger budget (num_iterations), and set the temperature to zero so it always chooses the most-visited action after search is complete.

In [16]:
az_evaluator_test = AlphaZero(MCTS)(
    eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),
    num_iterations = 64,
    max_nodes = 80,
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    temperature = 0.0
)

## Baselines

We can also test our trained model periodically against baselines, in order to gauge improvement.

Conveniently, pgx offers pre-trained baseline models for certain environments. If we want to test against one, we can use `make_nn_eval_fn_no_params_callable`, which just returns an evaluation function that uses the baseline model to evaluate the game state.

We can combine this with the `AlphaZero` evaluator like we did before to create a competeing AlphaZero instance to play against. Other than the eval_fn, it is important to use the same parameters that we give to our (test) evaluator, so as to give a true comparison between the strength of the policy/value estimates.

In [None]:
from core.evaluators.evaluation_fns import make_nn_eval_fn_no_params_callable

#model = pgx.make_baseline_model('othello_v0')

baseline_eval_fn = make_nn_eval_fn_no_params_callable(model, state_to_nn_input)

baseline_az = AlphaZero(MCTS)(
    eval_fn = baseline_eval_fn,
    num_iterations = 64,
    max_nodes = 80,
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    temperature = 0.0
)

Downloading from https://drive.google.com/uc?id=1mY40mWoPuYCOrlfMQk_6DPGEFaQcvNAM ...


KeyboardInterrupt: 

To see baselines available for other pgx environments, see https://sotets.uk/pgx/api/#pgx.BaselineModelId

We can use similar ideas to write a greedy baseline evaluation function, one that doesn't use a neural network at all!

Instead, it simply counts the number of tiles for the active player and compares it to the number of tiles controlled by the other player, so the value is higher for states where the active player controls more tiles than the other player.

Using similar techniques as before, we can create another AlphaZero evaluator to test against.

In [9]:
import jax.numpy as jnp

def greedy_eval(obs):
    value = (obs[...,0].sum() - obs[...,1].sum()) / 64
    return jnp.ones((1,env.num_actions)), jnp.array([value])

greedy_baseline_eval_fn = make_nn_eval_fn_no_params_callable(greedy_eval, state_to_nn_input)


greedy_az = AlphaZero(MCTS)(
    eval_fn = greedy_baseline_eval_fn,
    num_iterations = 64,
    max_nodes = 80,
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    temperature = 0.0
)

## Replay Memory Buffer

Next, we'll initialize a replay memory buffer to hold selfplay trajectories that we can sample from during training. This actually just defines an interface, the buffer state itself will be initialized and managed internally.

The replay buffer is batched, it retains a buffer of trajectories across a batch dimension. We specify a `capacity`: the amount of samples stored in a single buffer. The total capacity of the entire replay buffer is then `batch_size * capacity`, where `batch_size` is the number of environments/self-play games being run in parallel.

In [10]:
from core.memory.replay_memory import EpisodeReplayBuffer

replay_memory = EpisodeReplayBuffer(capacity=1000)

## Data Augmentation (Optional)

During self-play, we allow for any number of custom data augmentation functions, in order to create more training samples. 

In RL, it's sometimes common to take advantage of rotations or symmetries in order to generate additional training examples. 

In Othello, we could simply consider rotating the board to generate a new training example, we will need to be careful to update our policy as well.

In order to implement a data augmentation function, we must follow `DataTransformFn`:
```python
# (policy mask, policy weights, environment state) -> (transformed policy mask, transformed policy weights, transformed environment state)
DataTransformFn = Callable[[chex.Array, chex.Array, chex.ArrayTree], Tuple[chex.Array, chex.Array, chex.ArrayTree]]
```

We create rotational transform functions for rotating 90, 180, 270 degrees:

In [11]:
def make_rot_transform_fn(amnt: int):
    def rot_transform_fn(mask, policy, state):
        action_ids = jnp.arange(65) # 65 total actions, but only rotate the first 64! (65th is always do nothing action)
        # we only use state.observation, no need to update the rest of the state fields
        new_obs = jnp.rot90(state.observation, amnt, axes=(-3,-2))
        # map action ids to new action ids
        idxs = jnp.arange(64).reshape(8,8) # rotate first 64 actions
        new_idxs = jnp.rot90(idxs, amnt, axes=(0, 1)).flatten()
        action_ids = action_ids.at[:64].set(new_idxs)
        # get new mask and policy
        new_mask = mask[...,action_ids]
        new_policy = policy[...,action_ids]
        return new_mask, new_policy, state.replace(observation=new_obs)

    return rot_transform_fn

# make transform fns for rotating 90, 180, 270 degrees
transforms = [make_rot_transform_fn(i) for i in range(1,4)]        

## Rendering
We can optionally provide a `render_fn` that will record games played by our model against one of the baselines and save it as a `.gif`.

I've included a helper fn that takes care of this:

This helper function depends upon cairoSVG, which itself depends upon `cairo`, which you'll need to install on your system.

On Ubuntu, this can be done with:

In [None]:
! apt-get update && apt-get -y install libcairo2-dev

If you're on another OS, consult https://www.cairographics.org/download/ for installation guidelines

In [13]:
from functools import partial
from core.testing.utils import render_pgx_2p
render_fn = partial(render_pgx_2p, p1_label='Black', p2_label='White', duration=900)

## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The `Trainer` takes many parameters, so let's walk through them all:
* `batch_size`: # of parallel environments used to collect self-play games
* `train_batch_size`: size of minibatch used during training step
* `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. This is used to populate the replay memory with some initial samples
* `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play per epoch
* `train_steps_per_epoch`: # of train steps per epoch
* `nn`: neural network (`linen.Module`)
* `loss_fn`: loss function used for training, we use a provided default loss which implements the loss function used in the `AlphaZero` paper
* `optimizer`: an `optax` optimizer used for training
* `evaluator`: the `Evaluator` to use during self-play, we initialized ours using `AlphaZero(MCTS)`
* `memory_buffer`: the memory buffer used to store samples from self-play games, we  initialized ours using `EpisodeReplayBuffer`
* `max_episode_steps`: maximum number of steps/turns to allow before truncating an episode
* `env_step_fn`: environment step function (we defined ours above)
* `env_init_fn`: environment init function (we defined ours above)
* `state_to_nn_input_fn`: function to convert environment state to nn input (we defined ours above)
* `testers`: any number of `Tester`s, used to evaluate a given model and take their own parameters. We'll use the two evaluators defined above to initialize two Testers.
* `evaluator_test`: (Optional) Evaluator used within Testers. By default used `evaluator`, but sometimes you may want to test with a larger MCTS iteration budget for example, or a different move sampling temperature
* `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.)
* `extract_model_params_fn`: (Optional) in special cases we need to define how to extract all model parameters from a flax `TrainState`. The default function handles BatchNorm, but if another special-case technique applied across batches is used (e.g. Dropout) we would need to define a function to extract the appropriate parameters. You usually won't need to define this!
* `wandb_project_name`: (Optional) Weights and Biases project name. You will be prompted to login if a name is provided. If a name is provided, a run will be initialized and loss and other metrics will be logged to the given wandb project.
* `ckpt_dir`: (Optional) directory to store checkpoints in, by default this is set to `/tmp/turbozero_checkpoints`
* `max_checkpoints`: (Optional) maximum number of most-recent checkpoints to retain (default: 2)
* `num_devices`: (Optional) number of hardware accelerators (GPUs/TPUs) to use. If not given, all available hardware accelerators are used
* `wandb_run`: (Optional) continues from an initialized `wandb` run if provided, otherwise a new one is initialized
* `extra_wandb_config`: (Optional) any extra metadata to store in the `wandb` run config

A training epoch is comprised of M collection steps, followed by N training steps sampling minibatches from replay memory. Optionally, any number of Testers evaluate the current model. At the end of each epoch, a checkpoint is saved.

If you are using one or more GPUs (reccommended), TurboZero by default will run on all your available hardware.

In [None]:
from functools import partial
from core.testing.two_player_baseline import TwoPlayerBaseline
from core.training.loss_fns import az_default_loss_fn
from core.training.train import Trainer
import optax

trainer = Trainer(
    batch_size = 1024,
    train_batch_size = 4096,
    warmup_steps = 0,
    collection_steps_per_epoch = 256,
    train_steps_per_epoch = 64,
    nn = resnet,
    loss_fn = partial(az_default_loss_fn, l2_reg_lambda = 0.0),
    optimizer = optax.adam(1e-3),
    evaluator = az_evaluator,
    memory_buffer = replay_memory,
    max_episode_steps = 80,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    state_to_nn_input_fn=state_to_nn_input,
    testers = [
        TwoPlayerBaseline(num_episodes=128, baseline_evaluator=baseline_az, render_fn=render_fn, render_dir='.', name='pretrained'),
        TwoPlayerBaseline(num_episodes=128, baseline_evaluator=greedy_az, render_fn=render_fn, render_dir='.', name='greedy'),
    ],
    evaluator_test = az_evaluator_test,
    data_transform_fns=transforms
    # wandb_project_name = 'turbozero-othello' 
)

## Training

Now all that's left to do is to kick off the training loop! We need to pass an initial seed for reproducibility, and the number of epochs to run for!

If you've set up `wandb`, you can track metrics via plots in the run dashboard. Metrics will also be printed to the console. 

IMPORTANT: The first epoch will not execute quickly! This is because there is significant overhead in JAX compilation (nearly all of the training loop is JIT-compiled). This will cause the first epoch to run very slowly, as JIT-compiled functions are traced and compiled the first time they are run. Expect epochs after the first to execute much more quickly. Typically, GPU utilization will also be low/zero during this period.

It's also worth mentioning that the hyperparameters in this notebook are just here for example purposes. Regardless of the task, they will need to be tuned according to the characteristics of the environment as well as your available hardware and time/cost constraints.

In [15]:
output = trainer.train_loop(seed=0, num_epochs=100, eval_every=5)

and GIFs generated will appear in the same directory as this notebook, and also on your `wandb` dashboard.