# Trainer NNX scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [1]:
import os
from typing import Any

from flax import nnx  # The Flax NNX API.
from functools import partial

import jax
import jax.numpy as jnp

print(jax.devices())
assert jax.devices()[0].platform == 'gpu'

[CudaDevice(id=0)]


# Load Trajectories

In [2]:
from rgi.core import trajectory

game_name = "connect4"
trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

print(f'num_trajectories: {len(trajectories)}')


num_trajectories: 100


In [3]:
print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

from collections import Counter
c = Counter((t.actions[0].item(), t.final_rewards[0].item()) for t in trajectories)

for action in range(1,7+1):
    win_count, loss_count = c[(action, 1.0)], c[(action, -1.0)]
    print(f'action: {action}, win: {win_count:3d}, loss: {loss_count:3d}, win_pct: {win_count / (win_count + loss_count) * 100:.2f}%')
    
# num_trajectories: 100
# action: 1, win:  14, loss:   9, win_pct: 60.87%
# action: 2, win:   6, loss:   6, win_pct: 50.00%
# action: 3, win:   8, loss:   4, win_pct: 66.67%
# action: 4, win:  11, loss:   3, win_pct: 78.57%
# action: 5, win:  11, loss:   5, win_pct: 68.75%
# action: 6, win:   8, loss:   5, win_pct: 61.54%
# action: 7, win:   5, loss:   5, win_pct: 50.00%

trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 100


2024-10-14 19:13:53.759849: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


action: 1, win:  14, loss:   9, win_pct: 60.87%
action: 2, win:   6, loss:   6, win_pct: 50.00%
action: 3, win:   8, loss:   4, win_pct: 66.67%
action: 4, win:  11, loss:   3, win_pct: 78.57%
action: 5, win:  11, loss:   5, win_pct: 68.75%
action: 6, win:   8, loss:   5, win_pct: 61.54%
action: 7, win:   5, loss:   5, win_pct: 50.00%


In [4]:
from rgi.core.trajectory import EncodedTrajectory

# define trajectory NamedTuple
from typing import NamedTuple
class TrajectoryStep(NamedTuple):
    move_index: int
    state: jax.Array
    action: jax.Array
    next_state: jax.Array
    reward: jax.Array


def unroll_trajectory(encoded_trajectories: list[EncodedTrajectory]):
    for t in encoded_trajectories:
        for i in range(t.length - 1):
            yield TrajectoryStep(i, t.states[i], t.actions[i], t.states[i + 1], t.state_rewards[i])

t_steps_list = list(unroll_trajectory(trajectories))
print(f'num_trajectory_steps: {len(t_steps_list)}')

num_trajectory_steps: 2193


In [15]:
state_batch = jnp.array([t.state for t in t_steps_list])
action_batch = jnp.array([t.action for t in t_steps_list])
batch_full = {'state': state_batch, 'action': action_batch}
batch_1 = {'state': state_batch[:1], 'action': action_batch[:1]}
batch_2 = {'state': state_batch[:2], 'action': action_batch[:2]}

# Super easy to get 100% accuracy.
batch_repeated = {'state': jnp.tile(batch_2['state'], (1000,1)), 'action': jnp.tile(batch_2['action'], (1000,))}


# Define models

In [6]:
from rgi.games import connect4
game = connect4.Connect4Game()
serializer = connect4.Connect4Serializer()

In [7]:
class Connect4StateEmbedder(nnx.Module):
    embedding_dim: int = 64
    hidden_dim: int = 256

    def _state_to_array(self, encoded_state_batch: jax.Array) -> jax.Array:
        board_array = encoded_state_batch[:, :-1].reshape([-1, 6, 7])
        return board_array

    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(6*7*64, self.hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(self.hidden_dim, self.embedding_dim, rngs=rngs)

    def __call__(self, encoded_state_batch: jax.Array):
        x = self._state_to_array(encoded_state_batch)
        x = x[..., None]  # Add channel dimension
        x = nnx.relu(self.conv1(x))
        x = nnx.relu(self.conv2(x))
        x = x.reshape(x.shape[0], -1)  # Flatten while preserving batch dimension
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)

        return x


state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))

print(state_embedder(trajectories[0].states[:1])[:,:5])



[[0. 0. 0. 0. 0.]]


In [8]:
class Connect4ActionEmbedder(nnx.Module):
    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, *, rngs: nnx.Rngs):
        self.embedding = nnx.Embed(num_embeddings=self.num_actions, features=self.embedding_dim, rngs=rngs)

    def __call__(self, action: jax.Array) -> jax.Array:
        # Ensure action is 0-indexed
        action = action - 1
        return self.embedding(action)
    
    def all_action_embeddings(self):
        return self(jnp.arange(1,self.num_actions+1))

action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(0))


print(action_embedder(trajectories[0].actions[:1])[:,:5])


[[ 0.06664277  0.10936882 -0.18915226  0.08173456 -0.07859723]]


