# Trainer scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [19]:
from typing import Any
import os
import importlib

In [37]:
import jax
import optax


assert jax.devices()[0].platform == 'gpu'

In [5]:
from rgi.core import game_registry
GAMES: dict[str, game_registry.RegisteredGame[Any, Any, Any]] = game_registry.GAME_REGISTRY
PLAYERS: dict[str, Any] = game_registry.PLAYER_REGISTRY

In [6]:
from rgi.core import trajectory
from rgi.players.zerozero.zerozero_trainer import ZeroZeroTrainer
from rgi.players.zerozero.zerozero_model import ZeroZeroModel

In [7]:
from rgi.games import connect4
from rgi import main

In [8]:
game_name = "connect4"
registered_game = GAMES[game_name]
serializer = registered_game.serializer_fn()
state_embedder = registered_game.state_embedder_fn()
action_embedder = registered_game.action_embedder_fn()
game = registered_game.game_fn()

trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

In [12]:
print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')
print('actions:')
for t in trajectories[:5]: print(f'  {t.actions}')


trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 100
actions:
  [6 7 4 5 5 7 1 7 1 1 4 5 2 4 5 1 1 2 2 4 7 7 5 7 3]
  [4 1 6 4 4 5 1 6 2 1 6 6 1 6 7 4 5 7 2 4 6 5 1 1 2 7 4 5 7]
  [3 4 2 5 5 2 5 6 7 4 4 1 4 5 7 7 4 2 2 5 7 4 5 3 6]
  [7 2 7 4 6 2 4 4 6 7 1 3 2 6 6 3 2 6 7 7 1 3 2 7 1 6 2]
  [5 4 1 5 6 3 4 3 5 3 7 6 7 7 3 7 6 7 5 6 5 2 7 3 6 6 2 5 4]


In [None]:
# from rgi.players.zerozero import zerozero_model
# importlib.reload(zerozero_model)
# from rgi.players.zerozero.zerozero_model import ZeroZeroModel

# from rgi.players.zerozero import zerozero_trainer
# importlib.reload(zerozero_trainer)
# from rgi.players.zerozero.zerozero_trainer import ZeroZeroTrainer

# model = ZeroZeroModel(
#     state_embedder=state_embedder,
#     action_embedder=action_embedder,
#     possible_actions=game.all_actions(),
#     embedding_dim=64,
#     hidden_dim=128,
#     shared_dim=256,
# )

# trainer = ZeroZeroTrainer(model, serializer, game)
# absolute_checkpoint_dir = os.path.abspath(os.path.join("data", "checkpoints", game_name))
# trainer.load_checkpoint(absolute_checkpoint_dir)

# trainer.train(trajectories, num_epochs=5, batch_size=32)

In [None]:


# # Save the trained model
# trainer.save_checkpoint(absolute_checkpoint_dir)
# print(f"Model training completed. Checkpoint saved to {args.checkpoint_dir}")


In [74]:
from flax.typing import FrozenVariableDict
import jax
import jax.numpy as jnp
from flax import linen as nn

import functools
import jax
import jax.numpy as jnp
from flax.training import train_state, checkpoints
import optax
from typing import Any, Tuple, List
import numpy as np
from tqdm import tqdm
import jax.tree_util as jtu
from rgi.core.base import GameSerializer, Game

from rgi.core.trajectory import load_trajectories, EncodedTrajectory
from rgi.players.zerozero.zerozero_model import ZeroZeroModel, zerozero_loss
from typing import Iterator

TStateEmbedder = connect4.Connect4StateEmbedder
TActionEmbedder = connect4.Connect4ActionEmbedder
TAction = int
TEmbedding = jax.Array
TEncodedGameState = jax.Array

class ZZActionModel(nn.Module):
    state_embedder: connect4.Connect4StateEmbedder
    action_embedder: connect4.Connect4ActionEmbedder
    possible_actions: list[TAction]
    embedding_dim: int = 64
    hidden_dim: int = 128
    shared_dim: int = 128

    def setup(self) -> None:
        self.shared_state_layer: nn.Module = nn.Sequential([nn.Dense(self.shared_dim), nn.relu])
        self.shared_action_layer: nn.Module = nn.Sequential([nn.Dense(self.shared_dim), nn.relu])

        self.policy_head: nn.Module = nn.Sequential(
            [nn.Dense(self.hidden_dim), nn.relu, nn.Dense(self.embedding_dim)])

    @nn.compact
    def __call__(self, state: TEncodedGameState) -> tuple[TEmbedding, jax.Array]:
        state_embedding = self.state_embedder(state)
        policy_embedding = self.policy_head(state_embedding)

        return policy_embedding

    def compute_action_logits(self, policy_embedding: TEmbedding) -> jax.Array:
        all_action_embeddings = jnp.array([self.action_embedder(action) for action in self.possible_actions])
        return jnp.dot(policy_embedding, jnp.transpose(all_action_embeddings))

    def compute_action_probabilities(self, policy_embedding: TEmbedding) -> jax.Array:
        return jax.nn.softmax(self.compute_action_logits(policy_embedding))


def zz_action_loss(
    model: ZZActionModel,
    params: dict[str, Any],
    state: TEncodedGameState,
    action: TAction,
    policy_target: jax.Array,
) -> tuple[float, dict[str, float]]:

    policy_embedding: jax.Array

    policy_embedding = model.apply(params, state)  # type: ignore
    action_probs = model.apply(
        params,
        method=model.compute_action_probabilities,
        policy_embedding=policy_embedding,
    )
    assert isinstance(action_probs, jax.Array)
    policy_loss = -jnp.sum(policy_target * jnp.log(action_probs + 1e-8))

    total_loss = jnp.mean(policy_loss)
    loss_dict = {
        "total_loss": total_loss,
        "policy_loss": jnp.mean(policy_loss),
    }

    return total_loss, loss_dict

