In [2]:
import random
import copy
from collections import namedtuple

import pyspiel
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax

import optax
import haiku as hk
from tqdm.notebook import tqdm

print(hk.__version__)

0.0.5.dev


In [3]:
game = pyspiel.load_game('tic_tac_toe')
gamma = 0.99
key = jax.random.PRNGKey(0)
dim_board = game.observation_tensor_size()



In [4]:
def state_to_repr(state, player):
    board_obs = np.asarray(state.observation_tensor(0))
    player_obs = np.asarray(player).reshape((1,))
    return np.concatenate((board_obs, player_obs))

In [5]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self._capacity = capacity
        self._data = []
        self._next_entry_index = 0

    def add(self, element):
        if len(self._data) < self._capacity:
            self._data.append(element)
        else:
            self._data[self._next_entry_index] = element
            self._next_entry_index += 1
            self._next_entry_index %= self._capacity

    def sample(self, num_samples):
        if len(self._data) < num_samples:
            raise ValueError("{} elements could not be sampled from size {}".format(
              num_samples, len(self._data)))
        return random.sample(self._data, num_samples)

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

In [24]:
Experience = namedtuple(
    'Experience', 
    ['s', 'a', 'g', 's_next']
)

def f(self):
    print('r', self.g)
    print(obs_tensor_to_board(self.s), '\n')
    print(obs_tensor_to_board(self.s_next), '\n')

Experience.print = f

In [25]:
key, value_net_params_init_key = jax.random.split(key)

In [75]:
def value_net_hk(x):
    mlp = hk.nets.MLP([64, 32, 16, 1])
    return jnp.tanh(mlp(x))

value_net = hk.without_apply_rng(hk.transform(value_net_hk))

In [76]:
print(hk.experimental.tabulate(value_net)(jnp.zeros((12, 28))))

+-------------------------+-----------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module                  | Config                                  | Module params   | Input      | Output     |   Param count |   Param bytes |
| mlp (MLP)               | MLP(output_sizes=[64, 32, 16, 1])       |                 | f32[12,28] | f32[12,1]  |         4,481 |      17.92 KB |
+-------------------------+-----------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_0 (Linear) | Linear(output_size=64, name='linear_0') | w: f32[28,64]   | f32[12,28] | f32[12,64] |         1,856 |       7.42 KB |
|  └ mlp (MLP)            |                                         | b: f32[64]      |            |            |               |               |
+-------------------------+-----------------------------------------+-----------------+------------+------------+-----------

In [77]:
def state_to_exp(state):
    player = state.current_player()
    state_repr = state_to_repr(state, player)

    action = random.choice(state.legal_actions())
    state.apply_action(action)

    reward = (state.rewards()[0] + 1) / 2

    if state.is_terminal():
        next_state_repr = state_to_repr(state, player)
        v_next_mask = 0
        reward = state.rewards()[0]
    else:
        next_state_repr = state_to_repr(state, state.current_player())
        v_next_mask = 1
        reward = 0
    exp = Experience(state_repr, action, reward, next_state_repr, v_next_mask)
    return exp

In [78]:
def create_batch(samples):
    batch = {
        's': np.stack([sample.s for sample in samples]),
        'a': np.stack([sample.a for sample in samples]),
        'g': np.stack([sample.g for sample in samples]),
        's_next': np.stack([sample.s_next for sample in samples]),
#         'v_next_mask': np.stack([sample.v_next_mask for sample in samples]),
    }
    return batch

In [79]:
def obs_tensor_to_board(s):
    obs = s[:27].reshape(3, 9).T
    obs = obs.argmax(axis=-1).reshape(3, 3).squeeze()
    return '\n'.join(map(str, obs)).replace('[', '').replace(']', '') \
                .replace('2', 'x').replace('1', 'o').replace('0', '.')

In [80]:
def traj_to_exps():
    state = game.new_initial_state()
    obses = []
    actions = []
    rewards = []
    while not state.is_terminal():
        obses.append(np.array(state.observation_tensor()))
        action = random.choice(state.legal_actions())
        actions.append(action)
        state.apply_action(action)
        if state.is_terminal():
            reward = state.rewards()[0]
        else:
            reward = 0
        rewards.append(reward)
    obses.append(np.array(state.observation_tensor(0)))
    exps = []
    for i in range(len(actions)):
        action = actions[i]
        obs = obses[i]
        obs_next = obses[i + 1]
        rest_rewards = rewards[i:]
        G = 0
        for i, r in enumerate(rest_rewards):
            G += (gamma ** i) * r
        exp = Experience(obs, action, G, obs_next)
        exps.append(exp)
    return exps

In [81]:
@jit
def value_net_forward(value_params, x):
    return value_net.apply(value_params, x)

