In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
import acme.jax.utils
import acme.jax.variable_utils
from acme.agents import agent as acme_agent
from acme.agents import replay as acme_replay
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

import moozi as mz

In [2]:
seed = 0
key = jax.random.PRNGKey(seed)

In [3]:
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=7,rows=5)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
max_game_length = env.environment.environment.game.max_game_length()
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape[0]
dim_repr = 3
print(env_spec)

EnvironmentSpec(observations=OLT(observation=Array(shape=(35,), dtype=dtype('float32'), name=None), legal_actions=Array(shape=(3,), dtype=dtype('float32'), name=None), terminal=Array(shape=(1,), dtype=dtype('float32'), name=None)), actions=DiscreteArray(shape=(), dtype=int32, name=None, minimum=0, maximum=2, num_values=3), rewards=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=-1.0, maximum=1.0), discounts=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=0.0, maximum=1.0))


In [4]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)
network = mz.nn.get_network(nn_spec)
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
print(nn_spec)

NeuralNetworkSpec(dim_image=35, dim_repr=3, dim_action=3, repr_net_sizes=(16, 16), pred_net_sizes=(16, 16), dyna_net_sizes=(16, 16))


In [5]:
batch_size = 32
n_steps=5
reverb_replay = acme_replay.make_reverb_prioritized_nstep_replay(
    env_spec, batch_size=batch_size, n_step=n_steps)

In [6]:
learner = mz.learner.MooZiLearner(
    network=network,
    loss_fn=mz.loss.n_step_prior_vanilla_policy_gradient_loss,
    optimizer=optimizer,
    data_iterator=reverb_replay.data_iterator,
    random_key=jax.random.PRNGKey(996),
)

In [7]:
key, new_key = jax.random.split(key)
variable_client = acme.jax.variable_utils.VariableClient(learner, None)

In [8]:
key, new_key = jax.random.split(key)
actor = mz.actor.PriorPolicyActor(
    environment_spec=env_spec,
    network=network,
    adder=reverb_replay.adder,
    variable_client=variable_client,
    random_key=new_key,
#     epsilon=0.1,
#     temperature=1
)

In [9]:
agent = acme_agent.Agent(
    actor=actor,
    learner=learner,
    min_observations=100,
    observations_per_step=1
)

In [10]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[agent])

In [11]:
loop.run_episode()

{'episode_length': 4,
 'episode_return': array([-1.], dtype=float32),
 'steps_per_second': 7.618222834122224,
 'episodes': 1,
 'steps': 4}

In [12]:
num_steps = 10000
loop.run(num_steps=num_steps)

Loss = -0.6219470500946045 | Steps = 1
Loss = -0.7859916090965271 | Steps = 89
Loss = -0.5349277257919312 | Steps = 175
Loss = -0.9377725720405579 | Steps = 261
Loss = -1.8939168453216553 | Steps = 345
Loss = -0.6036562919616699 | Steps = 433
Loss = -8.336690902709961 | Steps = 524
Loss = -12.977231979370117 | Steps = 611
Loss = -28.27684211730957 | Steps = 700
Loss = -42.90653991699219 | Steps = 788
Loss = -101.74295806884766 | Steps = 875
Loss = -268.8840637207031 | Steps = 969
Loss = -567.10791015625 | Steps = 1065
Loss = -578.9107666015625 | Steps = 1159
Loss = -2244.784423828125 | Steps = 1252
Loss = -2540.06494140625 | Steps = 1345
Loss = -10141.755859375 | Steps = 1439
Loss = -11093.265625 | Steps = 1529
Loss = -27739.451171875 | Steps = 1618
Loss = -49846.96484375 | Steps = 1709
Loss = -79981.6171875 | Steps = 1801
Loss = -51072.41015625 | Steps = 1893
Loss = -124585.546875 | Steps = 1985
Loss = -204476.5 | Steps = 2075
Loss = -354924.9375 | Steps = 2166
Loss = -368487.625 | St