In [1]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import datetime
import typing
import functools
from pprint import pprint

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten

import pyspiel
import open_spiel
import dm_env
import acme
import acme.wrappers
from acme.agents import agent
from acme.agents import replay
# from acme import core
# from acme import specs
# from acme import types
# from acme import wrappers
from acme.jax.utils import prefetch
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper
# import bsuite

from tqdm.notebook import tqdm
import numpy as np
import trueskill

In [6]:
import moozi as mz
# from moozi import Game, Action, ActionHistory, Player

In [8]:
mz.nn

<module 'moozi.nerual_network' from '/workspace/moozi/nerual_network.py'>

In [4]:
%run hardware_sanity_check.ipynb

/usr/lib/x86_64-linux-gnu/libcuda.so.450.51.06
/usr/lib/x86_64-linux-gnu/libcuda.so
/usr/lib/x86_64-linux-gnu/libcuda.so.1
/usr/local/lib/python3.8/dist-packages/torch/lib/libcudart-80664282.so.10.2
/usr/local/cuda-11.0/compat/libcuda.so
/usr/local/cuda-11.0/compat/libcuda.so.1
/usr/local/cuda-11.0/compat/libcuda.so.450.119.03
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudart.so.11.0
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudart.so.11.0.221
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudadevrt.a
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudart.so
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudart_static.a
/usr/local/cuda-11.0/targets/x86_64-linux/lib/stubs/libcuda.so
/usr/local/cuda-11.0/extras/Debugger/include/libcudacore.h
/usr/local/cuda-11.0/extras/Debugger/lib64/libcudacore.a
Sun May 23 23:49:24 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Ver

In [6]:
class LossExtra(typing.NamedTuple):
    metrics: typing.Dict[str, jnp.DeviceArray]

In [7]:
def loss_fn(network, params, batch):
    inf_out = network.initial_inference(
        params, batch.data.observation.observation)
    # TODO: use TD error instead
    loss_scalar = jnp.mean(jnp.square(inf_out.value - batch.data.reward))
    extra = LossExtra({})
    return loss_scalar, extra

In [8]:
# https://github.com/deepmind/acme/blob/a6b4162701542ed08b0b911ffac7c69fcb1bb3c7/acme/agents/jax/dqn/learning_lib.py#L68
class TrainingState(typing.NamedTuple):
    params: typing.Any
    opt_state: optax.OptState
    steps: int
    rng_key: jax.random.PRNGKey
        
class MyLearner(acme.Learner):
    def __init__(
        self,
        network,
        loss_fn,
        optimizer,
        data_iterator, 
        random_key
    ):
        self.network = network
        self._loss = jax.jit(functools.partial(loss_fn, self.network))
        
        @jax.jit
        def sgd_step(training_state, batch):
            key, new_key = jax.random.split(training_state.rng_key)
            (loss, extra), grads = jax.value_and_grad(self._loss, has_aux=True)(training_state.params, batch)
            extra.metrics.update({'loss': loss})
            updates, new_opt_state = optimizer.update(grads, training_state.opt_state)
            new_params = optax.apply_updates(training_state.params, updates)
            steps = training_state.steps + 1
            new_training_state = TrainingState(new_params, new_opt_state, steps, new_key)
            return new_training_state, extra
        
        self._sgd_step = sgd_step
        self._data_iterator = prefetch(data_iterator)
        
        key_params, key_state = jax.random.split(random_key, 2)
        params = self.network.init(key_params)
        self._state = TrainingState(
            params=params,
            opt_state=optimizer.init(params),
            steps=0,
            rng_key=key_state
        )
        self._counter = acme.utils.counting.Counter()
        self._logger = acme.utils.loggers.TerminalLogger(time_delta=1., print_fn=print)
        
    def step(self):
        batch = next(self._data_iterator)
        self._state, extra = self._sgd_step(self._state, batch)
        result = self._counter.increment(steps=1)
        result.update(extra.metrics)
        self._logger.write(result)
        
    def get_variables(self, names):
        return [self._state.params]
    
    def save(self):
        return self._state
    
    def restore(self, state):
        self._state = state
        
class DoNothingLearner(acme.Learner):
    def __init__(self, *args, **kwargs):
        pass
    
    def step(self):
        pass
#         batch = next(self._data_iterator)
#         self._state, loss = self._sgd_step(self._state, batch)
#         self._logger.write(loss)
        
    def get_variables(self, names):
        pass
#         return [self._state.params]
    
    def save(self):
        pass
#         return self._state
    
    def restore(self, state):
        pass

In [9]:
class MyAgentConfig(typing.NamedTuple):
    batch_size: int = 32
    learning_rate: float = 1e-3
    min_observation: int = 0
    observations_per_step: float = 1

class MyAgent(agent.Agent):
    def __init__(self, env_spec, config, network):
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            env_spec, batch_size=config.batch_size,
            n_step=3
        )
        # a must-have line?
        self._server = reverb_replay.server
        optimizer = optax.adam(config.learning_rate)
        actor = RandomActor(reverb_replay.adder)
        learner = MyLearner(
            network=network,
            loss_fn=loss_fn,
            optimizer=optimizer,
            data_iterator=reverb_replay.data_iterator, 
            random_key=jax.random.PRNGKey(0)
        )
        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=config.min_observation,
            observations_per_step=config.observations_per_step
        )
        
class RandomAgent(agent.Agent):
    def __init__(self, env_spec, config, network):
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            env_spec, batch_size=config.batch_size
        )
        
        self._server = reverb_replay.server

        optimizer = optax.adam(config.learning_rate)
        actor = RandomActor(reverb_replay.adder)
        learner = DoNothingLearner()
        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=config.min_observation,
            observations_per_step=config.observations_per_step
        )

In [10]:
# OpenSpiel environment, not using it for now since not supported by the latest relased Acme
raw_env = open_spiel.python.rl_environment.Environment('catch(columns=8,rows=4)')
env = acme.wrappers.open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = acme.wrappers.SinglePrecisionWrapper(env)
env_spec = acme.specs.make_environment_spec(env)
dim_action = env_spec.actions.num_values
dim_image = env_spec.observations.observation.shape
dim_repr = 3

In [13]:
nn_spec = mz.nn.NeuralNetworkSpec(
    dim_image=dim_image,
    dim_repr=dim_repr,
    dim_action=dim_action
)

In [14]:
network = mz.nn.get_network(nn_spec)

In [15]:
my_agent = MyAgent(env_spec, MyAgentConfig(), network)

In [16]:
loop = OpenSpielEnvironmentLoop(environment=env, actors=[my_agent])

In [17]:
loop.run(num_episodes=100)

Loss = 0.00024332775501534343 | Steps = 1
Loss = 0.32159119844436646 | Steps = 106
Loss = 0.2494727522134781 | Steps = 222


In [38]:
sample = next(my_agent._learner._data_iterator)

In [44]:
d = sample.data

In [49]:
d.reward

array([ 0.,  0., -1.,  1.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0., -1.,
       -1., -1., -1., -1., -1.,  0., -1.,  0., -1., -1., -1.,  0., -1.,
       -1.,  0., -1., -1., -1., -1.], dtype=float32)

In [17]:
config = MyAgentConfig()
reverb_replay = replay.make_reverb_prioritized_nstep_replay(
    env_spec, batch_size=config.batch_size
)