In [None]:
import random
import copy
from collections import namedtuple

import pyspiel
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax

import optax
import haiku as hk
from tqdm.notebook import tqdm

print(hk.__version__)

In [None]:
game = pyspiel.load_game('tic_tac_toe')
gamma = 0.99
key = jax.random.PRNGKey(0)
dim_board = game.observation_tensor_size()

In [None]:
def state_to_repr(state, player):
    board_obs = np.asarray(state.observation_tensor(0))
    player_obs = np.asarray(player).reshape((1,))
    return np.concatenate((board_obs, player_obs))

In [None]:
all_actions = np.arange(game.num_distinct_actions())
def choose_action(state, probs):
    probs = np.asarray(probs, dtype=np.float32)
    legal_mask = np.asarray(state.legal_actions_mask())
    legal_probs = probs * legal_mask
    legal_probs /= np.sum(legal_probs)
    return np.random.choice(all_actions, p=legal_probs)

In [None]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self._capacity = capacity
        self._data = []
        self._next_entry_index = 0

    def add(self, element):
        if len(self._data) < self._capacity:
            self._data.append(element)
        else:
            self._data[self._next_entry_index] = element
            self._next_entry_index += 1
            self._next_entry_index %= self._capacity

    def sample(self, num_samples):
        if len(self._data) < num_samples:
            raise ValueError("{} elements could not be sampled from size {}".format(
              num_samples, len(self._data)))
        return random.sample(self._data, num_samples)

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

In [45]:
Experience = namedtuple(
    'Experience', 
    ['s', 'a', 'r', 's_next', 'v_next_mask']
)

def f(self):
    print('r', self.r)
    print(obs_tensor_to_board(self.s), '\n')
    print(obs_tensor_to_board(self.s_next), '\n')
Experience.print = f

In [None]:
def mse(y_pred, y_true):
    return jnp.mean(jnp.square(y_pred - y_true))

In [None]:
key, value_net_params_init_key = jax.random.split(key)

In [None]:
# def value_net_hk(x):
#     mlp = hk.Sequential([
# #         hk.Linear(128), jax.nn.relu,
#         hk.Linear(256), jax.nn.relu,
#         hk.Linear(128), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(1), jnp.tanh
#     ])
#     return mlp(x)

def value_net_hk(x):
    mlp = hk.nets.MLP([64, 64, 1])
    return jnp.tanh(mlp(x))

value_net = hk.without_apply_rng(hk.transform(value_net_hk))

In [None]:
print(hk.experimental.tabulate(value_net)(jnp.zeros((12, 28))))

In [None]:
def state_to_exp(state):
    player = state.current_player()
    state_repr = state_to_repr(state, player)

    action = random.choice(state.legal_actions())
    state.apply_action(action)

    reward = (state.rewards()[0] + 1) / 2

    if state.is_terminal():
        next_state_repr = state_to_repr(state, player)
        v_next_mask = 0
        reward = state.rewards()[0]
    else:
        next_state_repr = state_to_repr(state, state.current_player())
        v_next_mask = 1
        reward = 0
    exp = Experience(state_repr, action, reward, next_state_repr, v_next_mask)
    return exp

In [62]:
def create_batch(samples):
    batch = {
        's': np.stack([sample.s for sample in samples]),
        'a': np.stack([sample.a for sample in samples]),
        'r': np.stack([sample.r for sample in samples]),
        's_next': np.stack([sample.s_next for sample in samples]),
        'v_next_mask': np.stack([sample.v_next_mask for sample in samples]),
    }
    return batch

In [63]:
def obs_tensor_to_board(s):
    obs = s[:27].reshape(3, 9).T
    obs = obs.argmax(axis=-1).reshape(3, 3).squeeze()
    return '\n'.join(map(str, obs)).replace('[', '').replace(']', '') \
                .replace('2', 'x').replace('1', 'o').replace('0', '.')

## Simi-gradient TD

In [65]:
@jit
def value_net_forward(value_params, x):
    v = value_net.apply(value_params, x)
    return v

@jit
def compute_v_next(value_params, batch):
    v_next = value_net_forward(value_params, batch['s_next']) * batch['v_next_mask']
    return v_next

@jit
def compute_target(batch):
    target = batch['r'] + gamma * batch['v_next']
    return target

@jit
def td_error(value_params, batch):
    v = value_net_forward(value_params, batch['s'])
    target = compute_target(batch)
    error = target - v
    return error

def value_net_loss(value_params, batch):
    error = td_error(value_params, batch)
    return jnp.mean(jnp.square(error))
    
@jit
def value_update(value_params, value_opt_state, batch):
    loss, grads = value_and_grad(value_net_loss)(value_params, batch)
    updates, value_opt_state = value_opt.update(grads, value_opt_state)
    new_params = optax.apply_updates(value_params, updates)
    return loss, new_params, value_opt_state

In [None]:
num_games = 10000
buffer = ReplayBuffer(capacity=5000)
batch_size = 1000
epochs = 100
value_net_shape = (batch_size, 28)
value_params_old = None
value_params = value_net.init(value_net_params_init_key, jnp.zeros(value_net_shape))
value_opt = optax.adam(1e-5)
value_opt_state = value_opt.init(value_params)

In [None]:
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
        exp = state_to_exp(state)
        buffer.add(exp)

    if len(buffer) >= batch_size:
        samples = buffer.sample(batch_size)
        batch = create_batch(samples)
        batch['v_next'] = compute_v_next(value_params_old, batch)
        loss, value_params, value_opt_state = value_update(
            value_params, value_opt_state, batch)
        
