In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import typing
import functools
from pprint import pprint

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
from open_spiel.python import rl_environment
import dm_env
import acme
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper

from tqdm.notebook import tqdm
import numpy as np
import trueskill

In [2]:
import moozi as mz
from moozi import Game, Action, ActionHistory, Player
from moozi.model import Model

In [None]:
env = OpenSpielWrapper(open_spiel.python.rl_environment.Environment('catch(columns=3,rows=4)'))

In [3]:
for _ in range(10):
    time_step = env.reset()
    while True:
#         print(env.get_state, '\n')
        legals = time_step.observation[env.current_player].legal_actions
        action = np.random.choice(np.nonzero(legals)[0])
        time_step = env.step([action])
        if time_step.last():
#             print(env.get_state, '\n')
            break

In [4]:
env.observation_spec().observation.shape

(12,)

In [5]:
dim_actions = env.action_spec().num_values
dim_repr = 7
dim_obs = env.observation_spec().observation.shape

In [6]:
def get_model(rng, dim_obs, dim_actions, dim_repr):
    initial_inference = hk.without_apply_rng(hk.transform(
        lambda x: Model(dim_actions, dim_repr).initial_inference(x)))
    recurrent_inference = hk.without_apply_rng(hk.transform(
        lambda x, y: Model(dim_actions, dim_repr).recurrent_inference(x, y)))
    key_1, key_2 = jax.random.split(key)
    params = hk.data_structures.merge(
        initial_inference.init(key_1, jnp.ones(dim_obs)),
        recurrent_inference.init(key_2, jnp.ones(dim_repr), jnp.ones(dim_actions))
    )
    return initial_inference, recurrent_inference, params

In [8]:
key = jax.random.PRNGKey(0)
initial_inference, recurrent_inference, params = get_model(key, dim_obs, dim_actions, dim_repr)

In [22]:
import acme
from acme import core
from acme import specs
from acme import types
from acme import wrappers
from acme.environment_loops import open_spiel_environment_loop
from acme.wrappers import open_spiel_wrapper
from open_spiel.python import rl_environment

class RandomActor(core.Actor):
    """Fake actor which generates random actions and validates specs."""

    def __init__(self, spec: specs.EnvironmentSpec):
        self._spec = spec
        self.num_updates = 0

    def select_action(self, observation: open_spiel_wrapper.OLT) -> int:
#         _validate_spec(self._spec.observations, observation)
        legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32)
        return np.random.choice(legals[0])

    def observe_first(self, timestep: dm_env.TimeStep):
#         _validate_spec(self._spec.observations, timestep.observation)
        pass

    def observe(self, action: types.NestedArray,
                next_timestep: dm_env.TimeStep):
        pass

    def update(self, wait: bool = False):
        self.num_updates += 1

In [23]:
raw_env = rl_environment.Environment('tic_tac_toe')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
environment_spec = acme.make_environment_spec(env)

actors = []
for _ in range(env.num_players):
    actors.append(RandomActor(environment_spec))

loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors)
result = loop.run_episode()

loop.run(num_episodes=10)
loop.run(num_steps=100)

{'episode_length': 5, 'episode_return': array([ 1., -1.], dtype=float32), 'steps_per_second': 1426.634013605442, 'episodes': 1, 'steps': 5}


In [29]:
result = loop.run(num_episodes=300)
print(result)

None


In [26]:
ls

CMakeLists.txt  [0m[01;34marchive[0m/    [01;34mmoozi[0m/            sandbox.ipynb
README.md       [01;34mbuild[0m/      [01;34mmoozi.egg-info[0m/   setup.py
TODO            game.ipynb  pseudocode.py     test_model_supervised.ipynb
app.py          main.cpp    requirements.txt  [01;34mtests[0m/
