# Trainer NNX scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [46]:
import os
from typing import Any
from pprint import pprint

from collections import Counter


import jax
import jax.numpy as jnp
from flax import nnx

print(jax.devices())
assert jax.devices()[0].platform == 'gpu'

# print wider lines to stop arrays wrapping so soon. numpy default is 75.
jnp.set_printoptions(linewidth=150)

[CudaDevice(id=0)]


# Load Trajectories

In [47]:
from rgi.core import trajectory

game_name = "connect4"
trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 1100


In [48]:
def print_trajectory_stats():
    c = Counter((t.actions[0].item(), t.final_rewards[0].item()) for t in trajectories)
    for action in range(1,7+1):
        win_count, loss_count = c[(action, 1.0)], c[(action, -1.0)]
        print(f'action: {action}, win: {win_count:3d}, loss: {loss_count:3d}, win_pct: {win_count / (win_count + loss_count) * 100:.2f}%')

print_trajectory_stats()

# num_trajectories: 100
# action: 1, win:  14, loss:   9, win_pct: 60.87%
# action: 2, win:   6, loss:   6, win_pct: 50.00%
# action: 3, win:   8, loss:   4, win_pct: 66.67%
# action: 4, win:  11, loss:   3, win_pct: 78.57%
# action: 5, win:  11, loss:   5, win_pct: 68.75%
# action: 6, win:   8, loss:   5, win_pct: 61.54%
# action: 7, win:   5, loss:   5, win_pct: 50.00%

action: 1, win: 102, loss:  86, win_pct: 54.26%
action: 2, win:  85, loss:  75, win_pct: 53.12%
action: 3, win:  84, loss:  64, win_pct: 56.76%
action: 4, win:  98, loss:  52, win_pct: 65.33%
action: 5, win:  89, loss:  49, win_pct: 64.49%
action: 6, win:  80, loss:  75, win_pct: 51.61%
action: 7, win:  83, loss:  78, win_pct: 51.55%


In [49]:
from rgi.core.trajectory import EncodedTrajectory

# define trajectory NamedTuple
from typing import NamedTuple
class TrajectoryStep(NamedTuple):
    move_index: int
    state: jax.Array
    action: jax.Array
    next_state: jax.Array
    reward: jax.Array


# TODO: hack fixup to make reward (0,1) instead of (-1,1)
def fixup_reward(x): return (x+1) / 2

def unroll_trajectory(encoded_trajectories: list[EncodedTrajectory]):
    for t in encoded_trajectories:
        for i in range(t.length - 1):
            yield TrajectoryStep(i, t.states[i], t.actions[i], t.states[i + 1], fixup_reward(t.final_rewards[0]))

all_trajecoty_steps = list(unroll_trajectory(trajectories))
print(f'num_trajectory_steps: {len(all_trajecoty_steps)}')

num_trajectory_steps: 22835


In [None]:
state_batch = jnp.array([t.state for t in all_trajecoty_steps])
action_batch = jnp.array([t.action for t in all_trajecoty_steps])
reward_batch = jnp.array([t.reward for t in all_trajecoty_steps])
batch_full = {'state': state_batch, 'action': action_batch, 'reward': reward_batch}

# batch with 1 or 2 steps for testing
batch_1 = {'state': state_batch[:1], 'action': action_batch[:1], 'reward': reward_batch[:1]}
batch_2 = {'state': state_batch[:2], 'action': action_batch[:2], 'reward': reward_batch[:2]}

# Super easy to get 100% accuracy.
batch_repeated = {'state': jnp.tile(batch_2['state'], (1000,1)), 'action': jnp.tile(batch_2['action'], (1000,)), 'reward': jnp.tile(batch_2['reward'], (1000,))}

# Define models

In [None]:
from rgi.games import connect4
game = connect4.Connect4Game()
serializer = connect4.Connect4Serializer()