class ZZActionTrainer:
    def __init__(
        self,
        model: ZZActionModel,
        serializer: connect4.Connect4Serializer,
        game: connect4.Connect4Game,
        learning_rate: float = 1e-4,
    ):
        self.model = model
        self.serializer = serializer
        self.game = game
        self.optimizer = optax.adam(learning_rate)
        self.state = None

    # TODO: Rename
    def create_train_state(self, rng: jax.random.PRNGKey) -> train_state.TrainState:
        dummy_state = self.serializer.state_to_jax_array(self.game, self.game.initial_state())
        dummy_action = self.serializer.action_to_jax_array(self.game, self.game.all_actions()[0])
        # Add batch dimension to dummy inputs
        dummy_state_batch = jnp.expand_dims(dummy_state, axis=0)
        dummy_action_batch = jnp.expand_dims(dummy_action, axis=0)

        params = self.model.init(rng, dummy_state_batch)
        return train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.optimizer)


    @jax.disable_jit()  # TODO: Remove this
    @functools.partial(jax.jit, static_argnums=0)
    def train_step(self, state: train_state.TrainState, batch: Tuple[Any, ...]) -> Tuple[train_state.TrainState, dict]:
        def loss_fn(params) -> tuple[float, dict[str, float]]:
            state_input, action, next_state, reward, policy_target = batch
            reward = jnp.asarray(reward)  # Ensure reward is a jax.Array
            loss, loss_dict = zz_action_loss(
                self.model,
                params,
                state_input,
                action,
                policy_target,
            )
            return loss, loss_dict

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (_, loss_dict), grads = grad_fn(state.params)
        return state.apply_gradients(grads=grads), loss_dict

    def create_batches(self, trajectories: List[EncodedTrajectory], batch_size: int) -> Iterator[Tuple[Any, ...]]:
        states, actions, next_states, rewards, policy_targets = [], [], [], [], []

        possible_actions = self.model.possible_actions
        for trajectory in trajectories:
            for i in range(trajectory.length - 1):
                states.append(trajectory.states[i])
                actions.append(trajectory.actions[i])
                next_states.append(trajectory.states[i + 1])
                rewards.append(trajectory.state_rewards[i])

                decoded_action = self.serializer.jax_array_to_action(self.game, trajectory.actions[i])
                decoded_action_index = possible_actions.index(decoded_action)
                one_hot_action = jax.nn.one_hot(decoded_action_index, num_classes=len(possible_actions))
                policy_targets.append(one_hot_action)

        dataset = list(zip(states, actions, next_states, rewards, policy_targets))
        np.random.shuffle(dataset)

        for i in range(0, len(dataset), batch_size):
            yield tuple(map(np.array, zip(*dataset[i : i + batch_size])))

    def train(self, trajectories: List[EncodedTrajectory], num_epochs: int, batch_size: int):
        if self.state is None:
            raise ValueError("TrainState is not initialized. Call create_train_state first.")

        for epoch in range(num_epochs):
            epoch_losses = []
            batches = self.create_batches(trajectories, batch_size)

            with tqdm(total=len(trajectories), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
                for batch in batches:
                    self.state, loss_dict = self.train_step(self.state, batch)
                    epoch_losses.append(loss_dict["total_loss"])
                    pbar.update(batch_size)
                    pbar.set_postfix({"loss": np.mean(epoch_losses)})

            print(f"Epoch {epoch + 1} - Average Loss: {np.mean(epoch_losses):.4f}")

    def save_checkpoint(self, checkpoint_dir: str) -> None:
        if self.state is None:
            raise ValueError("No state to save. Please train the model first.")
        checkpoints.save_checkpoint(checkpoint_dir, self.state, step=self.state.step, keep=3)

    def load_checkpoint(self, checkpoint_dir: str) -> None:
        """Load checkpoint from directory. If no checkpoint is found, create a new state."""
        if self.state is None:
            # We need a state of the correct type to restore from a checkpoint.
            rng = jax.random.PRNGKey(0)
            self.state = self.create_train_state(rng)
        self.state = checkpoints.restore_checkpoint(checkpoint_dir, target=self.state)


In [79]:
from rgi.games import connect4
importlib.reload(connect4)

# from rgi.players.zerozero import zerozero_model
# importlib.reload(zerozero_model)
# from rgi.players.zerozero.zerozero_model import ZeroZeroModel

# from rgi.players.zerozero import zerozero_trainer
# importlib.reload(zerozero_trainer)
# from rgi.players.zerozero.zerozero_trainer import ZeroZeroTrainer

state_embedder = connect4.Connect4StateEmbedder()
action_embedder = connect4.Connect4ActionEmbedder()
serializer = connect4.Connect4Serializer()
game = connect4.Connect4Game()
all_actions: list[int] = game.all_actions()  # type: ignore

model = ZZActionModel(
    state_embedder=state_embedder,
    action_embedder=action_embedder,
    possible_actions=all_actions,
    embedding_dim=64,
    shared_dim=256,
)

trainer = ZZActionTrainer(model, serializer, game)
absolute_checkpoint_dir = os.path.abspath(os.path.join("data", "checkpoints", game_name))
trainer.load_checkpoint(absolute_checkpoint_dir)

trainer.train(trajectories, num_epochs=5, batch_size=32)

Epoch 1/5:   0%|          | 0/100 [08:44<?, ?it/s]


ScopeParamNotFoundError: Could not find parameter named "action_embeddings" in scope "/action_embedder". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)