In [None]:
import random
import copy
from collections import namedtuple
from dataclasses import dataclass
import typing

import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax.experimental import optimizers
from jax.experimental import stax
import optax
import haiku as hk

from tqdm.notebook import tqdm
import pyspiel
import numpy as np
import trueskill

print(hk.__version__)

In [None]:
game = pyspiel.load_game('tic_tac_toe')
gamma = 0.99
key = jax.random.PRNGKey(0)
dim_board = game.observation_tensor_size()
dim_actions = game.num_distinct_actions()
all_actions = list(range(dim_actions))

In [None]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self._capacity = capacity
        self._data = []
        self._next_entry_index = 0

    def add(self, element):
        if len(self._data) < self._capacity:
            self._data.append(element)
        else:
            self._data[self._next_entry_index] = element
            self._next_entry_index += 1
            self._next_entry_index %= self._capacity
    
    def extend(self, elements):
        for ele in elements:
            self.add(ele)

    def sample(self, num_samples):
        if len(self._data) < num_samples:
            raise ValueError("{} elements could not be sampled from size {}".format(
              num_samples, len(self._data)))
        return random.sample(self._data, num_samples)

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

In [None]:
@dataclass
class Experience(object):
    s: np.array
    a: int
    g: float
    a_mask: np.array

    def print(self):
        print('r', self.g)
        print(obs_tensor_to_board(self.s), '\n')
        print(obs_tensor_to_board(self.s_next), '\n')

In [None]:
def policy_net_hk(x):
    mlp = hk.nets.MLP([128, 64, 32, dim_actions])
    return mlp(x)
#     return jnp.clip(mlp(x), a_min=50, a_max=51)
#     return jnp.tanh(mlp(x))

policy_net = hk.without_apply_rng(hk.transform(policy_net_hk))
# print(hk.experimental.tabulate(policy_net)(jnp.zeros((12, dim_board))))

In [None]:
def create_batch(samples):
    batch = {
        's': np.stack([sample.s for sample in samples]),
        'a': np.stack([sample.a for sample in samples]),
        'g': np.stack([sample.g for sample in samples]),
#         's_next': np.stack([sample.s_next for sample in samples]),
        'a_mask': np.stack([sample.a_mask for sample in samples]),
    }
    return batch

In [None]:
def obs_tensor_to_board(s):
    obs = s.reshape(3, 9).T
    obs = obs.argmax(axis=-1).reshape(3, 3).squeeze()
    return '\n'.join(map(str, obs)).replace('[', '').replace(']', '') \
                .replace('2', 'x').replace('1', 'o').replace('0', '.')

In [None]:
from jax.tree_util import tree_flatten

@jit
def l2_squared(pytree):
    leaves, _ = tree_flatten(pytree)
    return sum(jnp.vdot(x, x) for x in leaves)

In [None]:
@jit
def policy_forward(params, batch):
    logits = policy_net.apply(params, batch['s'])
    action_probs = jax.nn.softmax(logits)
    legal_action_probs = action_probs * batch['a_mask']
    legal_action_probs /= jnp.sum(legal_action_probs)
    return legal_action_probs

@jit
def illegal_actions_loss(params, batch):
    logits = policy_net.apply(params, batch['s'])
    action_probs = jax.nn.softmax(logits)
    illegal_mask = jnp.logical_not(batch['a_mask'])
    illegal_probs = action_probs * illegal_mask
    return jnp.mean(illegal_probs)

# @jit
# def policy_forward(params, batch):
#     logits = policy_net.apply(params, batch['s'])
# #     action_probs = jax.nn.softmax(logits - jnp.max(logits, axis=-1, keepdims=True))
# #     action_probs += 0.0000001  # epsilon
#     action_probs = jax.nn.softmax(logits)
# #     action_probs = jnp.clip(action_probs, a_min=0.05)  # epsilon
#     legal_action_probs = action_probs * batch['a_mask']
#     legal_action_probs /= jnp.sum(legal_action_probs)
#     return legal_action_probs

@jit
def policy_loss(params, batch):
    action_probs = policy_forward(params, batch)
    log_probs = jnp.log(action_probs)
    action = batch['a']
    g = batch['g']
    selected_log_probs = action_probs[jnp.arange(action.size), action]
    ia_loss = illegal_actions_loss(params, batch)
#     print(ia_loss)
    return -jnp.mean(g * selected_log_probs) + ia_loss

@jit
def policy_update(params, opt_state, batch):
    loss, grads = value_and_grad(policy_loss)(params, batch)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return loss, new_params, opt_state

In [None]:
class ExpBuilder(object):
    def __init__(self):
        self.obses = []
        self.actions = []
        self.action_masks = []
        self.reward = None
        
    def add(self, obs, action, action_mask):
        self.actions.append(action)
        self.obses.append(obs)
        self.action_masks.append(action_mask)
        
    def set_reward(self, reward):
        self.reward = reward