In [None]:
class Connect4StateEmbedder(nnx.Module):
    embedding_dim: int = 64
    hidden_dim: int = 256

    def _state_to_array(self, encoded_state_batch: jax.Array) -> jax.Array:
        board_array = encoded_state_batch[:, :-1].reshape([-1, 6, 7])
        return board_array

    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(6*7*64, self.hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(self.hidden_dim, self.embedding_dim, rngs=rngs)

    def __call__(self, encoded_state_batch: jax.Array):
        x = self._state_to_array(encoded_state_batch)
        x = x[..., None]  # Add channel dimension
        x = nnx.relu(self.conv1(x))
        x = nnx.relu(self.conv2(x))
        x = x.reshape(x.shape[0], -1)  # Flatten while preserving batch dimension
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)

        return x


In [None]:
def test_state_embedder():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    print(state_embedder(trajectories[0].states[:2])[:,:5])

test_state_embedder()

In [None]:
class Connect4ActionEmbedder(nnx.Module):
    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, *, rngs: nnx.Rngs):
        self.embedding = nnx.Embed(num_embeddings=self.num_actions, features=self.embedding_dim, rngs=rngs)

    def __call__(self, action: jax.Array) -> jax.Array:
        # Ensure action is 0-indexed
        action = action - 1
        return self.embedding(action)
    
    def all_action_embeddings(self):
        # return self(jnp.arange(1,self.num_actions+1))
        return self.embedding.embedding.value

In [None]:
def test_action_embedder():
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(0))
    print(action_embedder(trajectories[0].actions[:1]))
    print(action_embedder.all_action_embeddings()[:,:5])    

test_action_embedder()

In [None]:
class PredictionModel(nnx.Module):
    """Model to predict action probabilities and reward of the current state."""
    state_embedder: Connect4StateEmbedder
    action_embedder: Connect4ActionEmbedder

    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, *, rngs: nnx.Rngs):
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder
        self.reward_head = nnx.Linear(self.embedding_dim, 1, rngs=rngs)

    def action_logits(self, state_batch: jax.Array) -> jax.Array:
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = state_embeddings @ all_action_embeddings.T
        return logits
    
    def action_probs(self, state_batch: jax.Array) -> jax.Array:
        logits = self.action_logits(state_batch)
        return jax.nn.softmax(logits)
    
    def reward_pred(self, state_batch: jax.Array) -> jax.Array:
        state_embeddings = self.state_embedder(state_batch)
        return self.reward_head(state_embeddings)



In [None]:
def test_state_action_model():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    state_action_model = PredictionModel(state_embedder, action_embedder, rngs=nnx.Rngs(2))
    print('action_logits:\n', state_action_model.action_logits(state_batch[:2]))
    print('\naction_probs:\n', state_action_model.action_probs(state_batch[:2]))
    print('\nreward_pred:\n', state_action_model.reward_pred(state_batch[:2]))

test_state_action_model()


# Train Model

In [None]:
import optax

@nnx.jit
def l2_loss(model, alpha=0.001):
    loss = sum(
        (alpha * (w ** 2).sum()) 
        for w in jax.tree.leaves(nnx.state(model, nnx.Param))
    )
    return loss

@nnx.jit
def loss_fn(prediction_model: PredictionModel, batch, l2_weight: float = 1e-4):
    action_logits = prediction_model.action_logits(batch['state'])
    action_labels = batch['action']-1
    action_data_loss = optax.softmax_cross_entropy_with_integer_labels(logits=action_logits, labels=action_labels).mean()

    reward_pred = prediction_model.reward_pred(batch['state'])
    reward_labels = batch['reward']
    reward_data_loss = ((reward_labels - reward_pred.squeeze())**2).mean()

    parameter_loss_l2 = l2_weight * l2_loss(prediction_model)
    
    total_loss = action_data_loss + reward_data_loss + parameter_loss_l2
    return total_loss, (action_logits, reward_pred)



In [None]:
def test_loss_fn():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    prediction_model = PredictionModel(state_embedder, action_embedder, rngs=nnx.Rngs(2))
    loss, (action_logits, reward_pred) = loss_fn(prediction_model, batch_2)
    print('loss: ', loss)
    print('action_logits: ', action_logits)
    print('reward_pred: ', reward_pred)

