In [None]:
import jax
import jax.numpy as jnp

from reinforce.neural.network import (
    StochasticMuZeroNetwork,
    create_network,
    representation_forward,
    prediction_forward,
    afterstate_dynamics_forward,
    afterstate_prediction_forward,
    dynamics_forward,
    encoder_forward,
    count_parameters,
)

from twentyfortyeight.envs import TwentyFortyEight

In [None]:
# ##>: Create the Stochastic MuZero network with all 6 components.
# ##>: observation_shape=(16,) for 4x4 board flattened.
key = jax.random.PRNGKey(42)
network = create_network(
    key=key,
    observation_shape=(16,),
    hidden_size=256,
    num_blocks=10,
    num_actions=4,
    codebook_size=32,
)

print(f'Network config: {network.config}')
print(f'Total parameters: {count_parameters(network):,}')

In [None]:
# ##>: Initialize game environment (not encoded - we use raw 4x4 board).
game = TwentyFortyEight(encoded=False)
_ = game.reset()

In [None]:
# ##>: Get observation and convert to JAX array with batch dimension.
obs = jnp.array(game.observation.flatten())[None, :]  # Shape: (1, 16)
print(f'Observation shape: {obs.shape}')

In [None]:
# ##>: Test representation (h): observation -> hidden state.
hidden_state = representation_forward(network, obs)
print(f'Hidden state shape: {hidden_state.shape}')
print(f'Hidden state sample: {hidden_state[0, :5]}')

In [None]:
# ##>: Test prediction (f): hidden state -> (policy_logits, value).
policy_logits, value = prediction_forward(network, hidden_state)
print(f'Policy logits shape: {policy_logits.shape}')
print(f'Policy logits: {policy_logits}')
print(f'Policy probs: {jax.nn.softmax(policy_logits)}')
print(f'Value: {value}')

In [None]:
# ##>: Test afterstate dynamics (φ): (state, action) -> afterstate.
action = jax.nn.one_hot(0, 4)[None, :]  # Action 0 (left), one-hot encoded
afterstate = afterstate_dynamics_forward(network, hidden_state, action)
print(f'Afterstate shape: {afterstate.shape}')

In [None]:
# ##>: Test afterstate prediction (ψ): afterstate -> (Q-value, chance_logits).
q_value, chance_logits = afterstate_prediction_forward(network, afterstate)
print(f'Q-value: {q_value}')
print(f'Chance logits shape: {chance_logits.shape}')
print(f'Chance probs (top 5): {jax.nn.softmax(chance_logits)[0, :5]}')

In [None]:
# ##>: Test dynamics (g): (afterstate, chance_code) -> (next_state, reward).
chance_code = jax.nn.one_hot(0, 32)[None, :]  # Chance code 0, one-hot encoded
next_state, reward = dynamics_forward(network, afterstate, chance_code)
print(f'Next state shape: {next_state.shape}')
print(f'Reward: {reward}')

In [None]:
# ##>: Test encoder (e): observation -> chance_code.
# ##>: This encodes the observation to a discrete chance code using straight-through estimation.
encoded_chance = encoder_forward(network, obs)
print(f'Encoded chance shape: {encoded_chance.shape}')
print(f'Encoded chance (should be one-hot-ish): {encoded_chance[0, :10]}')
print(f'Argmax of encoded chance: {jnp.argmax(encoded_chance, axis=-1)}')