In [9]:
class StateToActionModel(nnx.Module):
    state_embedder: Connect4StateEmbedder
    action_embedder: Connect4ActionEmbedder

    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, *, rngs: nnx.Rngs):
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder

    def __call__(self, state: jax.Array) -> jax.Array:
        return self.logits(state)

    def logits(self, state_batch: jax.Array) -> jax.Array:
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = state_embeddings @ all_action_embeddings.T
        return logits
    
    def probs(self, state_batch: jax.Array) -> jax.Array:
        logits = self.logits(state_batch)
        return jax.nn.softmax(logits)


state_action_model = StateToActionModel(state_embedder, action_embedder, rngs=nnx.Rngs(0))

logits = state_action_model.logits(state_batch[:2])
probs = state_action_model.probs(state_batch[:2])

print(f'logits:\n{logits}')
print(f'\nprobs:\n{probs}')

logits:
[[ 0.          0.          0.          0.          0.          0.
   0.        ]
 [-0.0733178  -0.04830146 -0.0082891  -0.00399575  0.04669977 -0.08175676
  -0.01695381]]

probs:
[[0.14285715 0.14285715 0.14285715 0.14285715 0.14285715 0.14285715
  0.14285715]
 [0.13621384 0.13966438 0.14536598 0.14599143 0.15358336 0.13506916
  0.14411187]]


# Train Model

In [10]:
import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(state_embedder, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'))


In [None]:
# def loss_fn(model: Connect4StateEmbedder, batch):
#   logits = model(batch['state'])
#   loss = optax.softmax_cross_entropy_with_integer_labels(
#     logits=logits, labels=batch['action']
#   ).mean()
#   return loss, logits

# @nnx.jit
# def train_step(model: Connect4StateEmbedder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
#   """Train for a single step."""
#   grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
#   (loss, logits), grads = grad_fn(model, batch)
#   metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.
#   optimizer.update(grads)  # In-place updates.

# @nnx.jit
# def eval_step(model: Connect4StateEmbedder, metrics: nnx.MultiMetric, batch):
#   loss, logits = loss_fn(model, batch)
#   metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.




In [12]:
@nnx.jit
def l2_loss(model, alpha=0.001):
    loss = sum(
        (alpha * (w ** 2).sum()) 
        for w in jax.tree.leaves(nnx.state(model, nnx.Param))
    )
    return loss

@nnx.jit
def loss_fn_2(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, batch, l2_weight: float = 1e-4):
    state_embeddings = state_embedder(batch['state'])
    all_action_embeddings = action_embedder.all_action_embeddings()
    logits = state_embeddings @ all_action_embeddings.T
    # labels = jax.nn.one_hot(batch['action']-1, num_classes=7)
    labels = batch['action']-1
    data_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()

    state_embedder_l2 = l2_loss(state_embedder)
    action_embedder_l2 = l2_loss(action_embedder)
    
    param_loss = l2_weight * (state_embedder_l2 + action_embedder_l2)

    total_loss = data_loss + param_loss
    return total_loss, logits

loss_fn_2(state_embedder, action_embedder, batch_2)

(Array(1.94158, dtype=float32),
 Array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [-0.07331781, -0.04830145, -0.0082891 , -0.00399575,  0.04669978,
         -0.08175675, -0.01695381]], dtype=float32))

In [13]:
@nnx.jit
def train_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn_2, has_aux=True)
  (loss, logits), grads = grad_fn(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action']-1)  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn_2(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.

In [None]:
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

train_steps = 10000
eval_every = 200
batch_size = 32

# TODO: How do we stream a bunch of different batches?
# train_batches = [(i, batch) for i in range(2000)]
# test_batches = [(i, batch) for i in range(100)]

#train_batches = [(i, batch_repeated) for i in range(2000)]
#test_batches = [(i, batch_repeated) for i in range(100)]

train_batches = [(i, batch_full) for i in range(2000)]
test_batches = [(i, batch_full) for i in range(100)]


# for step, batch in enumerate(train_ds.as_numpy_iterator()):
for step, batch in train_batches:
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  # with jax.disable_jit():
  #   train_step(state_embedder, action_embedder, optimizer, metrics, batch)
  train_step(state_embedder, action_embedder, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for _, test_batch in test_batches:
      eval_step(state_embedder, action_embedder, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    # print(
    #   f"[test] step: {step}, "
    #   f"loss: {metrics_history['test_loss'][-1]}, "
    #   f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    # )


## Outputs with embedding normalization.
# [train] step: 200, loss: 1.1590911149978638, accuracy: 100.0
# [train] step: 400, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 600, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 800, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 1000, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 1199, loss: 1.1590930223464966, accuracy: 100.0
# [train] step: 1200, loss: 1.159092903137207, accuracy: 100.0
# [train] step: 1400, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 1600, loss: 1.1590913534164429, accuracy: 100.0
# [train] step: 1800, loss: 1.1590913534164429, accuracy: 100.0

# logits: [[-0.12943102 -0.18300387 -0.31407255 -0.26985747 -0.03317889  0.8657205  -0.06745078]
#          [-0.3108766  -0.11808072 -0.00877056 -0.12999986 -0.18233624 -0.03913596   0.86498916]]
# probs: [[0.11739378 0.11127014 0.09760144 0.10201374 0.12925485 0.31756592  0.1249001 ]
#         [0.09565745 0.11599759 0.12939627 0.1146232  0.10877853 0.12552617  0.31002077]]