#         self.reward = max(self.reward, 0)
            
    def to_exps(self):
        exps = []
        for i in range(len(self.actions)):
            action = self.actions[i]
            action_mask = self.action_masks[i]
            obs = self.obses[i]
            G = gamma ** (len(self.actions) - 1) * self.reward
            exp = Experience(obs, action, G, a_mask=action_mask)
            exps.append(exp)
        return exps
    
    def reset(self):
        self.obses = []
        self.actions = []
        self.action_masks = []
        self.reward = None

In [None]:
class Player(object):
    def __init__(self):
        self.rating = trueskill.Rating()
        
    def step(self, state):
        pass
    
class RandomPlayer(Player):
    def step(self, state):
        if not state.is_terminal():
            actions = state.legal_actions()
            action = random.choice(actions)
            return action
    
class PolicyGradientPlayer(Player):
    def __init__(self, player_id, key):
        super().__init__()
        self.player_id = player_id
        self.params = policy_net.init(key, jnp.zeros((1, dim_board)))
#         self.params_behave = copy.deepcopy(self.params)
        self.opt_state = opt.init(self.params)

        self.batch_size = 50
        self.builder = ExpBuilder()
        self.buffer = ReplayBuffer(1000)
        self.last_loss = None
        
    def get_action_probs(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        a_mask = np.array(state.legal_actions_mask())
        batch = {'s': obs, 'a_mask': a_mask}
        action_probs = np.array(policy_forward(self.params, batch).block_until_ready())
        return action_probs
    
    def get_action_logits(self, state):
        obs = np.array(state.observation_tensor(self.player_id))
        a_mask = np.array(state.legal_actions_mask())
        batch = {'s': obs, 'a_mask': a_mask}
        logits = policy_net.apply(self.params, batch['s'])
        return logits
        
    def step(self, state):
        self.learn()
        
        if state.is_terminal():
            self.builder.set_reward(state.rewards()[self.player_id])
            exps = self.builder.to_exps()
            self.buffer.extend(exps)
            self.builder.reset()
        else:
            obs = np.array(state.observation_tensor(self.player_id))
            a_mask = np.array(state.legal_actions_mask())
            batch = {'s': obs, 'a_mask': a_mask}
            action_probs = np.array(policy_forward(self.params, batch).block_until_ready())
            action = np.random.choice(all_actions, 1, p=action_probs)
            self.builder.add(obs, action, a_mask)
            return action
        
    def learn(self):
        if len(self.buffer) >= self.batch_size:
#             for _ in range(20):
            samples = self.buffer.sample(self.batch_size)
            batch = create_batch(samples)
            self.last_loss, self.params, self.opt_state = policy_update(
                self.params, self.opt_state, batch)

In [None]:
# networks
# key, new_key = jax.random.split(key)
# params = policy_net.init(new_key, jnp.zeros((1, dim_board)))
opt = optax.adam(1e-5)
# opt_state = opt.init(params)

In [None]:
key, new_key, new_key_2 = jax.random.split(key, num=3)

In [None]:
agents = [
    RandomPlayer(),
    PolicyGradientPlayer(1, new_key)
]

In [None]:
# agents = [
#     PolicyGradientPlayer(0, new_key),
#     RandomPlayer(),
# ]

In [None]:
# agents = [
#     PolicyGradientPlayer(0, new_key),
#     PolicyGradientPlayer(1, new_key_2),
# ]

In [None]:
num_games = 50000
wins = {0: 0, 1: 0, 'drawn': 0}
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
        player_id = state.current_player()
        agent = agents[player_id]
        action = agent.step(state)
        state.apply_action(action)
    for agent in agents:
        agent.step(state)
        
        
    drawn = False
    if state.rewards()[0] > state.rewards()[1]:
        winner, loser = 0, 1
        wins[winner] += 1
    elif state.rewards()[0] < state.rewards()[1]:
        winner, loser = 1, 0
        wins[winner] += 1
    else:
        winner, loser = 0, 1
        drawn = True
        wins['drawn'] += 1
    
    agents[winner].rating, agents[loser].rating = \
        trueskill.rate_1vs1(agents[winner].rating, agents[loser].rating, drawn=drawn)
    
    
    if i > 1 and i % 500 == 0:
        print(agents[0].rating.mu, agents[1].rating.mu)
#         print(agents[0].last_loss)
        print(wins)
        print()
        for agent in agents:
            agent.rating = trueskill.Rating()
        wins = {0: 0, 1: 0, 'drawn': 0}

In [None]:
# logits = policy_net.apply(self.params, batch['s'])
# action_probs = jax.nn.softmax(logits - jnp.max(logits, axis=-1, keepdims=True))
# legal_action_probs = action_probs * batch['a_mask']
# legal_action_probs = legal_action_probs / jnp.sum(legal_action_probs)
# return legal_action_probs

In [None]:
num_games = 20
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
        print(state, '\n')
        player_id = state.current_player()
        agent = agents[player_id]
        print(player_id)
        if isinstance(agent, PolicyGradientPlayer):
            print(np.round(agent.get_action_probs(state).reshape(3, 3), 2))
        action = agent.step(state)
#         if player_id == 0:
#             print(np.round(agent.get_action_probs(state).reshape(3, 3), 2))
#         print(np.round(agent.get_action_probs(state).reshape(3, 3), 2))
        state.apply_action(action)
    print(state, '\n')

## Archive

In [None]:
def traj_to_exps():
    state = game.new_initial_state()
    obses = []
    actions = []
    action_masks = []
    rewards = []
    while not state.is_terminal():
        obses.append(np.array(state.observation_tensor()))
        action = random.choice(state.legal_actions())
        action_mask = np.array(state.legal_actions_mask())
        actions.append(action)
        action_masks.append(action_mask)
        state.apply_action(action)
        if state.is_terminal():
            reward = state.rewards()[0]
        else:
            reward = 0
        rewards.append(reward)
    obses.append(np.array(state.observation_tensor(0)))
    exps = []
    for i in range(len(actions)):
        action = actions[i]
        action_mask = action_masks[i]
        obs = obses[i]
        obs_next = obses[i + 1]
        rest_rewards = rewards[i:]
        G = 0
        for i, r in enumerate(rest_rewards):
            G += (gamma ** i) * r
        exp = Experience(obs, action, G, a_mask=action_mask)
        exps.append(exp)
    return exps

In [None]:
for i in tqdm(range(num_games)):
    state = game.new_initial_state()
    while not state.is_terminal():
#         print(state)
        if state.current_player() == 0:
            a_mask = np.array(state.legal_actions_mask())
            batch = {
                's': state_to_repr(state),
                'a_mask': a_mask
            }
#             print(a_mask)
#             print(action_probs)
            action_probs = np.array(policy_forward(params, batch).block_until_ready())
#             action_probs[-1] += 1 - np.sum(action_probs)
            action = np.random.choice(all_actions, 1, p=action_probs)
#             print(action)
        else:
            action = random.choice(state.legal_actions())
        state.apply_action(action)
#         print()

In [None]:
for i in tqdm(range(num_games)):
    exps = traj_to_exps()
    for exp in exps:
        buffer.add(exp)

In [None]:
for _ in range(1000):
    batch = create_batch(buffer.sample(batch_size))
    loss, params, opt_state = policy_update(params, opt_state, batch)
    print(loss)

In [None]:
for exp in buffer.sample(20):
    if exp:
        batch = create_batch([exp])
#         print('g', exp.g)
        print('loss', policy_loss(params, batch))
        print(obs_tensor_to_board(exp.s), '\n')
        action_probs = policy_forward(params, batch)
        print(np.round(action_probs.reshape(3,3), 2))
        
#         for action in range(dim_actions):
#         print(obs_tensor_to_board(exp.s_next))
        print('\n\n')

In [None]:
# @jit
# def value_net_forward(value_params, x):
#     return value_net.apply(value_params, x)

# @jit
# def value_net_error(value_params, batch):
#     v = value_net_forward(value_params, batch['s'])
#     target = batch['g']
#     error = target - v
#     return error

# @jit
# def value_net_loss(value_params, batch):
#     error = value_net_error(value_params, batch)
#     return jnp.mean(jnp.square(error))
    
# @jit
# def value_update(value_params, value_opt_state, batch):
#     loss, grads = value_and_grad(value_net_loss)(value_params, batch)
#     updates, value_opt_state = value_opt.update(grads, value_opt_state)
#     new_params = optax.apply_updates(value_params, updates)
#     return loss, new_params, value_opt_state

In [None]:
num_games = 10000
buffer = ReplayBuffer(capacity=50000)
batch_size = 10000
value_net_shape = (batch_size, 27)
value_params = value_net.init(value_net_params_init_key, jnp.zeros(value_net_shape))
value_opt = optax.adam(1e-5)
value_opt_state = value_opt.init(value_params)

In [None]:
for i in tqdm(range(num_games)):
    exps = traj_to_exps()
    for exp in exps:
        buffer.add(exp)

    if len(buffer) >= batch_size:
        samples = buffer.sample(batch_size)
        batch = create_batch(samples)
        loss, value_params, value_opt_state = value_update(
            value_params, value_opt_state, batch)
        if i % 100 == 0:
            print(loss)

In [None]:
for exp in buffer.sample(20):
    if exp:
        batch = create_batch([exp])
        print('g', exp.g)
        print('v', value_net_forward(value_params, jnp.array(exp.s)))
#         v_next = compute_v_next(value_params, batch)
#         batch['v_next'] = v_next
#         print('v_next', v_next)
#         print('target', compute_target(batch))
        print('loss', value_net_loss(value_params, batch))
        print(obs_tensor_to_board(exp.s), '\n')
        print(obs_tensor_to_board(exp.s_next))
        print('\n\n')