In [2]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

from tqdm.notebook import tqdm
import pyspiel
import numpy as np
import trueskill

import moozi as mz

In [3]:
game = pyspiel.load_game('tic_tac_toe')
key = jax.random.PRNGKey(0)
dim_image = game.observation_tensor_size()
dim_repr = 4
dim_actions = game.num_distinct_actions()
all_actions = list(range(dim_actions))



In [4]:
def mse(pred, true):
    return jnp.mean(jnp.square(pred - true))

In [5]:
class Model(hk.Module):
    def repr_net(self, image):
        net = hk.nets.MLP(output_sizes=[16, 16, dim_repr], name='repr')
        return net(image)

    def pred_net(self, hidden_state):
        v_net = hk.nets.MLP(output_sizes=[16, 16, 1], name='pred_v')
        p_net = hk.nets.MLP(output_sizes=[16, 16, dim_actions], name='pred_p')
        return v_net(hidden_state), p_net(hidden_state)

    def dyna_net(self, hidden_state, action):
        state_action_repr = jnp.concatenate((hidden_state, action), axis=-1)
        transition_net = hk.nets.MLP(
            output_sizes=[16, 16, dim_repr], name='dyna_trans')
        reward_net = hk.nets.MLP(
            output_sizes=[16, 16, dim_repr], name='dyna_reward')
        return transition_net(state_action_repr), reward_net(state_action_repr)

    def initial_inference(self, image):
        hidden_state = self.repr_net(image)
        reward = 0
        value, policy_logits = self.pred_net(hidden_state)
        return NetworkOutput(
            value=value,
            reward=reward,
            policy_logits=policy_logits,
            hidden_state=hidden_state
        )

    def recurrent_inference(self, hidden_state, action):
        hidden_state, reward = self.dyna_net(hidden_state, action)
        value, policy_logits = self.pred_net(hidden_state)
        return NetworkOutput(
            value=value,
            reward=reward,
            policy_logits=policy_logits,
            hidden_state=hidden_state
        )

In [6]:
initial_inference = hk.without_apply_rng(hk.transform(lambda x: Model().initial_inference(x)))
recurrent_inference = hk.without_apply_rng(hk.transform(lambda x, y: Model().recurrent_inference(x, y)))

In [7]:
params = hk.data_structures.merge(
    initial_inference.init(key, jnp.ones(dim_image)),
    recurrent_inference.init(key, jnp.ones(dim_repr), jnp.ones(dim_actions))
)

NameError: name 'NetworkOutput' is not defined

In [None]:
buffer = ReplayBuffer(100)

In [None]:
game_state = game.new_initial_state()
image = np.array(game_state.observation_tensor(0))

In [None]:
state_repr = initial_inference.apply(params, image)