<a href="https://colab.research.google.com/github/sheric98/Connect-Four-AI/blob/master/Connect_Four_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/ME')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /ME


In [None]:
import numpy as np
import copy

In [None]:
# connect four board
class Game:
    def __init__(self, nrows=6, ncols=7):
        self.player = 1
        self.nrows = nrows
        self.ncols = ncols
        self.board = np.zeros((nrows, ncols),dtype=int)
        self.display = np.empty((nrows,ncols),dtype=str)
        self.display[:,:] = ' '
        self.encode = np.zeros((nrows,ncols,3),dtype=int)
        self.col_pos = np.zeros(ncols, dtype=int)
        self.valid_cols = np.ones(ncols, dtype=bool)
        self.moves = 0
        self.neighbors = [((1,1), (-1,-1)), ((1,-1), (-1,1)),
                          ((1,0), (-1,0)), ((0,1), (0,-1))]
        self.piece_map = {1: 'X', -1: 'O'}
        self.encode_map = {1: 0, -1: 1}
        self.state = 2  # 2 represents not ended

    def get_actual_player(self):
        return (2 * int(self.moves % 2 == 0)) - 1

    def get_key(self):
        return self.board.tobytes()

    def get_piece(self, tile):
        return self.board[tile[1],tile[0]]

    def check_valid_tile(self, tile):
        x = tile[0]
        y = tile[1]
        return (0 <= x < self.ncols) and (0 <= y < self.nrows)

    def move_in_dir(self, direc, tile):
        return (tile[0] + direc[0], tile[1] + direc[1])

    def check_dir(self, direc, tile, chain):
        if not self.check_valid_tile(tile):
            return chain
        piece = self.get_piece(tile)
        if piece == self.player:
            if chain == 3:  # this means player won
                return 4
            next_in_chain = self.move_in_dir(direc, tile)
            return self.check_dir(direc, next_in_chain, chain+1)
        return chain

    # check win (four consecutive) of
    # current player for current move
    def check_win(self, move):
        for pair in self.neighbors:
            chain1 = self.check_dir(pair[0],move,0)
            chain2 = self.check_dir(pair[1],move,0)
            if chain1 + chain2 - 1 >= 4:
                return True
        return False

    def check_full(self):
        return self.moves == (self.nrows * self.ncols)

    def check_valid_move(self, col):
        return (0 <= col < self.ncols) and self.valid_cols[col]

    def update_col(self, col):
        self.col_pos[col] += 1
        # column is full
        if self.col_pos[col] == self.nrows:
            self.valid_cols[col] = False
    
    def update_board(self, x,y):
        self.board[y,x] = self.player
        self.display[y,x] = self.piece_map[self.player]
        ind = self.encode_map[self.player]
        other_ind = 1-ind
        self.encode[y,x,ind] = 1
        self.encode[:,:,2] = other_ind


    # make move
    # returns 1 for win, 2 for draw,
    # 0 for move made, and -1 for invalid move
    def make_move(self, col):
        if self.check_valid_move(col):
            x = col
            y = self.col_pos[col]

            self.update_board(x,y)
            self.moves += 1
            self.update_col(col)
            move = (x,y)
            if self.check_win(move):
                #print('Player %d wins' % (self.player))
                self.state = self.player
                self.player *= -1
                return 1
            if self.check_full():
                #print('Draw')
                self.state = 0
                self.player *= -1
                return 2
            self.player *= -1
            return 0
        return -1
    
    def display_board(self):
        for i in range(self.nrows):
            print(self.display[self.nrows-1-i,:])

    def make_copy(self):
        return copy.deepcopy(self)

    def normalize_board(self):
        self.board *= self.player
        if self.state != 2:
           self.state *= self.player
        self.player *= self.player

    def apply_and_normalize(self, move):
        game = self.make_copy()
        game.make_move(move)
        game.normalize_board()
        return game

In [None]:
game = Game()
game.make_move(0)
game.make_move(6)
game.make_move(1)
game.make_move(6)
game.make_move(3)
game.make_move(6)
print(game.state)
game.make_move(2)
game.display_board()
print(game.state)
print(game.player)
game.normalize_board()
game.display_board()
print(game.state)
print(game.player)

