In [None]:
from kaggle_environments import make, evaluate
from kaggle_environments.envs.connectx.connectx import play,is_win

import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

import numpy as np
import random
import math
from datetime import datetime
import copy
import time
import pickle

if (not os.path.exists("logs")):
  os.mkdir("logs")

from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

tb = SummaryWriter(log_dir="logs")

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ROWS = 6
COLUMNS = 7

N_BLOCKS = 12
FILTERS = 128

PV_EVALUATE_COUNT = 50
SP_GAME_COUNT = 500
SP_TEMPERATURE = 1.0

BATCH_SIZE = 128
LOOPS = 10
RN_EPOCHS = 100
TEST_DATA_NUM = 100
LEARNING_RATE = 0.01

In [None]:
class Block(nn.Module):
  def __init__(self, filter):
    super(Block,self).__init__()
    self.conv1 = nn.Conv2d(filter, filter, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(filter, filter, kernel_size=3, padding=1)

    self.bn1 = nn.BatchNorm2d(filter)
    self.bn2 = nn.BatchNorm2d(filter)

    self.relu1 = nn.ReLU()
    self.relu2 = nn.ReLU()

  def forward(self, x):
    input = x
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu1(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = x + input
    x = self.relu2(x)
    return x

class ResNet(nn.Module):
  def __init__(self):
    super(ResNet,self).__init__()
    self.conv = nn.Conv2d(2, FILTERS, kernel_size=3, padding=1)
    self.bn = nn.BatchNorm2d(FILTERS)
    self.relu = nn.ReLU()

    self.blocks = nn.ModuleList([Block(FILTERS) for _ in range(N_BLOCKS)])

    self.conv_p = nn.Conv2d(FILTERS, 8, kernel_size=1)
    self.bn_p = nn.BatchNorm2d(8)
    self.relu_p = nn.ReLU()
    self.linear_p = nn.Linear(8 * ROWS*COLUMNS, COLUMNS)
    self.softmax_p = nn.Softmax(dim=1)

    self.conv_v = nn.Conv2d(FILTERS, 4, kernel_size=1)
    self.bn_v = nn.BatchNorm2d(4)
    self.relu_v = nn.ReLU()
    self.linear_v1 = nn.Linear(4 * ROWS*COLUMNS, 64)
    self.linear_v2 = nn.Linear(64, 1)
    self.tanh_v = nn.Tanh()

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)

    for block in self.blocks:
      x = block(x)

    x_p = self.conv_p(x)
    x_p = self.bn_p(x_p)
    x_p = self.relu_p(x_p)
    x_p = x_p.view(x_p.size(0), -1)
    x_p = self.linear_p(x_p)
    policy = self.softmax_p(x_p)


    x_v = self.conv_v(x)
    x_v = self.bn_v(x_v)
    x_v = self.relu_v(x_v)
    x_v = x_v.view(x_v.size(0), -1)
    x_v = self.linear_v1(x_v)
    x_v = self.linear_v2(x_v)
    value = self.tanh_v(x_v)

    return value, policy

In [None]:
def predict(model, obs):
  state_pieces, state_enemy_pieces = get_state(obs)
  value, policy = model(torch.tensor([state_pieces,state_enemy_pieces], dtype=torch.float).reshape(1, 2, 6, 7).to(DEVICE))
  value = value[0][0].item()
  policy = np.array([policy[0].tolist()[c] for c in get_legal_actions(obs)])
  policy /= sum(policy) if sum(policy) else 1
  return value, policy.tolist()

def playout(observation, configuration, action, turn):
    play(observation['board'], action, turn, configuration)
    if is_win(observation['board'], action, turn, configuration, True):
      return 1
    if len(get_legal_actions(observation)) == 0:
      return 0
    next_turn = sum([1 if p != 0 else 0 for p in observation['board']])%2 + 1
    return -playout(observation, configuration, random_action(observation), next_turn)

def mcs_action(observation, configuration):
    next_turn = sum([1 if p != 0 else 0 for p in observation['board']])%2 + 1
    legal_actions =  [c for c in range(COLUMNS) if observation['board'][c] == 0]
    values = [0] * len(legal_actions)
    for i, action in enumerate(legal_actions):
        for _ in range(10):
            values[i] += playout(copy.deepcopy(observation), configuration, action, next_turn)
    return legal_actions[argmax(values)]

def get_legal_actions(observation):
  return [c for c in range(COLUMNS) if observation['board'][c] == 0]

def random_action(observation):
  return int(random.choice(get_legal_actions(observation)))

def argmax(collection, key=None):
    return collection.index(max(collection))

def get_turn(obs):
  next_turn = sum([1 if p != 0 else 0 for p in obs['board']])%2 + 1
  prev_turn = 2 if next_turn == 1 else 1
  return next_turn, prev_turn

def get_state(obs):
  next_turn, prev_turn = get_turn(obs) 
  board = copy.deepcopy(obs['board'])
  state_pieces = [1 if board[i] == next_turn else 0 for i in range(ROWS*COLUMNS)]
  state_enemy_pieces = [1 if board[i] == prev_turn else 0 for i in range(ROWS*COLUMNS)]
  return state_pieces, state_enemy_pieces

In [None]:
def nodes_to_scores(nodes):
    scores = []
    for c in nodes:
        scores.append(c.n)
    return scores

def pv_mcts_scores(observation, configuration, model, temperature):
  class Node:
    def __init__(self, obs, conf, p, is_win):
      self.obs = obs
      self.conf = conf
      self.next_turn = sum([1 if p != 0 else 0 for p in self.obs['board']])%2 + 1
      self.p = p
      self.is_win = is_win
      self.w = 0
      self.n = 0
      self.child_nodes = None

    def evaluate(self):
      if self.is_win:
        value = 1
        self.w += value
        self.n += 1
        return value
      elif len(get_legal_actions(self.obs)) == 0:
        value = 0
        self.w += value
        self.n += 1
        return value

      if not self.child_nodes:
        value, policies = predict(model, self.obs)
        self.w += value
        self.n += 1
        self.expand(policies)
        return value
      else:
        value = -self.next_child_node().evaluate() 
        self.w += value
        self.n += 1
        return value

    def expand(self, policies):
      legal_actions = get_legal_actions(self.obs)
      self.child_nodes = []
      for action, policy in zip(legal_actions, policies):
        next_obs = copy.deepcopy(self.obs)
        play(next_obs['board'], action, self.next_turn, self.conf)
        is_done_and_win = is_win(next_obs['board'], action, self.next_turn, self.conf, True)
        self.child_nodes.append(Node(next_obs, self.conf, policy, is_done_and_win))

    def next_child_node(self):
      C_PUCT = 1.0
      t = sum(nodes_to_scores(self.child_nodes))
      pucb_values = []
      for child_node in self.child_nodes:
        pucb_values.append((child_node.w / child_node.n if child_node.n else 0.0) + C_PUCT * child_node.p * math.sqrt(t) / (1 + child_node.n))
      return self.child_nodes[np.argmax(pucb_values)]
    
  root_node = Node(copy.deepcopy(observation), configuration, [], False)

  for _ in range(PV_EVALUATE_COUNT):
    root_node.evaluate()

  scores = nodes_to_scores(root_node.child_nodes)
  if temperature == 0:
    action = np.argmax(scores)
    scores = np.zeros(len(scores))
    scores[action] = 1
  else:
    scores = boltzman(scores, temperature)
  return scores

def pv_mcts_action(model, temperature=0):
  def pv_mcts_action(observation, configuration):
    scores = pv_mcts_scores(observation, configuration, model, temperature)
    legal_actions = get_legal_actions(observation)
    return int(np.random.choice(legal_actions, p=scores))
  return pv_mcts_action

def boltzman(xs, temperature):
    xs = [x ** (1 / temperature) for x in xs]
    return [x / sum(xs) for x in xs]

In [None]:
def first_player_value(obs, config, prev_turn, first_turn, prev_action):
  if is_win(obs['board'], prev_action, prev_turn, config, True):
    if prev_turn == first_turn:
      return 1
    else:
      return -1
  return 0

def write_data(history):
  now = datetime.now()
  os.makedirs('./data/', exist_ok=True)
  path = './data/{:04}{:02}{:02}{:02}{:02}{:02}.history'.format(now.year, now.month, now.day, now.hour, now.minute, now.second)
  with open(path, mode='wb') as f:
    pickle.dump(history, f)

def one_play(model):
  history = []

  env = make("connectx", debug=True)
  trainer = env.train([None, "random"])
  observation = trainer.reset()

  first_turn = sum([1 if p != 0 else 0 for p in observation['board']])%2 + 1
  next_turn = None
  prev_turn = None
  prev_action = None
  is_done = False

  while not is_done:

    next_turn = sum([1 if p != 0 else 0 for p in observation['board']])%2 + 1
    prev_turn = 2 if next_turn == 1 else 1

    scores = pv_mcts_scores(observation, env.configuration, model, SP_TEMPERATURE)
    policies = [0] * COLUMNS
    legal_actions = get_legal_actions(observation)
    for action, policy in zip(legal_actions, scores):
      policies[action] = policy
    state_pieces, state_enemy_pieces = get_state(observation)
    action = np.random.choice(legal_actions, p=scores)
    history.append([[state_pieces, state_enemy_pieces], action, None])
    play(observation['board'], int(action), next_turn, env.configuration)
    is_done = is_win(observation['board'], int(action), next_turn, env.configuration, True) or len(get_legal_actions(observation))==0
    prev_action = int(action)

  value = first_player_value(observation, env.configuration, next_turn, first_turn, prev_action)

  for i in range(len(history)):
    history[i][2] = value
    value = -value
  return history

def self_play(model):
  history = []
  for i in range(SP_GAME_COUNT):
    h = one_play(model)
    history.extend(h)
    print('\rSelfPlay {}/{}'.format(i+1, SP_GAME_COUNT), end='')
  print('')
  write_data(history)

In [None]:
def load_data():
    history_path = sorted(Path('./data').glob('*.history'))[-1]
    with history_path.open(mode='rb') as f:
      return pickle.load(f)

def train_network(model, times):
  class DataSet(Dataset):
    def __init__(self, state, policy, value):
      self.state = state
      self.policy = policy
      self.value = value
    def __len__(self):
      return len(self.state)
    def __getitem__(self, idx):
      return self.state[idx], self.policy[idx], self.value[idx]
    
  history = load_data()
  xs, y_policies, y_values = zip(*history)
  l = len(xs)
  xs = np.array(xs).reshape(l,2,6,7)

  tensor_state = torch.tensor(xs, dtype=torch.float).to(DEVICE)
  tensor_policy = torch.tensor(y_policies, dtype=torch.long).to(DEVICE)
  tensor_value = torch.tensor(y_values, dtype=torch.float).to(DEVICE)

  train_state = tensor_state[TEST_DATA_NUM:]
  train_policy = tensor_policy[TEST_DATA_NUM:]
  train_value = tensor_value[TEST_DATA_NUM:]
  test_state = tensor_state[:TEST_DATA_NUM]
  test_policy = tensor_policy[:TEST_DATA_NUM]
  test_value = tensor_value[:TEST_DATA_NUM]
  print("train Data", train_state.shape, train_policy.shape, train_value.shape)
  print("test Data", test_state.shape, test_policy.shape, test_value.shape)

  trainDS  = DataSet(train_state, train_policy, train_value)
  trainLoader = DataLoader(trainDS, batch_size=BATCH_SIZE, shuffle=True)
  testDS = DataSet(test_state, test_policy, test_value)
  testLoader = DataLoader(testDS, batch_size=100, shuffle=False)

  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
  policyLossF = nn.CrossEntropyLoss()
  valueLossF = nn.MSELoss()

  iterate = int(len(xs) / BATCH_SIZE)
  for epoch in range(RN_EPOCHS):
    time.sleep(1)
    losses = [0, 0, 0]
    accuracy_train = [0, 0]
    accuracy_test = [0, 0]
    model.train()
    ep = tqdm(trainLoader)
    for i, (X, Y_policy, Y_value) in enumerate(ep):
      optimizer.zero_grad()

      v, p = model.forward(X)
      p_loss = policyLossF(p, Y_policy)
      v_loss = valueLossF(v.view(-1), Y_value)

      #loss = (p_loss * policyVias) + (v_loss * valueVias)
      loss = (p_loss * 1) + (v_loss * 1)
      loss.backward()
      optimizer.step()

      losses[0] += loss.item()
      losses[1] += p_loss.item()
      losses[2] += v_loss.item()
    tb.add_scalar('Loss_{}'.format(times), losses[0], epoch)

  model.eval()

  for i, (X, Y_policy, Y_value) in enumerate(trainLoader):
    v, p = model.forward(X)
    _, pred = torch.max(p, dim=1)
    accuracy_train[0] += pred.eq(Y_policy).sum().item() / Y_value.shape[0]
    accuracy_train[1] += (v.view(-1).round() == Y_value).sum().item() / Y_value.shape[0]

  for i, (X, Y_policy, Y_value) in enumerate(testLoader):
    v, p = model.forward(X)
    _, pred = torch.max(p, dim=1)
    accuracy_test[0] += pred.eq(Y_policy).sum().item() / Y_value.shape[0]
    accuracy_test[1] += (v.view(-1).round() == Y_value).sum().item() / Y_value.shape[0]

  tb.add_scalar('Accuracy_Policy_Train', 100.0 * accuracy_train[0] / train_state.shape[0], times)
  tb.add_scalar('Accuracy_Value_Train', 100.0 * accuracy_train[1] / train_state.shape[0], times)

  tb.add_scalar('Accuracy_Policy_Test', 100.0 * accuracy_test[0] / TEST_DATA_NUM, times)
  tb.add_scalar('Accuracy_Value_Test', 100.0 * accuracy_test[1] / TEST_DATA_NUM, times)

In [None]:
def get_win_percentages(agent1, agent2, n_rounds=20):
    config = {'rows': 6, 'columns': 7, 'inarow': 4}       
    outcomes = evaluate("connectx", [agent1, agent2], config, [], n_rounds//2)     
    outcomes += [[b,a] for [a,b] in evaluate("connectx", [agent2, agent1], config, [], n_rounds-n_rounds//2)]

    print("Agent 1 Win Percentage:", np.round(outcomes.count([1,-1])/len(outcomes), 2))
    print("Agent 2 Win Percentage:", np.round(outcomes.count([-1,1])/len(outcomes), 2))
    print("Number of Invalid Plays by Agent 1:", outcomes.count([None, 0]))
    print("Number of Invalid Plays by Agent 2:", outcomes.count([0, None]))
    return np.round(outcomes.count([1,-1])/len(outcomes), 2)

In [None]:
def evaluate_model(model, times):
  agent_func = pv_mcts_action(model, SP_TEMPERATURE)
  vs_random = get_win_percentages(agent_func, 'random')
  vs_monte = get_win_percentages(agent_func, mcs_action)

  tb.add_scalar('VS_Random', vs_random, times)
  tb.add_scalar('VS_Monte', vs_monte, times)

In [None]:
model = ResNet()
model = model.to(DEVICE)

In [None]:
for i in range(LOOPS):
    print('Train',i,'====================')
    # セルフプレイ部
    self_play(model)

    # パラメータ更新部
    train_network(model, i)

    evaluate_model(model, i)

In [None]:
tb.close()
%tensorboard --logdir logs

In [None]:
out = """
from kaggle_environments import make, evaluate
from kaggle_environments.envs.connectx.connectx import play,is_win

import os
import sys
import io
import base64

import torch
import torch.nn as nn
import torch.nn.functional as f

import numpy as np
import random
import math

import copy

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ROWS = 6
COLUMNS = 7

N_BLOCKS = 12
FILTERS = 128

PV_EVALUATE_COUNT = 50
TEMPERATURE = 0.0

encoded_weights = \"\"\"BASE64_PARAMS\"\"\"

def agent_func(observation, configuration):
  class Block(nn.Module):
    def __init__(self, filter):
      super(Block,self).__init__()
      self.conv1 = nn.Conv2d(filter, filter, kernel_size=3, padding=1)
      self.conv2 = nn.Conv2d(filter, filter, kernel_size=3, padding=1)

      self.bn1 = nn.BatchNorm2d(filter)
      self.bn2 = nn.BatchNorm2d(filter)

      self.relu1 = nn.ReLU()
      self.relu2 = nn.ReLU()

    def forward(self, x):
      input = x
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu1(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = x + input
      x = self.relu2(x)
      return x

  class ResNet(nn.Module):
    def __init__(self):
      super(ResNet,self).__init__()
      self.conv = nn.Conv2d(2, FILTERS, kernel_size=3, padding=1)
      self.bn = nn.BatchNorm2d(FILTERS)
      self.relu = nn.ReLU()

      self.blocks = nn.ModuleList([Block(FILTERS) for _ in range(N_BLOCKS)])

      self.conv_p = nn.Conv2d(FILTERS, 8, kernel_size=1)
      self.bn_p = nn.BatchNorm2d(8)
      self.relu_p = nn.ReLU()
      self.linear_p = nn.Linear(8 * ROWS*COLUMNS, COLUMNS)
      self.softmax_p = nn.Softmax(dim=1)

      self.conv_v = nn.Conv2d(FILTERS, 4, kernel_size=1)
      self.bn_v = nn.BatchNorm2d(4)
      self.relu_v = nn.ReLU()
      self.linear_v1 = nn.Linear(4 * ROWS*COLUMNS, 64)
      self.linear_v2 = nn.Linear(64, 1)
      self.tanh_v = nn.Tanh()

    def forward(self, x):
      x = self.conv(x)
      x = self.bn(x)
      x = self.relu(x)

      for block in self.blocks:
        x = block(x)

      x_p = self.conv_p(x)
      x_p = self.bn_p(x_p)
      x_p = self.relu_p(x_p)
      x_p = x_p.view(x_p.size(0), -1)
      x_p = self.linear_p(x_p)
      policy = self.softmax_p(x_p)


      x_v = self.conv_v(x)
      x_v = self.bn_v(x_v)
      x_v = self.relu_v(x_v)
      x_v = x_v.view(x_v.size(0), -1)
      x_v = self.linear_v1(x_v)
      x_v = self.linear_v2(x_v)
      value = self.tanh_v(x_v)

      return value, policy
  def predict(model, obs):
    state_pieces, state_enemy_pieces = get_state(obs)
    value, policy = model(torch.tensor([state_pieces,state_enemy_pieces], dtype=torch.float).reshape(1, 2, 6, 7).to(DEVICE))
    value = value[0][0].item()
    policy = np.array([policy[0].tolist()[c] for c in get_legal_actions(obs)])
    policy /= sum(policy) if sum(policy) else 1
    return value, policy.tolist()

  def get_legal_actions(observation):
    return [c for c in range(COLUMNS) if observation['board'][c] == 0]

  def argmax(collection, key=None):
      return collection.index(max(collection))

  def get_turn(obs):
    next_turn = sum([1 if p != 0 else 0 for p in obs['board']])%2 + 1
    prev_turn = 2 if next_turn == 1 else 1
    return next_turn, prev_turn

  def get_state(obs):
    next_turn, prev_turn = get_turn(obs) 
    board = copy.deepcopy(obs['board'])
    state_pieces = [1 if board[i] == next_turn else 0 for i in range(ROWS*COLUMNS)]
    state_enemy_pieces = [1 if board[i] == prev_turn else 0 for i in range(ROWS*COLUMNS)]
    return state_pieces, state_enemy_pieces

  def nodes_to_scores(nodes):
      scores = []
      for c in nodes:
          scores.append(c.n)
      return scores

  def pv_mcts_scores(observation, configuration, model, temperature):
    class Node:
      def __init__(self, obs, conf, p, is_win):
        self.obs = obs
        self.conf = conf
        self.next_turn = sum([1 if p != 0 else 0 for p in self.obs['board']])%2 + 1
        self.p = p
        self.is_win = is_win
        self.w = 0
        self.n = 0
        self.child_nodes = None

      def evaluate(self):
        if self.is_win:
          value = 1
          self.w += value
          self.n += 1
          return value
        elif len(get_legal_actions(self.obs)) == 0:
          value = 0
          self.w += value
          self.n += 1
          return value

        if not self.child_nodes:
          value, policies = predict(model, self.obs)
          self.w += value
          self.n += 1
          self.expand(policies)
          return value
        else:
          value = -self.next_child_node().evaluate() 
          self.w += value
          self.n += 1
          return value

      def expand(self, policies):
        legal_actions = get_legal_actions(self.obs)
        self.child_nodes = []
        for action, policy in zip(legal_actions, policies):
          next_obs = copy.deepcopy(self.obs)
          play(next_obs['board'], action, self.next_turn, self.conf)
          is_done_and_win = is_win(next_obs['board'], action, self.next_turn, self.conf, True)
          self.child_nodes.append(Node(next_obs, self.conf, policy, is_done_and_win))

      def next_child_node(self):
        C_PUCT = 1.0
        t = sum(nodes_to_scores(self.child_nodes))
        pucb_values = []
        for child_node in self.child_nodes:
          pucb_values.append((child_node.w / child_node.n if child_node.n else 0.0) + C_PUCT * child_node.p * math.sqrt(t) / (1 + child_node.n))
        return self.child_nodes[np.argmax(pucb_values)]
      
    root_node = Node(copy.deepcopy(observation), configuration, [], False)

    for _ in range(PV_EVALUATE_COUNT):
      root_node.evaluate()

    scores = nodes_to_scores(root_node.child_nodes)
    if temperature == 0:
      action = np.argmax(scores)
      scores = np.zeros(len(scores))
      scores[action] = 1
    else:
      scores = boltzman(scores, temperature)
    return scores

  def pv_mcts_action(model, temperature=0):
    def pv_mcts_action(observation, configuration):
      scores = pv_mcts_scores(observation, configuration, model, temperature)
      legal_actions = get_legal_actions(observation)
      return np.random.choice(legal_actions, p=scores)
    return pv_mcts_action

  def boltzman(xs, temperature):
      xs = [x ** (1 / temperature) for x in xs]
      return [x / sum(xs) for x in xs]

  model = ResNet()
  decoded = base64.b64decode(encoded_weights)
  buffer = io.BytesIO(decoded)
  model = model.to(DEVICE)
  model.load_state_dict(torch.load(buffer, map_location=DEVICE))
  model.eval()
  next_action = pv_mcts_action(model, TEMPERATURE)

  return int(next_action(observation, configuration))
"""

In [None]:
import base64
import sys

dict_path = "net_params.pth"
torch.save(model.state_dict(), dict_path)

INPUT_PATH = dict_path
OUTPUT_NO_PARAM_PATH = 'submission_no_param.py'
OUTPUT_PATH = 'submission.py'

with open(INPUT_PATH, 'rb') as f:
    raw_bytes = f.read()
    encoded_weights = base64.encodebytes(raw_bytes).decode()

with open(OUTPUT_NO_PARAM_PATH, 'w') as f:
    f.write(out)

with open(OUTPUT_NO_PARAM_PATH, 'r') as f:
    data = f.read()

data = data.replace('BASE64_PARAMS', encoded_weights)

with open(OUTPUT_PATH, 'w') as f:
    f.write(data)
    print('written agent with net parameters to', OUTPUT_PATH)

In [None]:
from submission import agent_func

env = make("connectx", debug=True)
trainer = env.train(["random",None])

observation = trainer.reset()

while not env.done:
    my_action = agent_func(observation, env.configuration)
    print("My Action", my_action)
    observation, reward, done, info = trainer.step(my_action)

env.render(mode="ipython")