# Trainer NNX scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [None]:
import os
from typing import Any

from flax import nnx  # The Flax NNX API.
from functools import partial

import jax
import jax.numpy as jnp

print(jax.devices())
assert jax.devices()[0].platform == 'gpu'

In [None]:
from rgi.core import trajectory
from rgi.players.zerozero.zerozero_trainer import ZeroZeroTrainer
from rgi.players.zerozero.zerozero_model import ZeroZeroModel

In [None]:
from rgi.core import game_registry
GAMES: dict[str, game_registry.RegisteredGame[Any, Any, Any]] = game_registry.GAME_REGISTRY
PLAYERS: dict[str, Any] = game_registry.PLAYER_REGISTRY

game_name = "connect4"
registered_game = GAMES[game_name]
serializer = registered_game.serializer_fn()
state_embedder = registered_game.state_embedder_fn()
action_embedder = registered_game.action_embedder_fn()
game = registered_game.game_fn()

trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

In [None]:

class Connect4StateEmbedder(nnx.Module):
    embedding_dim: int = 64
    hidden_dim: int = 256

    def _state_to_array(self, encoded_state_batch: jax.Array) -> jax.Array:
        board_array = encoded_state_batch[:, :-1].reshape([-1, 6, 7])
        return board_array

    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(6*7*64, self.hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(self.hidden_dim, self.embedding_dim, rngs=rngs)

    def __call__(self, encoded_state_batch: jax.Array):
        x = self._state_to_array(encoded_state_batch)
        x = x[..., None]  # Add channel dimension
        x = nnx.relu(self.conv1(x))
        x = nnx.relu(self.conv2(x))
        x = x.reshape(x.shape[0], -1)  # Flatten while preserving batch dimension
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)

        # normalize x, Add epsilon to avoid division by zero
        epsilon = 1e-8  
        x += epsilon
        x = x / jnp.linalg.norm(x, axis=1, keepdims=True)
        return x


state_0 = game.initial_state()
state_1 = game.next_state(state_0, 2)

encoded_states_batch = jnp.array([
    serializer.state_to_jax_array(game, state_0),
    serializer.state_to_jax_array(game, state_1),
])
# encoded_state_batch = jnp.expand_dims(encoded_states, axis=0)
print(f'encoded_states_batch.shape: {encoded_states_batch.shape}')

state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))


In [None]:
class Connect4ActionEmbedder(nnx.Module):
    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, *, rngs: nnx.Rngs):
        self.embedding = nnx.Embed(num_embeddings=self.num_actions, features=self.embedding_dim, rngs=rngs)

    def __call__(self, action: jax.Array) -> jax.Array:
        # Ensure action is 0-indexed
        action = action - 1
        return self.embedding(action)
    
    def all_action_embeddings(self):
        return self(jnp.arange(1,self.num_actions+1))

action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(0))





In [None]:
print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

from collections import Counter
c = Counter((t.actions[0].item(), t.final_rewards[0].item()) for t in trajectories)

for action in range(1,7+1):
    win_count, loss_count = c[(action, 1.0)], c[(action, -1.0)]
    print(f'action: {action}, win: {win_count:3d}, loss: {loss_count:3d}, win_pct: {win_count / (win_count + loss_count) * 100:.2f}%')
    
# num_trajectories: 100
# action: 1, win:  14, loss:   9, win_pct: 60.87%
# action: 2, win:   6, loss:   6, win_pct: 50.00%
# action: 3, win:   8, loss:   4, win_pct: 66.67%
# action: 4, win:  11, loss:   3, win_pct: 78.57%
# action: 5, win:  11, loss:   5, win_pct: 68.75%
# action: 6, win:   8, loss:   5, win_pct: 61.54%
# action: 7, win:   5, loss:   5, win_pct: 50.00%


In [None]:
from rgi.core.trajectory import EncodedTrajectory

# define trajectory NamedTuple
from typing import NamedTuple
class TrajectoryStep(NamedTuple):
    move_index: int
    state: jax.Array
    action: jax.Array
    next_state: jax.Array
    reward: jax.Array


def unroll_trajectory(encoded_trajectories: list[EncodedTrajectory]):
    for t in encoded_trajectories:
        for i in range(t.length - 1):
            yield TrajectoryStep(i, t.states[i], t.actions[i], t.states[i + 1], t.state_rewards[i])

t_steps_list = list(unroll_trajectory(trajectories))
print(f'num_trajectory_steps: {len(t_steps_list)}')

In [None]:
# state_embedder
# state_batch = jnp.array([t_steps_list[0].state, t_steps_list[1].state])
state_batch = jnp.array([t.state for t in t_steps_list])
state_embedder(state_batch[:2])

In [None]:
# action_batch = jnp.array([t_steps_list[0].action, t_steps_list[1].action])
action_batch = jnp.array([t.action for t in t_steps_list])
action_embedder(action_batch)

