# Train Agent

## Imports

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from collections import namedtuple, Counter
from ipywidgets import *
from IPython.display import display, HTML

from santorinigo.environment import Santorini
from santorinigo.qnetwork import *
from santorinigo.replay_memory import *
from santorinigo.agent import Agent

DATA_PATH = 'data/'
MODEL_PATH = f'{DATA_PATH}models/'

## Environment

In [2]:
env = Santorini()
env.print_board()

Buildings:
 [[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
Workers:
 [[ 0  0 -1  0  0]
 [ 0  0  0  0  0]
 [ 1  0  0  0  2]
 [ 0  0  0  0  0]
 [ 0  0 -2  0  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 22  0  0  0]
 [ 0  0 18  0  0]
 [ 0  0  0 14  0]
 [ 0  0  0  0 18]]


In [3]:
env.step(27)

(array([ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0, -1,  0,  0,  0,  0,  0,  0,  0,
         0,  1,  0,  0,  0,  2,  0,  0,  0,  0,  0,  0,  0, -2,  0,  0,  0,
        21, 18, 14, 18]), -0.001, False, 1)

In [4]:
env.print_board()

Buildings:
 [[1 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
Workers:
 [[ 0 -1  0  0  0]
 [ 0  0  0  0  0]
 [ 1  0  0  0  2]
 [ 0  0  0  0  0]
 [ 0  0 -2  0  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 21  0  0  0]
 [ 0  0 18  0  0]
 [ 0  0  0 14  0]
 [ 0  0  0  0 18]]


In [5]:
env.legal_moves()[:10]

[9, 12, 14, 15, 16, 18, 19, 20, 21, 22]

In [6]:
p1_action = env.atoi[(-1,'x','d')]; p1_action

52

In [7]:
env.step(p1_action), env.print_board()

Buildings:
 [[1 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 1 0 0 0]
 [0 0 0 0 0]]
Workers:
 [[ 0 -1  0  0  0]
 [ 0  0  0  0  0]
 [ 0  0  0  0  2]
 [ 1  0  0  0  0]
 [ 0  0 -2  0  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 20  0  0  0]
 [ 0  0 18  0  0]
 [ 0  0  0 14  0]
 [ 0  0  0  0 18]]


((array([ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0, -2, -1,  0,  0,  0,  0,  0,  0,  2,  0,  0,  0,
         20, 18, 14, 18]), -0.001, False, -1), None)

In [8]:
p2_action = env.atoi[(-1,'d','d')]; p2_action

36

In [9]:
env.step(p2_action)

(array([ 1,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  2,  1,  0,  0,  0,  0,  0,  0, -2,  0,  0,  0,
        19, 18, 14, 18]), -0.001, False, 1)

In [10]:
env.action_dim, env.state_dim_flat

(128, 55)

## Agent

In [14]:
#memories
mem = PrioritizedMemory(capacity = 1000)
#agents
a = Agent(state_size = env.state_dim_flat, action_size = env.action_dim, replay_memory = mem, seed = 1412,
          lr = 1e-3 / 4, bs = 64, nb_hidden = 128,
          gamma=0.99, tau= 1/100, update_interval = 5)

In [15]:
winners = []
timesteps = []
game_records = []

env = Santorini()
for i in trange(100000):
    state = env.reset()
    timestep = 0
    game_record = [env.board]
    while True:
        actions = a.act(state,i,return_list=True)
            
        #check legality
        legal_moves = env.legal_moves()
        for a_ in actions:
            if a_ in legal_moves:
                action = a_
                break
        
        #step action
        next_state,reward,done,next_player = env.step(action)
        game_record.append(env.board)
        timestep+=1
            
        #step agent; store transition and train
        a.step(state,action,reward,next_state,done,i)
            
        #break if done
        if done: break
            
    #record
    winners.append(-next_player)
    timesteps.append(timestep)
    game_records.append(game_record)
    #print(f'Player {current_player} wins in {timestep} turns')
    if i % 10000 == 0:
        torch.save(a.qnetwork_local.state_dict(), f'{MODEL_PATH}half_rainbow_{i}.m')

  2%|▏         | 1894/100000 [28:28<24:34:49,  1.11it/s]

KeyboardInterrupt: 

In [None]:
Counter(winners)

In [None]:
game_records[np.argmax(timesteps)][-1]

In [None]:
plt.hist(timesteps)
np.mean(timesteps),np.min(timesteps),np.max(timesteps)

In [None]:
game_records[-1][-1]

In [None]:
plt.plot(a1.losses)

In [17]:
fname = f'{MODEL_PATH}half_rainbow_temp.m'
torch.save(a.qnetwork_local.state_dict(), f'{fname}')
#load
a.qnetwork_local.load_state_dict(torch.load(fname))
a.qnetwork_target.load_state_dict(torch.load(fname))

## Versus Human

In [18]:
env = Santorini()
env.print_board()

Buildings:
 [[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
Workers:
 [[ 0  0 -1  0  0]
 [ 0  0  0  0  0]
 [ 1  0  0  0  2]
 [ 0  0  0  0  0]
 [ 0  0 -2  0  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 22  0  0  0]
 [ 0  0 18  0  0]
 [ 0  0  0 14  0]
 [ 0  0  0  0 18]]


### Human's Turn

In [35]:
human_key = (-1,'d','a')
human_action = env.atoi[human_key]
env.current_player, human_key, human_action

(-1, (-1, 'd', 'a'), 35)

In [36]:
env.step(human_action)
env.print_board()

Buildings:
 [[0 0 3 2 0]
 [1 0 0 0 0]
 [0 0 0 1 1]
 [0 0 0 0 1]
 [0 0 0 0 0]]
Workers:
 [[ 0  0  0 -1  0]
 [ 0  1  0  0  0]
 [ 0  0  0  0  0]
 [ 0  0  0  2  0]
 [ 0  0 -2  0  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 16  0  0  0]
 [ 0  0 16  0  0]
 [ 0  0  0 13  0]
 [ 0  0  0  0 18]]


### Agent's Turn

In [37]:
state = env.get_state()
actions = a.act(state,1000,return_list=True)
#check legality
legal_moves = env.legal_moves()
for action in actions:
    if action in legal_moves:
        agent_action = action
        break
env.current_player, env.itoa[agent_action], agent_action

(1, (-2, 'x', 'w'), 113)

In [38]:
env.step(agent_action)
env.print_board()

Buildings:
 [[0 0 3 2 0]
 [1 0 0 0 0]
 [0 0 0 1 1]
 [0 0 0 1 1]
 [0 0 0 0 0]]
Workers:
 [[ 0  0  0 -1  0]
 [ 0  1  0  0  0]
 [ 0  0  0  0  0]
 [ 0  0  0  0  0]
 [ 0  0 -2  2  0]]
Parts:
 [[ 0  0  0  0  0]
 [ 0 15  0  0  0]
 [ 0  0 16  0  0]
 [ 0  0  0 13  0]
 [ 0  0  0  0 18]]
