In [54]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
from acme.agents import agent
from acme.agents import replay
# from acme import core
# from acme import specs
# from acme import types
# from acme import wrappers
from acme.jax.utils import prefetch
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

In [2]:
import moozi as mz
from moozi import Game, Action, ActionHistory, Player
from moozi.model import Model

In [3]:
# dim_actions = env.action_spec().num_values
# dim_repr = 7
# dim_obs = env.observation_spec().observation.shape

In [14]:
def get_model(rng, dim_obs, dim_actions, dim_repr):
    initial_inference = hk.without_apply_rng(hk.transform(
        lambda x: Model(dim_actions, dim_repr).initial_inference(x)))
    recurrent_inference = hk.without_apply_rng(hk.transform(
        lambda x, y: Model(dim_actions, dim_repr).recurrent_inference(x, y)))
    key_1, key_2 = jax.random.split(key)
    params = hk.data_structures.merge(
        initial_inference.init(key_1, jnp.ones(dim_obs)),
        recurrent_inference.init(key_2, jnp.ones(dim_repr), jnp.ones(dim_actions))
    )
    return initial_inference, recurrent_inference, params

In [15]:
class RandomActor(acme.core.Actor):
    def __init__(self, adder):
        self._adder = adder
        
    def select_action(self, observation: acme.wrappers.open_spiel_wrapper.OLT) -> int:
        legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32)
        return np.random.choice(legals[0])

    def observe_first(self, timestep: dm_env.TimeStep):
        self._adder.add_first(timestep)

    def observe(self, action: acme.types.NestedArray, next_timestep: dm_env.TimeStep):
        self._adder.add(action, next_timestep)

    def update(self, wait: bool = False):
        pass
    
class MuZeroLearner(acme.core.Learner):
    def __init__(self, rng, dim_obs, dim_actions, dim_repr):
        self.ini_ref, self.recur_ref, self.params = self.get_model(rng, dim_obs, dim_actions, dim_repr)
        
class RandomAgent(agent.Agent):
    def __init__(self, env_spec: acme.specs.EnvironmentSpec):
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(env_spec, n_step=2)
        key = jax.random.PRNGKey(0)
        learner = MuZeroLearner(key, 10, 10, 10)
        actor = RandomActor(reverb_replay.adder)
        super().__init__(actor=actor)

In [55]:
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=3,rows=4)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.make_environment_spec(env)

In [95]:
reverb_replay = replay.make_reverb_prioritized_sequence_replay(
    env_spec, batch_size=5
)
actor = RandomActor(reverb_replay.adder)
loop = OpenSpielEnvironmentLoop(environment=env, actors=[actor])

In [65]:
loop.run(num_episodes=10)

In [66]:
ditr = prefetch(reverb_replay.data_iterator)

In [67]:
batch = next(ditr)

In [92]:
batch.data.observation.observation.shape

(10, 12)

In [93]:
batch.data.next_observation.observation.shape

(10, 12)

In [23]:
for x in client.sample():
    print(x)

TypeError: NewSampler(): incompatible function arguments. The following argument types are supported:
    1. (self: reverb.libpybind.Client, arg0: str, arg1: int, arg2: int, arg3: int) -> reverb.libpybind.Sampler

Invoked with: <reverb.libpybind.Client object at 0x7f3dcf754670>, None, 1, 1, 3000