In [None]:
import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(state_embedder, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

# nnx.display(optimizer)

In [None]:
len(action_batch), len(state_batch)

In [None]:
def loss_fn(model: Connect4StateEmbedder, batch):
  logits = model(batch['state'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['action']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: Connect4StateEmbedder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(model: Connect4StateEmbedder, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.




In [None]:
batch = {'state': state_batch, 'action': action_batch}
loss, logits = loss_fn(state_embedder, batch)

loss, logits



In [None]:
# metrics_history = {
#   'train_loss': [],
#   'train_accuracy': [],
#   'test_loss': [],
#   'test_accuracy': [],
# }

# for _ in range(1000):
#     train_step(state_embedder, optimizer, metrics, batch)

# for metric, value in metrics.compute().items():  # Compute the metrics.
#     metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
# metrics.reset()  # Reset the metrics for the test set.

# eval_step(state_embedder, metrics, batch)

# # Log the test metrics.
# for metric, value in metrics.compute().items():
#     metrics_history[f'test_{metric}'].append(value)
# metrics.reset()  # Reset the metrics for the next training epoch.

# print(
#     f"[train] step: {1}, "
#     f"loss: {metrics_history['train_loss'][-1]}, "
#     f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
# )
# print(
#     f"[test] step: {1}, "
#     f"loss: {metrics_history['test_loss'][-1]}, "
#     f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
# )

# # [train] step: 1, loss: 3.2814717292785645, accuracy: 87.06680297851562
# # [test] step: 1, loss: 3.252131223678589, accuracy: 90.15047454833984

# # [train] step: 1, loss: 3.252079963684082, accuracy: 90.15047454833984
# # [test] step: 1, loss: 3.2519733905792236, accuracy: 90.15047454833984

In [None]:
state_embedder(state_batch[:2])

In [None]:
batch_2 = {k:v[:2] for k,v in batch.items()}

In [None]:
def loss_fn_2(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, batch):
    state_embeddings = state_embedder(batch['state'])
    all_action_embeddings = action_embedder.all_action_embeddings()
    logits = state_embeddings @ all_action_embeddings.T
    # labels = jax.nn.one_hot(batch['action']-1, num_classes=7)
    labels = batch['action']-1
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

batch_1 = {'state': state_batch[:1], 'action': action_batch[:1]}
batch_2 = {'state': state_batch[:2], 'action': action_batch[:2]}
loss_fn_2(state_embedder, action_embedder, batch_2)



In [None]:
@nnx.jit
def train_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn_2, has_aux=True)
  (loss, logits), grads = grad_fn(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action']-1)  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn_2(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.

In [None]:
# batch['state'].shape
# batch_2['state'].shape

# jnp.array([batch_2['state'], batch_2['state'], batch_2['state'], batch_2['state']]).shape

batch_repeated = {'state': jnp.tile(batch_2['state'], (1000,1)), 'action': jnp.tile(batch_2['action'], (1000,))}
# # Duplicate the batch 4 times
# b0 = batch_2['state']
# # b1 = jnp.array([b0,b0,b0,b0])
# b1 = jnp.tile(b0, (4,1))
# b0.shape, b1.shape
batch_repeated['state'].shape, batch_repeated['action'].shape


In [None]:
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

train_steps = 1200
eval_every = 200
batch_size = 32

# train_batches = [(i, batch) for i in range(2000)]
# test_batches = [(i, batch) for i in range(100)]

train_batches = [(i, batch_repeated) for i in range(2000)]
test_batches = [(i, batch_repeated) for i in range(100)]


# for step, batch in enumerate(train_ds.as_numpy_iterator()):
for step, batch in train_batches:
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  with jax.disable_jit():
    train_step(state_embedder, action_embedder, optimizer, metrics, batch)
  #train_step(state_embedder, action_embedder, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for _, test_batch in test_batches:
      eval_step(state_embedder, action_embedder, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    # print(
    #   f"[test] step: {step}, "
    #   f"loss: {metrics_history['test_loss'][-1]}, "
    #   f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    # )

In [None]:
class StateToActionModel(nnx.Module):
    state_embedder: Connect4StateEmbedder
    action_embedder: Connect4ActionEmbedder

    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, *, rngs: nnx.Rngs):
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder

    def __call__(self, state: jax.Array) -> jax.Array:
        return self.logits(state)

    def logits(self, state_batch: jax.Array) -> jax.Array:
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = state_embeddings @ all_action_embeddings.T
        return logits
    
    def probs(self, state_batch: jax.Array) -> jax.Array:
        logits = self.logits(state_batch)
        return jax.nn.softmax(logits)


state_action_model = StateToActionModel(state_embedder, action_embedder, rngs=nnx.Rngs(0))

logits = state_action_model.logits(state_batch[:2])
probs = state_action_model.probs(state_batch[:2])

print(f'logits: {logits}')
print(f'probs: {probs}')