In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
import acme.jax.utils
from acme.agents import agent as acme_agent
from acme.agents import replay as acme_replay
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

In [3]:
import moozi as mz

In [4]:
# %run hardware_sanity_check.ipynb

In [5]:
# OpenSpiel environment, not using it for now since not supported by the latest relased Acme
# raw_env = open_spiel.python.rl_environment.Environment('catch(columns=8,rows=4)')
raw_env = open_spiel.python.rl_environment.Environment('cliff_walking')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape
dim_repr = 3
print(env_spec)

EnvironmentSpec(observations=OLT(observation=Array(shape=(400,), dtype=dtype('float32'), name=None), legal_actions=Array(shape=(4,), dtype=dtype('float32'), name=None), terminal=Array(shape=(1,), dtype=dtype('float32'), name=None)), actions=DiscreteArray(shape=(), dtype=int32, name=None, minimum=0, maximum=3, num_values=4), rewards=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=-199.0, maximum=-9.0), discounts=BoundedArray(shape=(), dtype=dtype('float32'), name=None, minimum=0.0, maximum=1.0))


In [6]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)
print(nn_spec)
network = mz.nn.get_network(nn_spec)
optimizer = optax.adam(1e-4)

NeuralNetworkSpec(dim_image=(400,), dim_repr=3, dim_action=4)


In [7]:
batch_size = 128
n_steps=5
# reverb_replay = acme_replay.make_reverb_prioritized_nstep_replay(
#     env_spec, batch_size=batch_size, n_step=5)
reverb_replay = acme_replay.make_reverb_prioritized_sequence_replay(
    env_spec, batch_size=batch_size)
actor = mz.actor.RandomActor(reverb_replay.adder)
learner = mz.learner.MooZiLearner(
    network=network,
    loss_fn=mz.loss.initial_inference_value_loss,
    optimizer=optimizer,
    data_iterator=reverb_replay.data_iterator,
    random_key=jax.random.PRNGKey(996),
)

In [8]:
agent = acme_agent.Agent(
    actor=actor, learner=learner, min_observations=100, observations_per_step=1)

In [9]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[agent])
result = loop.run_episode()
print(result)

{'episode_length': 4, 'episode_return': array([-103.], dtype=float32), 'steps_per_second': 30.453421042122727, 'episodes': 1, 'steps': 4}


In [10]:
item = next(learner._data_iterator)
print(item.data.observation.observation.shape)

(128, 121, 400)


In [12]:
actor._adder.__dict__

{'_client': Client, server_address=localhost:16077,
 '_priority_fns': {'priority_table': <function acme.adders.reverb.base.ReverbAdder.__init__.<locals>.<lambda>(x)>},
 '_max_sequence_length': 122,
 '_delta_encoded': True,
 '_max_in_flight_items': 2,
 '_ReverbAdder__writer': <reverb.trajectory_writer.TrajectoryWriter at 0x7f2635615520>,
 '_get_signature_timeout_ms': 300000,
 '_period': 40,
 '_sequence_length': 121,
 '_end_of_episode_behavior': <EndOfEpisodeBehavior.ZERO_PAD: 'zero_pad_til_next_write'>}

In [147]:
for _ in range(10):
    result = loop.run_episode()
    print(loop._observed_first)
    print(result)

[False]
{'episode_length': 1, 'episode_return': array([-100.], dtype=float32), 'steps_per_second': 60.585064278491984, 'episodes': 1, 'steps': 1}
[False]
{'episode_length': 1, 'episode_return': array([-100.], dtype=float32), 'steps_per_second': 79.88693979391654, 'episodes': 2, 'steps': 2}
[False]
{'episode_length': 1, 'episode_return': array([-100.], dtype=float32), 'steps_per_second': 85.5858142714306, 'episodes': 3, 'steps': 3}
[False]
{'episode_length': 1, 'episode_return': array([-100.], dtype=float32), 'steps_per_second': 75.56623727592108, 'episodes': 4, 'steps': 4}
[False]
{'episode_length': 2, 'episode_return': array([-101.], dtype=float32), 'steps_per_second': 96.88963836496148, 'episodes': 5, 'steps': 6}
[False]
{'episode_length': 7, 'episode_return': array([-106.], dtype=float32), 'steps_per_second': 119.45354310846383, 'episodes': 6, 'steps': 13}
[False]
{'episode_length': 31, 'episode_return': array([-130.], dtype=float32), 'steps_per_second': 140.50981176157836, 'episode

In [64]:
%debug

> [0;32m/usr/local/lib/python3.8/dist-packages/acme/environment_loops/open_spiel_environment_loop.py[0m(115)[0;36m_get_player_timestep[0;34m()[0m
[0;32m    113 [0;31m                           player: int) -> dm_env.TimeStep:
[0m[0;32m    114 [0;31m    return dm_env.TimeStep(observation=timestep.observation[player],
[0m[0;32m--> 115 [0;31m                           [0mreward[0m[0;34m=[0m[0mtimestep[0m[0;34m.[0m[0mreward[0m[0;34m[[0m[0mplayer[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    116 [0;31m                           [0mdiscount[0m[0;34m=[0m[0mtimestep[0m[0;34m.[0m[0mdiscount[0m[0;34m[[0m[0mplayer[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m                           step_type=timestep.step_type)
[0m
ipdb> timestep
TimeStep(step_type=<StepType.FIRST: 0>, reward=None, discount=None, observation=[OLT(observation=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

ipdb> u
> [0;32m/usr/local/lib/python3.8/dist-packages/acme/environment_loops/open_spiel_environment_loop.py[0m(146)[0;36mrun_episode[0;34m()[0m
[0;32m    144 [0;31m[0;34m[0m[0m
[0m[0;32m    145 [0;31m    [0;31m# Make the first observation.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 146 [0;31m    [0mself[0m[0;34m.[0m[0m_send_observation[0m[0;34m([0m[0mtimestep[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_environment[0m[0;34m.[0m[0mcurrent_player[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    147 [0;31m[0;34m[0m[0m
[0m[0;32m    148 [0;31m    [0;31m# Run an episode.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> l
[1;32m    141 [0m                                        multiplayer_reward_spec)
[1;32m    142 [0m[0;34m[0m[0m
[1;32m    143 [0m    [0mtimestep[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_environment[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    144 [0m[0;3

In [None]:
sample = next(agent._learner._data_iterator)

In [38]:
sample.data.observation.observation.shape

(128, 121, 400)

In [10]:
loop.run(num_episodes=1000)

AssertionError: Error in rank compatibility check: input 0 has rank 2 (shape (128, 121)) but expected 1; input 1 has rank 2 (shape (128, 121)) but expected 1.