#     value_params_old = copy.deepcopy(value_params)
    value_params_old = value_params
    if i % 100 == 0:
        print(loss)

## REINFORCE

In [61]:
def traj_to_exps():
    state = game.new_initial_state()
    obses = []
    actions = []
    rewards = []
    while not state.is_terminal():
        obses.append(np.array(state.observation_tensor()))
        action = random.choice(state.legal_actions())
        actions.append(action)
        state.apply_action(action)
        if state.is_terminal():
            reward = state.rewards()[0]
        else:
            reward = 0
        rewards.append(reward)
    obses.append(np.array(state.observation_tensor(0)))
    exps = []
    for i in range(len(actions)):
        action = actions[i]
        obs = obses[i]
        obs_next = obses[i + 1]
        rest_rewards = rewards[i:]
        G = 0
        for i, r in enumerate(rest_rewards):
            G += (gamma ** i) * r
        exp = Experience(obs, action, G, obs_next, 0)
        exps.append(exp)
    for exp in exps:
        exp.print()
    return exps

In [None]:
@jit
def value_net_forward(value_params, x):
    v = value_net.apply(value_params, x)
    return v

@jit
def compute_v_next(value_params, batch):
    v_next = value_net_forward(value_params, batch['s_next']) * batch['v_next_mask']
    return v_next

@jit
def compute_target(batch):
    target = batch['r'] + gamma * batch['v_next']
    return target

@jit
def td_error(value_params, batch):
    v = value_net_forward(value_params, batch['s'])
    target = compute_target(batch)
    error = target - v
    return error

def value_net_loss(value_params, batch):
    error = td_error(value_params, batch)
    return jnp.mean(jnp.square(error))
    
@jit
def value_update(value_params, value_opt_state, batch):
    loss, grads = value_and_grad(value_net_loss)(value_params, batch)
    updates, value_opt_state = value_opt.update(grads, value_opt_state)
    new_params = optax.apply_updates(value_params, updates)
    return loss, new_params, value_opt_state

In [64]:
num_games = 1000
buffer = ReplayBuffer(capacity=5000)
batch_size = 1000
epochs = 100
value_net_shape = (batch_size, 28)
value_params = value_net.init(value_net_params_init_key, jnp.zeros(value_net_shape))
value_opt = optax.adam(1e-5)
value_opt_state = value_opt.init(value_params)

In [None]:
for i in tqdm(range(num_games)):
    exps = traj_to_exps()

    if len(buffer) >= batch_size:
        samples = buffer.sample(batch_size)
        batch = create_batch(samples)
        batch['v_next'] = compute_v_next(value_params_old, batch)
        loss, value_params, value_opt_state = value_update(
            value_params, value_opt_state, batch)
        
# #     value_params_old = copy.deepcopy(value_params)
#     value_params_old = value_params
#     if i % 100 == 0:
#         print(loss)

In [39]:
for exp in buffer.sample(20):
    if exp:
        batch = create_batch([exp])
        print('r', exp.r)
        print('v', value_net_forward(value_params, jnp.array(exp.s)))
        v_next = compute_v_next(value_params, batch)
        batch['v_next'] = v_next
        print('v_next', v_next)
        print('target', compute_target(batch))
        print('loss', value_net_loss(value_params, batch))
        print(obs_tensor_to_board(exp.s), '\n')
        print(obs_tensor_to_board(exp.s_next))
        print('\n\n')

r 0
v [0.45561114]
v_next [[0.44146043]]
target [[0.43704584]]
loss 0.00034467026
x . o
o o x
x . . 

x . o
o o x
x x .



r -1.0
v [0.46478873]
v_next [[0.]]
target [[-1.]]
loss 2.1456058
x o x
. x x
o o . 

x o x
. x x
o o o



r 0
v [0.42403108]
v_next [[0.46620712]]
target [[0.46154505]]
loss 0.001407298
. . o
o . x
. x x 

. o o
o . x
. x x



r 0
v [0.39510557]
v_next [[0.40234292]]
target [[0.39831948]]
loss 1.0329232e-05
x . .
. x o
. . o 

x x .
. x o
. . o



r 0
v [0.43726972]
v_next [[0.42403108]]
target [[0.41979077]]
loss 0.00030551344
. . o
o . .
. x x 

. . o
o . x
. x x



r 0
v [0.37565485]
v_next [[0.41991568]]
target [[0.41571653]]
loss 0.0016049384
. x .
o . .
. . . 

. x .
o x .
. . .



r 0
v [0.34631684]
v_next [[0.37933838]]
target [[0.375545]]
loss 0.0008542848
. . .
x . o
. x . 

. . .
x . o
. x o



r 0
v [0.3733909]
v_next [[0.4123707]]
target [[0.408247]]
loss 0.0012149464
. . x
. . .
. . o 

. . x
. . x
. . o



r 0
v [0.37727344]
v_next [[0.3758107]]
tar

In [None]:
state = game.new_initial_state()
for i in range(10):
    print(state)
    if not state.is_terminal():
        print(state.observation_tensor())
    print()
    exp = state_to_exp(state)
    print(state)
    if not state.is_terminal():
        print(state.observation_tensor())
    print()
    print(exp)
    print()
    print()

In [None]:
for sample in buffer.sample(10):
    print(sample)

In [None]:
state = game.new_initial_state()
while True:
    val = value_net.apply(value_params, state_to_repr(state, state.current_player()))
    print(state, val, '\n')
    if state.is_terminal():
        break
    state.apply_action(random.choice(state.legal_actions()))