@jit
def value_net_error(value_params, batch):
    v = value_net_forward(value_params, batch['s'])
    target = batch['g']
    error = target - v
    return error

@jit
def value_net_loss(value_params, batch):
    error = value_net_error(value_params, batch)
    return jnp.mean(jnp.square(error))
    
@jit
def value_update(value_params, value_opt_state, batch):
    loss, grads = value_and_grad(value_net_loss)(value_params, batch)
    updates, value_opt_state = value_opt.update(grads, value_opt_state)
    new_params = optax.apply_updates(value_params, updates)
    return loss, new_params, value_opt_state

In [82]:
num_games = 10000
buffer = ReplayBuffer(capacity=50000)
batch_size = 10000
value_net_shape = (batch_size, 27)
value_params = value_net.init(value_net_params_init_key, jnp.zeros(value_net_shape))
value_opt = optax.adam(1e-5)
value_opt_state = value_opt.init(value_params)

In [83]:
for i in tqdm(range(num_games)):
    exps = traj_to_exps()
    for exp in exps:
        buffer.add(exp)

    if len(buffer) >= batch_size:
        samples = buffer.sample(batch_size)
        batch = create_batch(samples)
        loss, value_params, value_opt_state = value_update(
            value_params, value_opt_state, batch)
        if i % 100 == 0:
            print(loss)

  0%|          | 0/10000 [00:00<?, ?it/s]

1.0148703
0.96990997
0.93185925
0.89520526
0.8625934
0.83349895
0.81524986
0.7956222
0.7784736
0.77110577
0.76101637
0.7475865
0.7465021
0.74434775
0.7380743
0.7283181
0.7287822
0.7232114
0.72803485
0.7231534
0.72571677
0.7234094
0.72699136
0.72109526
0.7206454
0.7182596
0.7223592
0.71361727
0.7157363
0.71845025
0.7214123
0.7173888
0.72139233
0.71385425
0.7240179
0.72471005
0.7163072
0.71658295
0.71804607
0.7123363
0.7289148
0.7242673
0.71860015
0.71782094
0.7137435
0.7203043
0.72692895
0.71798325
0.7139135
0.7201965
0.7109823
0.71637493
0.72145087
0.72466135
0.7188185
0.7305182
0.7210333
0.72334933
0.72515905
0.7245476
0.7263353
0.73957455
0.7295589
0.72910696
0.7245447
0.71502465
0.7234984
0.7243146
0.73208576
0.7229909
0.72236717
0.7322102
0.72220933
0.72389066
0.72880167
0.7219267
0.71846545
0.7192258
0.7244987
0.7144807
0.72690594
0.72516406
0.7336955
0.7275932
0.73303455
0.7206732


In [84]:
for exp in buffer.sample(20):
    if exp:
        batch = create_batch([exp])
        print('g', exp.g)
        print('v', value_net_forward(value_params, jnp.array(exp.s)))
#         v_next = compute_v_next(value_params, batch)
#         batch['v_next'] = v_next
#         print('v_next', v_next)
#         print('target', compute_target(batch))
        print('loss', value_net_loss(value_params, batch))
        print(obs_tensor_to_board(exp.s), '\n')
        print(obs_tensor_to_board(exp.s_next))
        print('\n\n')

g 0.9509900498999999
v [0.27829367]
loss 0.45252037
x . .
. . .
. . . 

x . .
. . o
. . .



g 0.9801
v [0.2504881]
loss 0.5323335
x . o
o x o
x . . 

x . o
o x o
x x .



g 0.96059601
v [0.27224517]
loss 0.47382692
x . o
. . .
. . . 

x . o
. . x
. . .



g -0.99
v [0.2403401]
loss 1.5137368
. x o
x o x
. . o 

. x o
x o x
. x o



g 0.0
v [0.26352015]
loss 0.06944287
x x o
o o .
x x o 

x x o
o o x
x x o



g 0.9801
v [0.27168962]
loss 0.50184524
. x .
o o x
x o . 

x x .
o o x
x o .



g 0.96059601
v [0.2700978]
loss 0.4767878
. . .
. x .
. o . 

. . x
. x .
. o .



g -0.99
v [0.28320667]
loss 1.6210554
. . .
x x .
o o . 

. . x
x x .
o o .



g -0.9801
v [0.27063575]
loss 1.56434
. o x
. . x
x . o 

. o x
. o x
x . o



g -0.9509900498999999
v [0.25939992]
loss 1.4650439
. . .
. . .
. o x 

. . .
x . .
. o x



g -0.96059601
v [0.2761472]
loss 1.5295337
o x .
x . .
. . . 

o x .
x . .
. . o



g 0.970299
v [0.27763763]
loss 0.4797798
. x .
o . .
x . . 

. x .
o . .
x . o



g -0.9