In [79]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import typing

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk
from jax.tree_util import tree_flatten
import chex
from chex import assert_shape

from tqdm.notebook import tqdm
import pyspiel
import numpy as np
import trueskill

print(hk.__version__)

0.0.5.dev


In [80]:
%run bokeh_lib.ipynb

In [123]:
game = pyspiel.load_game('tic_tac_toe')
gamma = 0.99
entropy_reg_factor = 1e-1
key = jax.random.PRNGKey(0)
dim_board = game.observation_tensor_size()
dim_actions = game.num_distinct_actions()
dim_q = dim_board + dim_actions
all_actions = list(range(dim_actions))

In [124]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self._capacity = capacity
        self._data = []
        self._next_entry_index = 0

    def add(self, element):
        if len(self._data) < self._capacity:
            self._data.append(element)
        else:
            self._data[self._next_entry_index] = element
            self._next_entry_index += 1
            self._next_entry_index %= self._capacity
    
    def extend(self, elements):
        for ele in elements:
            self.add(ele)

    def sample(self, num_samples):
        if len(self._data) < num_samples:
            raise ValueError("{} elements could not be sampled from size {}".format(
              num_samples, len(self._data)))
        return random.sample(self._data, num_samples)

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

In [125]:
@dataclass
class Experience(object):
    s: np.array
    a: np.array
    g: float
    a_mask: np.array

#     def print(self):
#         print('r', self.g)
#         print(obs_tensor_to_board(self.s), '\n')
#         print(obs_tensor_to_board(self.s_next), '\n')

In [126]:
def policy_net_hk(x):
    mlp = hk.nets.MLP([64, 64, dim_actions])
    return mlp(x)

def q_net_hk(x):
    mlp = hk.nets.MLP([64, 64, 1])
    return mlp(x)

def v_net_hk(x):
    mlp = hk.nets.MLP([64, 64, 1])
    return mlp(x)

policy_net = hk.without_apply_rng(hk.transform(policy_net_hk))
q_net = hk.without_apply_rng(hk.transform(q_net_hk))
v_net = hk.without_apply_rng(hk.transform(v_net_hk))

In [127]:
def create_batch(samples):
    batch = {
        's': np.stack([sample.s for sample in samples]),
        'a': np.stack([sample.a for sample in samples]),
        'g': np.stack([sample.g for sample in samples]),
        'a_mask': np.stack([sample.a_mask for sample in samples]),
    }
    return batch

In [128]:
def obs_tensor_to_board(s):
    obs = s.reshape(3, 9).T
    obs = obs.argmax(axis=-1).reshape(3, 3).squeeze()
    return '\n'.join(map(str, obs)).replace('[', '').replace(']', '') \
                .replace('2', 'x').replace('1', 'o').replace('0', '.')

In [129]:
@jit
def l2_squared(pytree):
    leaves, _ = tree_flatten(pytree)
    return sum(jnp.vdot(x, x) for x in leaves)

In [130]:
def one_hot_action(action):
    vec = np.zeros(dim_actions)
    vec[action] = 1
    return vec

In [131]:
@jit
def q_forward(params, batch):
    q_repr = jnp.concatenate((batch['s'], batch['a']), axis=-1)
    q_val = q_net.apply(params, q_repr)
    return q_val

@jit
def q_loss(params, batch):
    q_val = q_forward(params, batch)
    g = batch['g']
    return jnp.mean(jnp.square(q_val - g))

@jit
def q_update(params, opt_state, batch):
    loss, grads = value_and_grad(q_loss)(params, batch)
    updates, opt_state = q_opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return loss, new_params, opt_state

In [132]:
@jit
def v_forward(params, batch):
    v_val = v_net.apply(params, batch['s'])
    return v_val

@jit
def v_loss(params, batch):
    v_val = v_forward(params, batch)
    g = batch['g']
    return jnp.mean(jnp.square(v_val - g))

@jit
def v_update(params, opt_state, batch):
    loss, grads = value_and_grad(v_loss)(params, batch)
    updates, opt_state = v_opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return loss, new_params, opt_state

In [133]:
@jit
def policy_logits(params, batch):
    logits = policy_net.apply(params, batch['s'])
    return logits

@jit
def policy_logits_masked(params, batch):
    logits = policy_net.apply(params, batch['s'])
    inf_mask = jnp.log(batch['a_mask'])
    return logits + inf_mask

@jit
def policy_probs(params, batch):
    logits = policy_logits_masked(params, batch)
    action_probs = jax.nn.softmax(logits - jnp.max(logits, axis=-1, keepdims=True))
    return action_probs

