In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
import acme.jax.utils
import acme.jax.variable_utils
from acme.agents import agent as acme_agent
from acme.agents import replay as acme_replay
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

import moozi as mz

In [2]:
seed = 0
key = jax.random.PRNGKey(seed)

In [3]:
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=7,rows=5)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
max_game_length = env.environment.environment.game.max_game_length()
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape
dim_repr = 3
print(env_spec)
# mz.utils.print_traj_in_env(env)

EnvironmentSpec(observations=OLT(observation=Array(shape=(35,), dtype=dtype('float32'), name=None), legal_actions=Array(shape=(3,), dtype=dtype('float32'), name=None), terminal=Array(shape=(1,), dtype=dtype('float32'), name=None)), actions=DiscreteArray(shape=(), dtype=int32, name=None, minimum=0, maximum=2, num_values=3), rewards=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=-1.0, maximum=1.0), discounts=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=0.0, maximum=1.0))


In [4]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)
print(nn_spec)
network = mz.nn.get_network(nn_spec)
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)

NeuralNetworkSpec(dim_image=(35,), dim_repr=3, dim_action=3)


In [5]:
batch_size = 16
n_steps=5
reverb_replay = acme_replay.make_reverb_prioritized_nstep_replay(
    env_spec, batch_size=batch_size, n_step=n_steps)

In [6]:
learner = mz.learner.MooZiLearner(
    network=network,
    loss_fn=mz.loss.initial_inference_value_loss,
    optimizer=optimizer,
    data_iterator=reverb_replay.data_iterator,
    random_key=jax.random.PRNGKey(996),
)

In [7]:
key, new_key = jax.random.split(key)
variable_client = acme.jax.variable_utils.VariableClient(learner, new_key)

In [8]:
key, new_key = jax.random.split(key)
actor = mz.actor.PriorPolicyActor(
    network=network,
    adder=reverb_replay.adder,
    variable_client=variable_client,
    random_key=new_key,
    epsilon=0.1,
    temperature=1
)

In [9]:
agent = acme_agent.Agent(
    actor=actor,
    learner=learner,
    min_observations=100,
    observations_per_step=1
)

In [10]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[agent])
loop.run_episode()

ValueError: vmap got arg 0 of rank 1 but axis to be mapped 1. The tree of ranks is:
((FlatMapping({
  '__neural_network_haiku/~dyna_net/dyna_reward/~/linear_0': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~dyna_net/dyna_reward/~/linear_1': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~dyna_net/dyna_reward/~/linear_2': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~dyna_net/dyna_trans/~/linear_0': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~dyna_net/dyna_trans/~/linear_1': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~dyna_net/dyna_trans/~/linear_2': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_p/~/linear_0': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_p/~/linear_1': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_p/~/linear_2': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_v/~/linear_0': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_v/~/linear_1': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~pred_net/pred_v/~/linear_2': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~repr_net/repr/~/linear_0': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~repr_net/repr/~/linear_1': FlatMapping({'b': 1, 'w': 2}),
  '__neural_network_haiku/~repr_net/repr/~/linear_2': FlatMapping({'b': 1, 'w': 2}),
}), OLT(observation=1, legal_actions=1, terminal=1), 1), {})

In [None]:
num_episodes = 1000
result = loop.run(num_episodes=num_episodes)

In [None]:
olt = next(reverb_replay.data_iterator).data.observation