2
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' 'O']
[' ' ' ' ' ' ' ' ' ' ' ' 'O']
['X' 'X' 'X' 'X' ' ' ' ' 'O']
1
-1
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' ' ']
[' ' ' ' ' ' ' ' ' ' ' ' 'O']
[' ' ' ' ' ' ' ' ' ' ' ' 'O']
['X' 'X' 'X' 'X' ' ' ' ' 'O']
-1
1


In [None]:
# define network
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cuda:0


In [None]:
class NNet(nn.Module):
    def __init__(self, nrows=6, ncols=7, nres=2, dropout=0.3,
                 step_size=0.001):
        super(NNet, self).__init__()
        self.nrows = nrows
        self.ncols = ncols
        self.nres = nres
        self.dropout = dropout
        self.conv = nn.Conv2d(1,512,3,padding=1)
        self.bn = nn.BatchNorm2d(512)
        self.conv1 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv3 = nn.Conv3d(512, 512, 3)
        self.conv4 = nn.Conv4d(512, 512, 3)
        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(512)
        self.bn3 = nn.BatchNorm2d(512)
        self.bn4 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512*(nrows-4)*(ncols-4), 1024)
        self.fc1_bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, ncols)
        self.fc4 = nn.Linear(512, 1)
        self.criterion1 = nn.MSELoss()
        self.criterion2 = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr = step_size)


    def forward(self, x):
        # apply init convoulutional lyer
        x = x.view(-1,1,self.nrows,self.ncols)
        x = F.relu(self.bn(self.conv(x)))
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        # apply final output layer. Outputs value and policy
        x = x.view(-1, 512*(self.nrows-4)*(self.ncols-4))
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = F.dropout(x,p=self.dropout, training=self.training)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.dropout(x,p=self.dropout, training=self.training)
        p = F.log_softmax(self.fc3(x), dim=1)
        v = torch.tanh(self.fc4(x))
        return v,p

    def predict(self, game):
        self.eval()
        inp = torch.from_numpy(game.board).float().to(device)
        with torch.no_grad():
            v,p = self(inp)
        
        return v.data.cpu().numpy()[0][0], p.exp().data.cpu().numpy()[0]

    def calc_loss(self, v, v_targs, p, p_targs):
        v_loss = torch.sum((v.view(-1) - v_targs)**2) / v_targs.size()[0]
        p_loss = -torch.trace(torch.mm(p, p_targs.t())) / p_targs.size()[0]
        return v_loss + p_loss

    def get_loss(self, games, v_targs, p_targs):
        v, p = self.forward(games)
        return self.calc_loss(v, v_targs, p, p_targs)
    
    def run_grad(self, games, v_targs, p_targs):

        # Compute loss
        loss = self.get_loss(games, v_targs, p_targs)
        # Zero out gradients
        self.optimizer.zero_grad()
        # Compute gradients
        loss.backward()
        # Update parameters based on gradients
        self.optimizer.step()
        
        return loss

In [None]:
def train_one_epoch(net, training, batch_size=64, thresh=10):
    net.train()
    np.random.shuffle(training)
    N = len(training)
    tot_loss = 0
    for j in range(0,N,batch_size):
        batch = training[j:j+batch_size]
        B = len(batch)
        if B < thresh:
            N -= B
            break
        boards, v_targs, p_targs = list(zip(*batch))
        boards = np.stack(boards)
        p_targs = np.stack(p_targs)
        games = torch.FloatTensor(boards.astype(np.float64)).to(device)
        v_targs = torch.FloatTensor(np.array(v_targs).astype(np.float64)).to(device)
        p_targs = torch.FloatTensor(p_targs).to(device)
        loss = net.run_grad(games, v_targs, p_targs)
        tot_loss += (loss * B)
    return tot_loss / N

In [None]:
def train_mult_epochs(net, training, num_epochs=20):
    for i in range(num_epochs):
        loss = train_one_epoch(net, training)
    print('Loss is %f' % (loss))

# Monte-Carlo Search Tree