In [134]:
def compute_action_entropy(action_probs):
    return jnp.sum(jsp.special.entr(action_probs), axis=-1, keepdims=True)

In [135]:
@jit
def policy_loss_with_a_and_entropy(params, batch):
    action_probs = policy_probs(params, batch)
    log_probs = jnp.log(action_probs)
    q = batch['q']
    v = batch['v']
    selected_log_probs = action_probs * batch['a']
    action_entropy = compute_action_entropy(action_probs)
    action_entropy *= entropy_reg_factor
    ret = -jnp.mean((q - v + action_entropy) * selected_log_probs)
    
    batch_size = q.shape[0]
    chex.assert_shape(q, (batch_size, 1))
    chex.assert_shape(v, (batch_size, 1))
    chex.assert_shape(action_entropy, (batch_size, 1))
    
    aux = dict(action_entropy=action_entropy)
    return ret, aux

@jit
def policy_update_with_a_and_entropy(params, opt_state, batch):
    (loss, aux), grads = value_and_grad(policy_loss_with_a_and_entropy, has_aux=True)(params, batch)
    updates, opt_state = policy_opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return (loss, aux), new_params, opt_state

### Sanity Check

In [136]:
key, new_key = jax.random.split(key)
batch_size = 32
toy_policy_params = policy_net.init(new_key, jnp.ones((batch_size, dim_board)))
batch = {
    's': np.random.randn(batch_size, dim_board),
    'v': np.random.randn(batch_size, 1),
    'a': np.random.choice(np.arange(dim_actions), (batch_size, 1)),
    'q': np.random.randn(batch_size, 1),
    'a_mask': np.ones((batch_size, dim_actions)),
}
policy_loss_with_a_and_entropy(toy_policy_params, batch)

