In [1]:
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
import concurrent.futures
from torch import optim
import torch
import os
import copy


%load_ext autoreload
%autoreload 2
import env
import network
import player

BOARD_XSIZE=7
BOARD_YSIZE=6
DIMS=(BOARD_YSIZE,BOARD_XSIZE)

EPISODES_PER_AGENT = 50
TRAIN_EPOCHS = 500000
MODEL_SAVE_INTERVAL = 100
SUMMARY_STATS_INTERVAL = 10
RANDOM_SEED = 42

SUMMARY_DIR = './summary'
MODEL_DIR = './models'

# create result directory
if not os.path.exists(SUMMARY_DIR):
    os.makedirs(SUMMARY_DIR)


use_cuda = torch.cuda.is_available()
torch.manual_seed(RANDOM_SEED)

cuda = torch.device("cuda")
cpu = torch.device("cpu")

if use_cuda:
    device = cuda
else:
    device = cpu

device=cpu

In [2]:
actor = network.Actor(BOARD_XSIZE, BOARD_YSIZE).to(device)
critic = network.Critic(BOARD_XSIZE, BOARD_YSIZE).to(device)

actor_optimizer = optim.Adam(actor.parameters(), lr=network.ACTOR_LR)
critic_optimizer = optim.Adam(critic.parameters(), lr=network.CRITIC_LR)

# Get Writer
writer = SummaryWriter(log_dir=SUMMARY_DIR)

step=0

In [3]:
opponent_pool:list[player.Player] = [
    player.MinimaxPlayer(env.PLAYER2, 2, 0.3),
    player.MinimaxPlayer(env.PLAYER2, 2, 0.5),
]

rewards_vs: dict[str, list[float]] = {}

In [4]:
def play(actor:player.ActorPlayer, opponent: player.Player, actor_turn:bool) -> tuple[
    list[env.Observation],
    list[env.Action],
    list[env.Reward],
    list[env.Advantage],
    list[env.Reward],
]:
    e = env.Env(DIMS)

    s_t:list[env.Observation] = []
    a_t:list[env.Action] = []
    r_t:list[env.Reward] = []
    # play the game
    while not e.game_over():
        if actor_turn:
            obs, chosen_action, reward = actor.play(e)
            s_t += [obs]
            a_t += [chosen_action]
            r_t += [reward]
        else:
            opponent.play(e)

        # flip turn
        actor_turn = not actor_turn

    # compute advantage and value
    d_t = network.compute_advantage(actor.critic, s_t, r_t)
    v_t = network.compute_value(r_t)

    return s_t, a_t, r_t, d_t, v_t

In [5]:
# interrupt this cell when you're done training

for _ in range(TRAIN_EPOCHS):
    s_batch:list[env.Observation] = []
    a_batch:list[env.Action] = []
    p_batch:list[np.ndarray] = []
    d_batch:list[env.Advantage] = []
    v_batch:list[env.Value] = []
    
    # create actor player
    actor_player = player.ActorPlayer(actor, critic, step, env.PLAYER1)
    
    for _ in range(EPISODES_PER_AGENT):
        # pick a random opponent
        opponent_player = opponent_pool[np.random.randint(len(opponent_pool))]

        # whether we or our opponent goes first
        go_first = np.random.randint(2) == 0

        # play the game
        s_t, a_t, r_t, d_t, v_t = play(actor_player,opponent_player, go_first)

        # now update the minibatch
        s_batch += s_t
        a_batch += a_t
        d_batch += d_t
        v_batch += v_t

        # statistics
        opp_name = opponent_player.name()
        if opp_name in rewards_vs:
            rewards_vs[opp_name].append(float(v_t[-1]))
        else:
            rewards_vs[opp_name] = [float(v_t[-1])]

    actor_losses, critic_losses = network.train_policygradient(
        actor,
        critic,
        actor_optimizer,
        critic_optimizer,
        s_batch,
        a_batch,
        d_batch,
        v_batch
    )

    for actor_loss, critic_loss in zip(actor_losses, critic_losses):
        writer.add_scalar('actor_loss', actor_loss, step)
        writer.add_scalar('critic_loss', critic_loss, step)

        if step % SUMMARY_STATS_INTERVAL == 0:
            for opponent_name, rewards in rewards_vs.items():
                if len(rewards) > 50:
                    avg_reward = np.array(rewards).mean()
                    writer.add_scalar(f'reward_against_{opponent_name}', avg_reward, step)
                    rewards_vs[opponent_name] = []

        if step % MODEL_SAVE_INTERVAL == 0:
            # Save the neural net parameters to disk.
            torch.save(actor.state_dict(), f"{SUMMARY_DIR}/nn_model_ep_{step}_actor.ckpt")
            torch.save(critic.state_dict(), f"{SUMMARY_DIR}/nn_model_ep_{step}_critic.ckpt")
        
        step += 1