with jax.disable_jit():
    test_loss_fn()


# loss:  1.927142
# logits:  [[ 0.          0.          0.          0.          0.          0.          0.        ]
#  [-0.05365274  0.04191583 -0.05556455 -0.0902497  -0.04111161  0.00694049  0.01304064]]

In [None]:
@nnx.jit
def train_step(prediction_model: PredictionModel, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, reward_pred)), grads = grad_fn(prediction_model, batch)
  metrics.update(
    loss=loss,
    action_logits=logits, action_labels=batch['action']-1,
    reward_pred=reward_pred.squeeze(), reward_labels=batch['reward'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(prediction_model: PredictionModel, metrics: nnx.MultiMetric, batch):
  loss, (logits, reward_pred) = loss_fn(prediction_model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.


In [None]:
from flax import nnx
import jax.numpy as jnp
from abc import ABC, abstractmethod
from typing_extensions import override

class TwoValueAverageMetric(nnx.Metric, ABC):
    total: nnx.metrics.MetricState
    count: nnx.metrics.MetricState

    def __init__(self, argname_1: str, argname_2: str):
        self.argname_1 = argname_1
        self.argname_2 = argname_2
        self.total = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.float32))
        self.count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def reset(self) -> None:
        self.total.value = jnp.array(0, dtype=jnp.float32)
        self.count.value = jnp.array(0, dtype=jnp.int32)

    def update(self, **kwargs) -> None:
        if self.argname_1 not in kwargs:
            raise TypeError(f"Expected keyword argument '{self.argname_1}'")
        if self.argname_2 not in kwargs:
            raise TypeError(f"Expected keyword argument '{self.argname_2}'")      
        v1, v2 = kwargs[self.argname_1], kwargs[self.argname_2]
        self.update_pair(v1, v2)

    @abstractmethod
    def update_pair(v1: jax.Array, v2: jax.Array) -> None:
        raise NotImplementedError("Must be implemented by subclass")

    @abstractmethod
    def compute(self) -> jax.Array:
        raise NotImplementedError("Must be implemented by subclass")

class MeanSquaredError(TwoValueAverageMetric):
    @override
    def __init__(self, argname_1:str='mse_1', argname_2:str='mse_2'):
        super().__init__(argname_1, argname_2)

    @override
    def update_pair(self, v1: jax.Array, v2: jax.Array) -> None:
        if v1.shape != v2.shape:
            raise ValueError(f"Expected shapes {v1.shape} and {v2.shape} to be equal")
        self.total.value += ((v1 - v2)**2).sum()
        self.count.value += v1.size

    def compute(self) -> jax.Array:
        """Compute and return the average."""
        return self.total.value / self.count.value

class LogitAccuracy(TwoValueAverageMetric):
    @override
    def __init__(self, argname_1:str='logits', argname_2:str='labels'):
        super().__init__(argname_1, argname_2)

    @override
    def update_pair(self, logits: jax.Array, labels: jax.Array) -> None:
        if logits.ndim != labels.ndim + 1:
            raise ValueError(f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}')

        self.total.value += (logits.argmax(axis=-1) == labels).sum()
        self.count.value += labels.size

    def compute(self) -> jax.Array:
        """Compute and return the average."""
        return self.total.value / self.count.value

In [None]:
def test_mean_squared_error():
    mse = MeanSquaredError('reward_pred', 'reward_labels')
    mse.update(reward_pred=jnp.array([1, 2, 5]), reward_labels=jnp.array([1, 2, 3]))
    print(mse.compute())
    mse.reset()

test_mean_squared_error()

In [None]:
def test_multi_metric():
    metrics = nnx.MultiMetric(
        accuracy=nnx.metrics.Accuracy(),
        action_accuracy=LogitAccuracy("action_logits","action_labels"),
        reward_mse=MeanSquaredError('reward_pred', 'reward_labels'),
        loss=nnx.metrics.Average('loss'))
    # metrics = nnx.MultiMetric(action=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'))
    metrics.update(
        loss=1.0,
        logits=jnp.array([[1, 2, 3], [4, 5, 6], [3,1,1]]),
        labels=jnp.array([1, 2, 0]),
        action_logits=jnp.array([[1, 2, 3], [4, 5, 6], [3,1,1]]),
        action_labels=jnp.array([1, 2, 0]),
        reward_pred=jnp.array([1, 2, 5]),
        reward_labels=jnp.array([1, 2, 3]),
        )
    pprint(metrics.compute())

test_multi_metric()

In [None]:
def test_train_step(print_logits=False, num_steps=400):
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(10))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    prediction_model = PredictionModel(state_embedder, action_embedder, rngs=nnx.Rngs(2))
    optimizer = nnx.Optimizer(prediction_model, optax.adamw(learning_rate=0.0005))
    metrics = nnx.MultiMetric(
        action_accuracy=LogitAccuracy('action_logits', 'action_labels'),
        reward_mse=MeanSquaredError('reward_pred', 'reward_labels'),
        loss=nnx.metrics.Average('loss'))
    
    if print_logits:
        state_action_counts = {}
        state_reward_counts = {}
        for i in range(4):
            state = batch_full['state'][i]
            state_action_counts[i] = Counter(a.item() for s,a in zip(batch_full['state'], batch_full['action']) if (s == state).all())
            state_reward_counts[i] = Counter(r.item() for s,r in zip(batch_full['state'], batch_full['reward']) if (s == state).all())

    for i in range(1, num_steps+1):
        train_step(prediction_model, optimizer, metrics, batch_full)
        if i == 1 or i % 200 == 0 or i == num_steps:
            print(f'train_step: {i}: {metrics.compute()}')
            if print_logits:
                for i in range(4):                
                    state = batch_full['state'][i]
                    action_probs = prediction_model.action_probs(jnp.array([state]))
                    reward_pred = prediction_model.reward_pred(jnp.array([state])).item()
                    reward_true = batch_full['reward'][i].item()
                    print(i, f'r={reward_true} p={reward_pred}', jnp.array_str(action_probs * 100, precision=4, max_line_width=100, suppress_small=True), state_action_counts[i], state_reward_counts[i])
                print()
    return prediction_model