In [None]:
class MCST:
    def __init__(self, net, cpuct=1, num_sims=25):
        self.net = net
        self.cpuct = cpuct
        self.num_sims = num_sims
        self.Qs = {}
        self.Ns = {}
        self.N_pairs = {}
        self.masks = {}
        self.board_vals = {}
        self.policies = {}
    
    def is_expanded(self, game):
        key = game.get_key()
        return key in self.policies

    def expand(self, game):
        key = game.get_key()
        v, p = self.net.predict(game)
        self.masks[key] = game.valid_cols.astype(int)
        p *= self.masks[key]
        div = np.sum(p)
        if div == 0:
            print('Mask is all 0 in Expand')
            print('Backup to uniform distribution')
            p = np.ones(game.ncols)
            div = game.ncols
        p /= div
        self.policies[key] = p
        self.Ns[key] = 0
        return v
    
    def get_value(self, game):
        key = game.get_key()
        if key not in self.board_vals:
            self.board_vals[key] = game.state
        return self.board_vals[key]

    def calc_U(self, game, move):
        key = game.get_key()
        pair = (key, move)
        Q = 0
        N_pair = 0
        prob = self.policies[key][move]
        mask = self.masks[key]
        if mask[move] == 0:
            return -1
        N = self.Ns[key] if self.Ns[key] != 0 else 1e-8  # small default value
        if pair in self.Qs:
            Q = self.Qs[pair]
            N_pair = self.N_pairs[pair]
        return Q + self.cpuct * prob * np.sqrt(N) / (1 + N_pair)


    def get_best_move(self, game):
        scores = []
        for move in range(game.ncols):
            scores.append(self.calc_U(game, move))
        best = np.argmax(np.array(scores))
        return best
        
    def update_pair(self, game, move, v):
        key = game.get_key()
        pair = (key, move)
        Q = 0
        N_pair = 0
        if pair in self.Qs:
            Q = self.Qs[pair]
            N_pair = self.N_pairs[pair]
        self.Qs[pair] = (Q * N_pair + v) / (1 + N_pair)
        self.N_pairs[pair] = N_pair + 1

    def search(self, game):
        end_val = game.state
        if end_val != 2:
            return -end_val
        
        if not self.is_expanded(game):
            return -self.expand(game)

        best_move = self.get_best_move(game)
        next_game = game.apply_and_normalize(best_move)
        v = self.search(next_game)
        self.update_pair(game, best_move, v)

        key = game.get_key()
        self.Ns[key] += 1

        return -v

    def get_probs(self, game, temp=1):
        key = game.get_key()

        for i in range(self.num_sims):
            self.search(game)

        counts = []
        for move in range(game.ncols):
            pair = (key,move)
            count = 0
            if pair in self.N_pairs:
                count = self.N_pairs[pair]
            counts.append(count)
        counts = np.asarray(counts).astype(float)
        
        if temp == 0:
            best_moves = np.array(np.argwhere(counts == np.max(counts))).flatten()
            best = np.random.choice(best_moves)
            probs = np.zeros(game.ncols)
            probs[best] = 1
            return probs

        div = np.sum(counts)
        if div == 0:
            print('All Moves 0 in Get Prob')

        probs = counts / div
        return probs

    def get_move(self, game):
        return np.argmax(self.get_probs(game,temp=0))

# Make Training Examples

In [None]:
def play_through_examples(net, thresh=15):
    i = 0
    game = Game()
    mcst = MCST(net)
    end = 2
    training_tups = []
    player = 1

    while end == 2:
        normalized_board = game.make_copy()
        normalized_board.normalize_board()
        temp = 1 if i < thresh else 0
        probs = mcst.get_probs(normalized_board, temp)
        board = normalized_board.board
        training_tups.append((board,probs,player))
        move = np.random.choice(len(probs), p=probs)
        game.make_move(move)
        player *= -1
        i += 1

        end = game.state
    ret = []
    for tup in training_tups:
        train = (tup[0],end*tup[2],tup[1])
        ret.append(train)
    return ret


In [None]:
nnet = NNet()
nnet.to(device)
train = play_through_examples(nnet)
print(len(train))


20


# Self Play


In [None]:
# tree 1 goes first
def play_single(tree1, tree2):
    game = Game()

    while game.state == 2:
        normalized_board = game.make_copy()
        normalized_board.normalize_board()
        tree = tree1 if game.player == 1 else tree2
        move = tree.get_move(normalized_board)
        game.make_move(move)
    return game.state

