In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
import acme.jax.utils
from acme.agents import agent as acme_agent
from acme.agents import replay as acme_replay
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

In [3]:
import moozi as mz

In [4]:
# %run hardware_sanity_check.ipynb

In [5]:
# %run moozi/utils.py

In [6]:
# OpenSpiel environment, not using it for now since not supported by the latest relased Acme
# raw_env = open_spiel.python.rl_environment.Environment('catch(columns=8,rows=4)')
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=7,rows=5)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
max_game_length = env.environment.environment.game.max_game_length()
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape
dim_repr = 3
print(env_spec)
# print_traj_in_env(env)

EnvironmentSpec(observations=OLT(observation=Array(shape=(25,), dtype=dtype('float32'), name=None), legal_actions=Array(shape=(3,), dtype=dtype('float32'), name=None), terminal=Array(shape=(1,), dtype=dtype('float32'), name=None)), actions=DiscreteArray(shape=(), dtype=int32, name=None, minimum=0, maximum=2, num_values=3), rewards=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=-1.0, maximum=1.0), discounts=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=0.0, maximum=1.0))


In [7]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)
print(nn_spec)
network = mz.nn.get_network(nn_spec)
optimizer = optax.adam(1e-4)

NeuralNetworkSpec(dim_image=(25,), dim_repr=3, dim_action=3)


In [8]:
batch_size = 16
n_steps=5
reverb_replay = acme_replay.make_reverb_prioritized_nstep_replay(
    env_spec, batch_size=batch_size, n_step=n_steps)
# reverb_replay = acme_replay.make_reverb_prioritized_sequence_replay(
#     env_spec, batch_size=batch_size)
# reverb_replay = make_reverb_episode_replay(
#     env_spec, max_sequence_length=max_game_length
# )

In [9]:
reverb_replay.adder.signature(env_spec)

Transition(observation=OLT(observation=TensorSpec(shape=(25,), dtype=tf.float32, name='observation/observation'), legal_actions=TensorSpec(shape=(3,), dtype=tf.float32, name='observation/legal_actions'), terminal=TensorSpec(shape=(1,), dtype=tf.float32, name='observation/terminal')), action=TensorSpec(shape=(), dtype=tf.int32, name='action'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=TensorSpec(shape=(), dtype=tf.float32, name='discount'), next_observation=OLT(observation=TensorSpec(shape=(25,), dtype=tf.float32, name='next_observation/observation'), legal_actions=TensorSpec(shape=(3,), dtype=tf.float32, name='next_observation/legal_actions'), terminal=TensorSpec(shape=(1,), dtype=tf.float32, name='next_observation/terminal')), extras=())

In [10]:
actor = mz.actor.RandomActor(reverb_replay.adder)
learner = mz.learner.MooZiLearner(
    network=network,
    loss_fn=mz.loss.initial_inference_value_loss,
    optimizer=optimizer,
    data_iterator=reverb_replay.data_iterator,
    random_key=jax.random.PRNGKey(996),
)

In [11]:
agent = acme_agent.Agent(
    actor=actor, learner=learner, min_observations=100, observations_per_step=1)

In [12]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[agent])

In [14]:
result = loop.run(num_episodes=1000)
print(result)

Loss = 0.6425123810768127 | Steps = 302
Loss = 0.623607337474823 | Steps = 397
Loss = 0.5701141357421875 | Steps = 490
Loss = 0.4183977246284485 | Steps = 586
Loss = 0.4966594874858856 | Steps = 678
Loss = 0.4875529706478119 | Steps = 777
Loss = 0.6258119344711304 | Steps = 877
Loss = 0.3970840275287628 | Steps = 980
Loss = 0.46853935718536377 | Steps = 1077
Loss = 0.44663938879966736 | Steps = 1179
Loss = 0.2534307539463043 | Steps = 1285
Loss = 0.3672373294830322 | Steps = 1390
Loss = 0.3139341473579407 | Steps = 1493
Loss = 0.45155611634254456 | Steps = 1594
Loss = 0.42277291417121887 | Steps = 1695
Loss = 0.33315005898475647 | Steps = 1801
Loss = 0.42628541588783264 | Steps = 1904
Loss = 0.24865266680717468 | Steps = 2003
Loss = 0.41429194808006287 | Steps = 2106
Loss = 0.4365229904651642 | Steps = 2211
Loss = 0.3757643699645996 | Steps = 2313
Loss = 0.27186667919158936 | Steps = 2417
Loss = 0.5605866312980652 | Steps = 2519
Loss = 0.4893239140510559 | Steps = 2621
Loss = 0.3701498

In [86]:
# replay signature
sig = reverb_replay.adder.signature(env_spec)
def recursive_as_dict(x):
    if hasattr(x, '_asdict'):
        x = x._asdict()
        return {k: recursive_as_dict(v) for k, v in x.items()}
    else:
        return x
pprint(recursive_as_dict(sig))

{'action': TensorSpec(shape=(), dtype=tf.int32, name='action'),
 'discount': TensorSpec(shape=(), dtype=tf.float32, name='discount'),
 'extras': (),
 'next_observation': {'legal_actions': TensorSpec(shape=(3,), dtype=tf.float32, name='next_observation/legal_actions'),
                      'observation': TensorSpec(shape=(25,), dtype=tf.float32, name='next_observation/observation'),
                      'terminal': TensorSpec(shape=(1,), dtype=tf.float32, name='next_observation/terminal')},
 'observation': {'legal_actions': TensorSpec(shape=(3,), dtype=tf.float32, name='observation/legal_actions'),
                 'observation': TensorSpec(shape=(25,), dtype=tf.float32, name='observation/observation'),
                 'terminal': TensorSpec(shape=(1,), dtype=tf.float32, name='observation/terminal')},
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward')}


In [23]:
from sacred import Experiment
from sacred.observers import MongoObserver

ex = Experiment('jupyter_ex', interactive=True)
ex.observers.append(MongoObserver())

@ex.config
def my_config():
    recipient = "world"
    message = "Hello %s!" % recipient

@ex.main
def my_main(message):
    print(message)

hello