# test_train_step(print_logits=True, num_steps=2000)   # 4m31s
# test_train_step(print_logits=False, num_steps=200)     # 23s
prediction_model = test_train_step(print_logits=False, num_steps=1000)     # 1m52s



In [None]:
# prediction_model.reward_pred(jnp.array([batch_full['state'][0]]))

# prediction_model.reward_pred(batch_full['state'][:1000])
reward_zip = [(a.item(), b.item()) for (a,b) in  zip(batch_full['reward'], prediction_model.reward_pred(batch_full['state']))]
reward_zip[:10]

In [None]:
print(Counter(a for (a,b) in reward_zip))
print(sum(a for (a,b) in reward_zip))
print(sum(b for (a,b) in reward_zip))
print(len(reward_zip))

s_0 = game.initial_state()
s_1 = game.next_state(s_0, 1)

all_moves =[s_0] + [game.next_state(s_0, a) for a in game.all_actions()]
all_moves_jax = [serializer.state_to_jax_array(game, s) for s in all_moves]
# serializer.state_to_jax_array(game, s_0)

prediction_model.reward_pred(jnp.array(all_moves_jax))

In [None]:
print('Move 0:')
s_0 = game.initial_state()
j_0 = serializer.state_to_jax_array(game, s_0)
print(j_0, prediction_model.reward_pred(jnp.array([j_0])))

print('\nMove 1:')
for a_1 in game.all_actions():
    s_1 = game.next_state(s_0, a_1)
    j_1 = serializer.state_to_jax_array(game, s_1)
    print(a_1, j_1, prediction_model.reward_pred(jnp.array([j_1])))

print('\nMove 2:')
for a_1 in game.all_actions():
    s_1 = game.next_state(s_0, a_1)
    for a_2 in game.all_actions():
        s_2 = game.next_state(s_1, a_2)
        j_2 = serializer.state_to_jax_array(game, s_2)
        print(a_1, a_2, j_2, prediction_model.reward_pred(jnp.array([j_2])))