KeyboardInterrupt: 

In [None]:
actor.load_state_dict(torch.load('./summary/nn_model_ep_500_actor.ckpt'))
#critic.load_state_dict(torch.load('./summary/nn_model_ep_1500_critic.ckpt'))

<All keys matched successfully>

In [6]:
e = env.Env(DIMS)

e.step(env.Action(1), env.PLAYER1)
e.step(env.Action(1), env.PLAYER1)
e.step(env.Action(1), env.PLAYER1)

e.step(env.Action(5), env.PLAYER2)
e.step(env.Action(5), env.PLAYER2)
e.step(env.Action(5), env.PLAYER2)






o = e.observe(1)
print(e.legal_mask())
env.print_obs(o)
print('0 1 2 3 4 5 6 7')
print(actor.forward(network.obs_batch_to_tensor([o], device))[0])
print(critic.forward(network.obs_batch_to_tensor([o], device))[0])

[ True  True  True  True  True  True  True]
              
              
              
  #       O   
  #       O   
  #       O   

0 1 2 3 4 5 6 7
tensor([0.1330, 0.1411, 0.1726, 0.1745, 0.1982, 0.0890, 0.0917],
       grad_fn=<SelectBackward0>)
tensor(0.1529, grad_fn=<SelectBackward0>)


In [7]:
# use this cell to observe some games from the network

s_tensor = network.obs_batch_to_tensor(s_batch, device)
critic_guesses = critic.forward(s_tensor).to(cpu).detach().numpy()
actor_guesses = actor.forward(s_tensor).to(cpu).detach().numpy()
for v, obs, critic_guess, actor_guess in zip(v_batch, s_batch, critic_guesses, actor_guesses):
    print("real_value", v)
    print("pred_value", float(critic_guess))
    print("actor_probs", np.array(actor_guess))
    env.print_obs(obs)
    print('0 1 2 3 4 5 6 7')

real_value 0.0
pred_value 0.027175992727279663
actor_probs [0.1404474  0.12150193 0.13940413 0.1544029  0.17026801 0.13913056
 0.1348451 ]
              
              
              
              
              
              

0 1 2 3 4 5 6 7
real_value 0.0
pred_value 0.03009571135044098
actor_probs [0.14211562 0.12067111 0.14645967 0.15132162 0.1678718  0.13682544
 0.13473463]
              
              
              
              
              
O #           

0 1 2 3 4 5 6 7
real_value 0.0
pred_value 0.02339843660593033
actor_probs [0.1456281  0.12095633 0.15536818 0.15067069 0.16190518 0.13282353
 0.13264799]
              
              
              
O             
#             
O #           

0 1 2 3 4 5 6 7
real_value 0.0
pred_value 0.02304650843143463
actor_probs [0.15320058 0.1113179  0.16751769 0.1736629  0.15461431 0.12072457
 0.11896203]
              
              
              
O O           
# #           
O #           

0 1 2 3 4 5 6 7
real_value 0.0
pred