In [None]:
def play_mult_games(tree1, tree2, num_games=40):
    per_round = num_games // 2
    wins_1 = 0
    wins_2 = 0
    draws = 0

    # play with tree1 going first
    for i in range(per_round):
        res = play_single(tree1, tree2)
        if res == 1:
           wins_1 += 1
        elif res == -1:
            wins_2 += 1
        else:
            draws += 1
    
    # play with tree2 going first
    for i in range(per_round):
        res = play_single(tree2, tree1)
        if res == 1:
            wins_2 += 1
        elif res == -1:
            wins_1 += 1
        else:
          draws += 1
    
    return wins_1, wins_2, draws

In [None]:
# returns True if new net should replace prev net
def compare_nets(new_net, prev_net, thresh=0.6):
    new_mcst = MCST(new_net)
    prev_mcst = MCST(prev_net)

    new_wins, prev_wins, draws = play_mult_games(new_mcst, prev_mcst)
    print(new_wins, prev_wins, draws)

    if new_wins + prev_wins == 0:
        return False
    new_win_rate = new_wins / (new_wins + prev_wins)
    if new_win_rate > thresh:
        return True
    return False


# Training

In [None]:
import _pickle as pickle
# save network
def save_net(net, path):
    # Save model
    torch.save(net.state_dict(), path)

# load network
def load_net(path):
    net = NNet()
    net.to(device)
    state_dict = torch.load(path, map_location = device)
    net.load_state_dict(state_dict)
    return net

def save_and_load(net, path):
    save_net(net, path)
    copied = load_net(path)
    return copied

def save_prev_training(prev_training, path):
    with open(path,'wb') as output:
        pickle.dump(prev_training,output, -1)

def load_prev_training(path):
    with open(path,'rb') as file:
        ret = pickle.load(file)
    return ret

In [None]:
from collections import deque as deq
import time

def train_net(net, path, training_path, niters=100, neps=100,
              queue_cap=100000, prev_training_cap=20, start=0):
    if net is None:
        curr_net = load_net(path)
        prev_training = load_prev_training(training_path)
        print(len(prev_training))
    else:
        curr_net = net
        save_net(curr_net, path)
        prev_training = []

    for i in range(start, niters):
        print('Starting iter %d' % (i+1))
        start = time.time()
        iter_training = deq([], maxlen=queue_cap)
        for j in range(neps):
            train = play_through_examples(curr_net)
            iter_training.extend(train)
        
        prev_training.append(iter_training)

        if len(prev_training) > prev_training_cap:
            prev_training.pop(0)
        
        end1 = time.time()
        print('Generating Training: %f seconds' % (end1 - start))

        save_prev_training(prev_training, training_path)

        training = []
        for x in prev_training:
            training.extend(x)

        prev_net = load_net(path)

        train_mult_epochs(curr_net, training)

        end2 = time.time()
        print('Training Model: %f seconds' % (end2 - end1))

        comp = compare_nets(curr_net, prev_net)

        end3 = time.time()
        print('Final Comparison: %f seconds' % (end3 - end2))
        if not comp:
            print('reject model')
            curr_net = prev_net
        else:
            print('update model')
            save_net(curr_net, path)
        
        print('Total Iteration Time: %f seconds\n' % (end3 - start))
    
    save_net(curr_net, path)
    return curr_net


In [None]:
datadir='/ME/My Drive/Colab Notebooks/'
path=datadir + 'models/c4ai'
train_path = datadir + 'training/c4train'

In [None]:
#nnet = NNet()
#nnet.to(device)

c4ai = train_net(None, path, train_path,start=25)

# Play Against Bot


In [None]:
def represents_int(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False

def player_move(game):
    while True:
        print('Enter a valid column from 1-7:')
        a = input()
        if not represents_int(a):
            continue
        inp = int(a) - 1
        if game.check_valid_move(inp):
            return inp


def play_against_net(net, mcst_sims=50, player=1):
    game = Game()
    mcst = MCST(net, num_sims=mcst_sims)
    while game.state == 2:
        if game.player == player:
            game.display_board()
            move = player_move(game)
        else:
            normalized_board = game.make_copy()
            normalized_board.normalize_board()
            move = mcst.get_move(normalized_board)
        game.make_move(move)
    game.display_board()
    if game.state == player:
        print('You Win!')
    elif game.state == -player:
        print('You Lose!')
    else:
        print('Draw.')

In [None]:
c4ai = load_net(path)

play_against_net(c4ai)