print('\nMove 3:')
for a_1 in game.all_actions():
    s_1 = game.next_state(s_0, a_1)
    for a_2 in game.all_actions():
        s_2 = game.next_state(s_1, a_2)
        for a_3 in game.all_actions():
            s_3 = game.next_state(s_2, a_3)
            j_3 = serializer.state_to_jax_array(game, s_3)
            print(a_1, a_2, a_3, j_3, prediction_model.reward_pred(jnp.array([j_3])))



In [None]:
## 100 trajectories.
# train_step: 1: {'action_accuracy': Array(0.14363885, dtype=float32), 'reward_mse': Array(1.2103255, dtype=float32), 'loss': Array(3.1888344, dtype=float32)}
# 0 r=1.0 p=0.009542055428028107 [[14.301  14.3019 14.2707 14.2778 14.2718 14.2823 14.2946]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 r=1.0 p=0.239242821931839 [[15.7015 14.4646 14.3248 13.5308 13.8753 13.7428 14.3602]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 r=1.0 p=0.38645169138908386 [[16.3885 15.3902 13.6702 13.2374 14.175  12.9441 14.1946]] Counter({4: 1, 3: 1, 1: 1})
# 3 r=1.0 p=0.5279813408851624 [[16.5539 16.2118 13.3792 13.607  13.9671 12.0987 14.1822]] Counter({5: 1})

# train_step: 200: {'action_accuracy': Array(0.39684907, dtype=float32), 'reward_mse': Array(0.3694912, dtype=float32), 'loss': Array(1.9460464, dtype=float32)}
# 0 r=1.0 p=0.2352880835533142 [[23.4505 11.7681 12.0107 15.0266 15.1127 12.3827 10.2487]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 r=1.0 p=0.1304871290922165 [[ 8.4951  9.1253 24.1409  2.3545 11.7189 24.5514 19.6139]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 r=1.0 p=-0.51841139793396 [[12.4069  5.542  19.5634 17.891  13.2847 15.4737 15.8383]] Counter({4: 1, 3: 1, 1: 1})
# 3 r=1.0 p=0.5774312019348145 [[21.8451  4.6146  3.4356 13.3274 22.2893  7.1891 27.2988]] Counter({5: 1})

# train_step: 400: {'action_accuracy': Array(0.6348518, dtype=float32), 'reward_mse': Array(0.2586214, dtype=float32), 'loss': Array(1.2657048, dtype=float32)}
# 0 r=1.0 p=0.24247753620147705 [[22.9135 12.1181 12.1923 13.9984 15.749  13.0248 10.0039]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 r=1.0 p=0.17915339767932892 [[ 7.1621  7.4489 23.7718  0.0898 14.4461 23.1061 23.9753]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 r=1.0 p=-0.3926801085472107 [[32.4896  0.9777 29.2641 33.8393  0.1045  1.6182  1.7066]] Counter({4: 1, 3: 1, 1: 1})
# 3 r=1.0 p=0.8595808744430542 [[17.7365  0.0774  0.0108  3.722  75.0556  1.3197  2.0781]] Counter({5: 1})


In [None]:
# # TODO: How do we stream a bunch of different batches?
# # train_batches = [(i, batch) for i in range(2000)]
# # test_batches = [(i, batch) for i in range(100)]

# #train_batches = [(i, batch_repeated) for i in range(2000)]
# #test_batches = [(i, batch_repeated) for i in range(100)]

# # NUM_TRAIN_BATCHES = 2000
# NUM_TRAIN_BATCHES = 1
# train_batches = [(i, batch_full) for i in range(NUM_TRAIN_BATCHES)]
# test_batches = [(i, batch_full) for i in range(100)]


# metrics_history = {
#   'train_loss': [],
#   'train_accuracy': [],
#   'test_loss': [],
#   'test_accuracy': [],
# }

