In [2]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten
import jaxboard

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
import acme.jax.utils
import acme.jax.variable_utils
from acme.agents import agent as acme_agent
from acme.agents import replay as acme_replay
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

import moozi as mz

In [2]:
seed = 0
key = jax.random.PRNGKey(seed)

In [3]:
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=7,rows=5)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
max_game_length = env.environment.environment.game.max_game_length()
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape[0]
dim_repr = 3
print(env_spec)

EnvironmentSpec(observations=OLT(observation=Array(shape=(35,), dtype=dtype('float32'), name=None), legal_actions=Array(shape=(3,), dtype=dtype('float32'), name=None), terminal=Array(shape=(1,), dtype=dtype('float32'), name=None)), actions=DiscreteArray(shape=(), dtype=int32, name=None, minimum=0, maximum=2, num_values=3), rewards=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=-1.0, maximum=1.0), discounts=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=0.0, maximum=1.0))


In [4]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)
network = mz.nn.get_network(nn_spec)
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
print(nn_spec)

NeuralNetworkSpec(dim_image=35, dim_repr=3, dim_action=3, repr_net_sizes=(16, 16), pred_net_sizes=(16, 16), dyna_net_sizes=(16, 16))


In [5]:
batch_size = 32
n_steps=5
reverb_replay = acme_replay.make_reverb_prioritized_nstep_replay(
    env_spec, batch_size=batch_size, n_step=n_steps)

In [6]:
learner = mz.learner.MooZiLearner(
    network=network,
    loss_fn=mz.loss.n_step_prior_vanilla_policy_gradient_loss,
    optimizer=optimizer,
    data_iterator=reverb_replay.data_iterator,
    random_key=jax.random.PRNGKey(996),
)

In [7]:
key, new_key = jax.random.split(key)
variable_client = acme.jax.variable_utils.VariableClient(learner, None)

In [8]:
key, new_key = jax.random.split(key)
actor = mz.actor.PriorPolicyActor(
    environment_spec=env_spec,
    network=network,
    adder=reverb_replay.adder,
    variable_client=variable_client,
    random_key=new_key,
#     epsilon=0.1,
#     temperature=1
)

In [9]:
agent = acme_agent.Agent(
    actor=actor,
    learner=learner,
    min_observations=100,
    observations_per_step=1
)

In [10]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[agent])

In [11]:
num_steps = 5_000
loop.run(num_steps=num_steps)

Loss = -0.5049116611480713 | Steps = 1
Loss = -0.6460065841674805 | Steps = 41
Loss = -0.46696144342422485 | Steps = 81
Loss = -0.4846264123916626 | Steps = 123
Loss = -0.4565252959728241 | Steps = 166
Loss = -0.447548508644104 | Steps = 208
Loss = -0.57877117395401 | Steps = 250
Loss = -0.5818035006523132 | Steps = 289
Loss = -0.5456626415252686 | Steps = 326
Loss = -0.5138188004493713 | Steps = 364
Loss = -0.38858678936958313 | Steps = 401
Loss = -0.5282334685325623 | Steps = 440
Loss = -0.6908484697341919 | Steps = 482
Loss = -0.625980019569397 | Steps = 524
Loss = -0.3960024118423462 | Steps = 566
Loss = -0.4678385257720947 | Steps = 609
Loss = -0.5946560502052307 | Steps = 650
Loss = -0.5355611443519592 | Steps = 691
Loss = -0.7328777313232422 | Steps = 729
Loss = -0.24787819385528564 | Steps = 769
Loss = -0.4891621172428131 | Steps = 811
Loss = -0.4183257520198822 | Steps = 854
Loss = -1.1634336709976196 | Steps = 897
Loss = -1.6184744834899902 | Steps = 940
Loss = -1.88362753391

In [12]:
print(learner._jaxboard_logger._log_dir)

/tmp/moozi-log-52a4b607-d040-4c


In [19]:
logits = -np.ones(30)
logits[0] = 1
jax.nn.softmax(logits)

DeviceArray([0.20305708, 0.02748079, 0.02748079, 0.02748079, 0.02748079,
             0.02748079, 0.02748079, 0.02748079, 0.02748079, 0.02748079,
             0.02748079, 0.02748079, 0.02748079, 0.02748079, 0.02748079,
             0.02748079, 0.02748079, 0.02748079, 0.02748079, 0.02748079,
             0.02748079, 0.02748079, 0.02748079, 0.02748079, 0.02748079,
             0.02748079, 0.02748079, 0.02748079, 0.02748079, 0.02748079],            dtype=float32)