(DeviceArray(-0.30369875, dtype=float32),
 {'action_entropy': DeviceArray([[0.21647826],
               [0.21619618],
               [0.21892162],
               [0.21765251],
               [0.21817921],
               [0.21699464],
               [0.21378298],
               [0.2169377 ],
               [0.21615648],
               [0.21769881],
               [0.21534534],
               [0.21805918],
               [0.21674685],
               [0.21698074],
               [0.21640043],
               [0.21943496],
               [0.20835844],
               [0.21414891],
               [0.21820986],
               [0.21531406],
               [0.21726075],
               [0.21820961],
               [0.21816781],
               [0.21743324],
               [0.21836376],
               [0.21788768],
               [0.21911626],
               [0.21714464],
               [0.21794598],
               [0.21725321],
               [0.21866937],
               [0.21836133]], dtype=float

In [137]:
# logits = np.random.randn(10) * 5
logits = np.ones(dim_actions)
action_probs = jax.nn.softmax(logits)
action_entropy = compute_action_entropy(action_probs)
print(np.round(action_probs, 2))
print(action_entropy)
print(action_entropy * entropy_reg_factor)

[0.11 0.11 0.11 0.11 0.11 0.11 0.11 0.11 0.11]
[2.1972246]
[0.21972246]


In [138]:
class ExpBuilder(object):
    def __init__(self):
        self.obses = []
        self.actions = []
        self.action_masks = []
        self.reward = None
        
    def add(self, obs, action, action_mask):
        self.actions.append(action)
        self.obses.append(obs)
        self.action_masks.append(action_mask)
        
    def set_reward(self, reward):
        self.reward = reward
            
    def to_exps(self):
        exps = []
        for i in range(len(self.actions)):
            action = one_hot_action(self.actions[i])
            action_mask = self.action_masks[i]
            obs = self.obses[i]
            G = gamma ** (len(self.actions) - 1) * self.reward
            exp = Experience(obs, action, G, a_mask=action_mask)
            exps.append(exp)
        return exps
    
    def reset(self):
        self.obses = []
        self.actions = []
        self.action_masks = []
        self.reward = None

In [153]:
class Player(object):
    def __init__(self):
        self.rating = trueskill.Rating()
        
    def step(self, state):
        pass
    
class RandomPlayer(Player):
    def step(self, state):
        if not state.is_terminal():
            actions = state.legal_actions()
            action = random.choice(actions)
            return action
    
class PolicyGradientPlayer(Player):
    def __init__(self, player_id, key):
        super().__init__()
        self.player_id = player_id
        
        keys = jax.random.split(key, num=3)
        
        self.policy_params = policy_net.init(keys[0], jnp.zeros((1, dim_board)))
        self.policy_params_old = copy.deepcopy(self.policy_params)
        self.policy_opt_state = policy_opt.init(self.policy_params)
        
        self.q_params = q_net.init(keys[1], jnp.zeros((1, dim_q)))
        self.q_opt_state = q_opt.init(self.q_params)
        
        self.v_params = v_net.init(keys[2], jnp.zeros((1, dim_board)))
        self.v_opt_state = v_opt.init(self.v_params)

        self.batch_size = 200
        self.builder = ExpBuilder()
        self.buffer = ReplayBuffer(20000)
        self.info = None
        self.steps = 0
        self.sync_per_n_steps = 50000
        
    def get_q_vals(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        q_vals = np.zeros(9)
        for action in range(dim_actions):
            batch = {'s': obs, 'a': one_hot_action(action)}
            q_vals[action] = np.array(q_forward(self.q_params, batch).block_until_ready())
        return q_vals.reshape(3, 3)
    
    def get_v(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        batch = {'s': obs}
        return np.array(v_forward(self.v_params, batch).block_until_ready())
        
    def get_action_probs(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        a_mask = np.array(state.legal_actions_mask())
        batch = {'s': obs, 'a_mask': a_mask}
        action_probs = np.array(policy_probs(self.policy_params, batch).block_until_ready())
        return action_probs
    
    def get_action_logits(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        a_mask = np.array(state.legal_actions_mask())
        batch = {'s': obs, 'a_mask': a_mask}
        logits = policy_net.apply(self.policy_params, batch['s'])
        return logits
        
    def step(self, state):
        self.steps += 1
        self.learn()
        
        if self.steps % self.sync_per_n_steps == 0:
            self.sync()
        
        if state.is_terminal():
            self.builder.set_reward(state.rewards()[self.player_id])
            exps = self.builder.to_exps()
            self.buffer.extend(exps)
            self.builder.reset()
        else:
            obs = np.array(state.observation_tensor(self.player_id))
            a_mask = np.array(state.legal_actions_mask())
            batch = {'s': obs, 'a_mask': a_mask}
            action_probs = np.array(policy_probs(self.policy_params_old, batch).block_until_ready())
            action = np.random.choice(all_actions, 1, p=action_probs)
            self.builder.add(obs, action, a_mask)
            return action
        
    def learn(self):
        if len(self.buffer) >= self.batch_size:
            samples = self.buffer.sample(self.batch_size)
            batch = create_batch(samples)
            batch['q'] = q_forward(self.q_params, batch)
            batch['v'] = v_forward(self.v_params, batch)
            (policy_loss, aux), self.policy_params, self.policy_opt_state = policy_update_with_a_and_entropy(
                self.policy_params, self.policy_opt_state, batch)
            q_loss, self.q_params, self.q_opt_state = q_update(
                self.q_params, self.q_opt_state, batch)
            v_loss, self.v_params, self.v_opt_state = v_update(
                self.v_params, self.v_opt_state, batch)
            
            self.info = {
                'policy_loss': float(policy_loss.block_until_ready()),
                'q_loss': float(q_loss.block_until_ready()),
                'v_loss': float(v_loss.block_until_ready()),
                'entropy': round(float(np.mean(aux['action_entropy'].block_until_ready())), 3),
            }
            
    def sync(self):
        self.policy_params_old = copy.deepcopy(self.policy_params)

In [154]:
vis = VisDashboard([
    LineFigure('policy_loss'),
    LineFigure('q_loss'),
    LineFigure('v_loss'),
    LineFigure('win_rate'),
    LineFigure('trueskill'),
    LineFigure('entropy')
])
vis.show()



In [155]:
policy_opt = optax.adam(1e-6)
q_opt = optax.adam(3e-4)
v_opt = optax.adam(3e-4)

In [156]:
key, new_key, new_key_2 = jax.random.split(key, num=3)

In [157]:
agents = [
    RandomPlayer(),
    PolicyGradientPlayer(1, new_key)
]

In [158]:
# agents = [
#     PolicyGradientPlayer(0, new_key),
#     RandomPlayer(),
# ]

In [159]:
# agents = [
#     PolicyGradientPlayer(0, new_key),
#     PolicyGradientPlayer(1, new_key_2),
# ]

In [160]:
# vis.saver.reset()
num_games = 20000
wins = {0: 0, 1: 0, 'drawn': 0}
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
        player_id = state.current_player()
        agent = agents[player_id]
        action = agent.step(state)
        state.apply_action(action)
    for agent in agents:
        agent.step(state)
        
    drawn = False
    if state.rewards()[0] > state.rewards()[1]:
        winner, loser = 0, 1
        wins[winner] += 1
    elif state.rewards()[0] < state.rewards()[1]:
        winner, loser = 1, 0
        wins[winner] += 1
    else:
        winner, loser = 0, 1
        drawn = True
        wins['drawn'] += 1
    
    agents[winner].rating, agents[loser].rating = \
        trueskill.rate_1vs1(agents[winner].rating, agents[loser].rating, drawn=drawn)
    
    for agent in agents:
        if isinstance(agent, PolicyGradientPlayer) and agent.info:
            for fig_name, val in agent.info.items():
                vis.saver.update({fig_name: {'pg': val}})
    
    vis.saver.update({
        'trueskill': {
            'random': agents[0].rating.mu,
            'pg': agents[1].rating.mu
        },
        'win_rate': {
            'random': wins[0] / sum(wins.values()),
            'pg': wins[1] / sum(wins.values()),
        }
    })
                
    if i > 1 and i % 100 == 0:
#         print(agents[0].rating.mu, agents[1].rating.mu)
#         print(wins)
#         print()
        vis.push()
        for agent in agents:
            agent.rating = trueskill.Rating()
        wins = {0: 0, 1: 0, 'drawn': 0}

  0%|          | 0/20000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [106]:
num_games = 20
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
        print(state, '\n')
        player_id = state.current_player()
        agent = agents[player_id]
        print(player_id)
        if isinstance(agent, PolicyGradientPlayer):
            print('policy\n', np.round(agent.get_action_probs(state).reshape(3, 3), 2))
            q = np.round(agent.get_q_vals(state), 2)
            v = np.round(agent.get_v(state), 2)
            a = np.around(q - v, 2)
            print('q\n', q)
            print('v\n', v)
            print('a\n', a)
        action = agent.step(state)
#         if player_id == 0:
#             print(np.round(agent.get_action_probs(state).reshape(3, 3), 2))
#         print(np.round(agent.get_action_probs(state).reshape(3, 3), 2))
        state.apply_action(action)
    print(state, '\n')

  0%|          | 0/20 [00:00<?, ?it/s]

...
...
... 

0
...
...
..x 

1
policy
 [[0.06 0.05 0.09]
 [0.05 0.06 0.07]
 [0.05 0.57 0.  ]]
q
 [[-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]]
v
 [-0.25]
a
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
...
...
.ox 

0
.x.
...
.ox 

1
policy
 [[0.12 0.   0.3 ]
 [0.08 0.15 0.17]
 [0.17 0.   0.  ]]
q
 [[-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]]
v
 [-0.25]
a
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
ox.
...
.ox 

0
ox.
x..
.ox 

1
policy
 [[0.   0.   0.36]
 [0.   0.19 0.2 ]
 [0.26 0.   0.  ]]
q
 [[-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]]
v
 [-0.25]
a
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
ox.
xo.
.ox 

0
oxx
xo.
.ox 

1
policy
 [[0.   0.   0.  ]
 [0.   0.   0.45]
 [0.55 0.   0.  ]]
q
 [[-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]
 [-0.25 -0.25 -0.25]]
v
 [-0.25]
a
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
oxx
xo.
oox 

0
oxx
xox
oox 

...
...
... 

0
...
...
x.. 

1
policy
 [[0.06 0.05 0.09]
 [0.05 0.07 0.06]
 [0.   0.57 0.03]]
q
 [[-0.25 -0.25 -0.25]

..o
...
.x. 

0
..o
.x.
.x. 

1
policy
 [[0.17 0.1  0.  ]
 [0.14 0.   0.24]
 [0.21 0.   0.14]]
q
 [[-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]]
v
 [-0.26]
a
 [[-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]]
..o
.x.
.xo 

0
..o
.xx
.xo 

1
policy
 [[0.22 0.17 0.  ]
 [0.25 0.   0.  ]
 [0.36 0.   0.  ]]
q
 [[-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]]
v
 [-0.26]
a
 [[-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]]
..o
oxx
.xo 

0
..o
oxx
xxo 

1
policy
 [[0.58 0.42 0.  ]
 [0.   0.   0.  ]
 [0.   0.   0.  ]]
q
 [[-0.26 -0.25 -0.26]
 [-0.26 -0.26 -0.25]
 [-0.26 -0.25 -0.26]]
v
 [-0.26]
a
 [[-0.    0.01 -0.  ]
 [-0.   -0.    0.01]
 [-0.    0.01 -0.  ]]
o.o
oxx
xxo 

0
oxo
oxx
xxo 

...
...
... 

0
.x.
...
... 

1
policy
 [[0.07 0.   0.14]
 [0.03 0.06 0.06]
 [0.07 0.53 0.04]]
q
 [[-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]
 [-0.26 -0.26 -0.26]]
v
 [-0.26]
a
 [[-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]]
ox.
...
... 

0
ox.
..x
... 

1
policy
 [[0.   0.   0.18]
 [0.05