# train_steps = 10000
# eval_every = 200
# batch_size = 32

# # for step, batch in enumerate(train_ds.as_numpy_iterator()):
# for step, batch in train_batches:
#   # with jax.disable_jit():
#   #   train_step(state_embedder, action_embedder, optimizer, metrics, batch)
#   train_step(state_embedder, action_embedder, optimizer, metrics, batch)

#   if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
#     # Log the training metrics.
#     for metric, value in metrics.compute().items():  # Compute the metrics.
#       metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
#     metrics.reset()  # Reset the metrics for the test set.

#     # Compute the metrics on the test set after each training epoch.
#     for _, test_batch in test_batches:
#       eval_step(state_embedder, action_embedder, metrics, test_batch)

#     # Log the test metrics.
#     for metric, value in metrics.compute().items():
#       metrics_history[f'test_{metric}'].append(value)
#     metrics.reset()  # Reset the metrics for the next training epoch.

#     print(
#       f"[train] step: {step}, "
#       f"loss: {metrics_history['train_loss'][-1]}, "
#       f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
#     )
#     # print(
#     #   f"[test] step: {step}, "
#     #   f"loss: {metrics_history['test_loss'][-1]}, "
#     #   f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
#     # )


In [None]:
from flax import nnx
import orbax.checkpoint as ocp

checkpoint_base_dir = '/workspaces/rgi/data/models/connect4'
os.makedirs(checkpoint_base_dir, exist_ok=True)
checkpoint_dir = os.path.join(checkpoint_base_dir, 'foo')
ckpt_dir = ocp.test_utils.erase_and_create_empty(checkpoint_dir)

In [None]:
_, model_state = nnx.split(prediction_model)
with ocp.StandardCheckpointer() as checkpointer:
    checkpointer.save(ckpt_dir / 'model_state', model_state)

In [None]:
from flax import nnx
import orbax.checkpoint as ocp


abstract_model = nnx.eval_shape(lambda: PredictionModel(Connect4StateEmbedder(rngs=nnx.Rngs(0)), Connect4ActionEmbedder(rngs=nnx.Rngs(1)), rngs=nnx.Rngs(2)))
graphdef, abstract_state = nnx.split(abstract_model)

with ocp.StandardCheckpointer() as checkpointer:
    state_restored = checkpointer.restore(ckpt_dir / 'model_state', abstract_state)
model_restored = nnx.merge(graphdef, state_restored)

In [None]:
print('Move 0:')
s_0 = game.initial_state()
j_0 = serializer.state_to_jax_array(game, s_0)
print(j_0, prediction_model.reward_pred(jnp.array([j_0])), prediction_model.action_probs(jnp.array([j_0])))
print(j_0, model_restored.reward_pred(jnp.array([j_0])), model_restored.action_probs(jnp.array([j_0])))

print('Move 3:')
print(j_3, prediction_model.reward_pred(jnp.array([j_3])), prediction_model.action_probs(jnp.array([j_3])))
print(j_3, model_restored.reward_pred(jnp.array([j_3])), model_restored.action_probs(jnp.array([j_3])))

# Move 0:
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1] [[0.5674922]] [[0.17131999 0.14294909 0.12640913 0.1492165  0.13335824 0.13016856 0.14657846]]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1] [[0.5674922]] [[0.17131999 0.14294909 0.12640913 0.1492165  0.13335824 0.13016856 0.14657846]]
# Move 3:
# [0 0 0 0 0 0 1 0 0 0 0 0 0 2 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2] [[0.62254393]] [[3.2540622e-01 5.2950774e-05 3.6740914e-02 3.4891254e-01 1.9822055e-02 4.7829631e-03 2.6428244e-01]]
# [0 0 0 0 0 0 1 0 0 0 0 0 0 2 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2] [[0.62254393]] [[3.2540622e-01 5.2950774e-05 3.6740914e-02 3.4891254e-01 1.9822055e-02 4.7829631e-03 2.6428244e-01]]

In [None]:
# model_restored.__class__ == prediction_model.__class__
model_restored == prediction_model