In [1]:
import pandas
import numpy as np
import torch
import matplotlib.pyplot as plt
import random
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
import connect_four

In [69]:
def get_pairs():
    tmp_pairs = []
    for i in range(6):
        for j in range(7):
            tmp_pairs.append([(i, j), (i+1,j), (i+2,j), (i+3,j)])
            tmp_pairs.append([(i, j), (i+1,j+1), (i+2,j+2), (i+3,j+3)])
            tmp_pairs.append([(i, j), (i,j+1), (i,j+2), (i,j+3)])
            tmp_pairs.append([(i, j), (i-1,j+1), (i-2,j+2), (i-3,j+3)])
    pairs_filt = []
    for plist in tmp_pairs:
        plist = [(x, y) for x, y in plist if 0 <= x < 6 and 0 <= y < 7]
        if len(plist) == 4:
            pairs_filt.append(([p[0] for p in plist], [p[1] for p in plist]))
    # # Checks
    # print(len(tmp_pairs))
    # print(len(pairs_filt))
    # print(3 * 7 + 4 * 6 + 3*4 + 3*4)
    return pairs_filt

In [520]:
pairs = get_pairs()
test_state = np.zeros([6,7], dtype=np.int32)
test_state[5,3] = 1
test_state[5,2] = -1
test_state[4,3] = 1
test_state[3,3] = 1
# test_state[2,3] = 1
test_state[0,0] = 1
test_state[1,0] = -1
test_state[2,0] = 1
test_state[3,0] = -1
test_state[4,0] = 1
test_state[5,0] = -1

def start_state():
    return torch.zeros([1,1,6,7])

def gameEnded(state):
    state = state.squeeze()
    for plist in pairs:
        val = state[plist[0],plist[1]].sum()
        if val == 4:
            return 1
        elif val == -4:
            return -1
    if len(getValidActions(state)) == 0:
        return 0
    return 0

def getValidActions(state):
    val = torch.nonzero(((state.squeeze() != 0).sum(axis=0) < 6)).squeeze().tolist()
    if isinstance(val, int):
        return [val]
    else:
        return val

# State is always oriented to it being red's turn
def nextState(state, action):
    state = state.squeeze().clone()
    position = torch.argwhere(state[:,action] != 0)
    
    if len(position) == 0:
        state[5,action] = 1
        return state.reshape(1, 1, 6, 7)
    
    position = position.min()
    if position == 0:
        raise ValueError
    else:
        state[position-1, action] = 1
        return state.reshape(1, 1, 6, 7)

def to_rep(state):
    state = state + 1
    res = ''.join(state.to(torch.int).numpy().flatten().astype(str))
    return res

CF = ConnectFour()

In [521]:
class ConnectNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=1, padding=0),
            torch.nn.ReLU(),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,4), stride=1, padding=0),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 8)
        )
    
    def forward(self, x):
        output = self.layer1(x)
        output = self.fc(output)
        return torch.tanh(output[:,0]), torch.nn.functional.softmax(output[:,1:], dim=1)

CN = ConnectNet()

In [522]:
c = 1

class MCTS:
    def __init__(self):
        self.visited = set()
        self.P = {}
        self.Q = {}
        self.N = {}
    
    def search(self, state, nnet):
        ge = gameEnded(state)
        if ge != 0:
            return -ge

        rep = to_rep(state)
        if rep not in self.visited:
            # print(f"NEW STATE {rep}")
            # print(state)
            self.visited.add(rep)
            v, probA = nnet(state)
            self.P[rep] = probA
            self.N[rep] = torch.zeros(7)
            self.Q[rep] = torch.zeros(7)
            return -v
        else:
            # print(f"OLD STATE {rep}")
            # print(state)
            max_u, best_a = -torch.inf, -1
            for a in getValidActions(state):
                u_explore = self.N[rep].sum().sqrt() / (1+self.N[rep][a])
                u = self.Q[rep][a] + c*self.P[rep][0,a]*u_explore
                if u > max_u:
                    max_u = u
                    best_a = a
            a = best_a

            # print(f"action {a}")
            sp = nextState(state, a)
            # print("NEXT STATE")
            # print(sp)
            v = self.search(sp * -1, nnet)

            self.Q[rep][a] = (self.N[rep][a] + v)/(self.N[rep][a] + 1)
            self.N[rep][a] += 1
            return -v
    
    def pi(self, state):
        rep = to_rep(state)
        return self.N[rep] / (self.N[rep].sum() + 1e-3)

In [523]:
# test_state = torch.zeros([1, 1, 6, 7], dtype=torch.float32)
# test_state[0,0] = -1
test_state_tensor = torch.Tensor(test_state).reshape(1, 1, 6, 7)

print(test_state_tensor)
# print(test_state_tensor)
# ''.join(test_state_tensor.to(torch.int).numpy().flatten().astype(str))
print(to_rep(test_state_tensor))
print(getValidActions(test_state_tensor))
print(gameEnded(test_state_tensor))
test_v, test_probA = CN.forward(test_state_tensor)
print(test_v)
print(test_probA)

tensor([[[[ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
          [-1.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
          [-1.,  0.,  0.,  1.,  0.,  0.,  0.],
          [ 1.,  0.,  0.,  1.,  0.,  0.,  0.],
          [-1.,  0., -1.,  1.,  0.,  0.,  0.]]]])
211111101111112111111011211121121110102111
[1, 2, 3, 4, 5, 6]
0
tensor([0.0409], grad_fn=<TanhBackward0>)
tensor([[0.1452, 0.1596, 0.1442, 0.1531, 0.1525, 0.1199, 0.1256]],
       grad_fn=<SoftmaxBackward0>)


In [524]:
(torch.zeros(1))

tensor([0.])

In [527]:
numMCTSsims = 15
numIters = 100
numEps = 10
learning_rate = 0.001

def policyIterSP():
    nnet = ConnectNet()
    optimizer = torch.optim.Adam(nnet.parameters(), lr=learning_rate)
    
    examples = []
    for i in range(numIters):
        print(i)
        for e in range(numEps):
            print(e)
            examples += executeEpisode(nnet) 
        trainNNet(nnet, examples, optimizer=optimizer)

def trainNNet(nnet, examples, optimizer):
    l = torch.zeros(1)
    for e in examples:
        v, proba = nnet(e[0])
        l += torch.square(v - e[2]) - torch.dot(e[1], torch.log(proba.squeeze()))
    
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
        

def executeEpisode(nnet):
    examples = []
    mcts = MCTS()           
    state = start_state()
    while True:
        for i in range(numMCTSsims):
            mcts.search(state, nnet)
        pi_val = mcts.pi(state)
        examples.append([state, pi_val, None])
        a = np.argmax(np.random.multinomial(1, pvals=pi_val))
        state = nextState(state, a)
        ge = gameEnded(state)
        if ge != 0:
            return assignRewards(examples, ge)
        state = -state

def assignRewards(examples, reward):
    assigned_examples = []
    for l in reversed(examples):
        l[2] = reward
        reward = -reward
        assigned_examples.append(l)
    return assigned_examples
    

In [528]:
policyIterSP()

0
0
1
2
3
4
5
6
7
8
9
1
0
1
2
3
4
5
6
7
8
9
2
0
1
2
3
4
5
6
7
8
9
3
0
1
2
3
4
5
6
7
8
9
4
0
1
2
3
4
5
6
7
8
9
5
0
1
2
3
4
5
6
7
8
9
6
0
1
2
3
4
5
6
7
8
9
7
0
1
2
3
4
5
6
7
8
9
8
0
1
2
3
4
5
6
7
8
9
9
0
1
2
3
4
5
6
7
8
9
10
0
1
2
3
4
5
6
7
8
9
11
0
1
2
3
4
5
6
7
8
9
12
0
1
2
3
4
5
6
7
8
9
13
0
1
2
3
4
5
6
7
8
9
14
0
1
2
3
4
5
6
7
8
9
15
0
1
2
3
4
5
6
7
8
9
16
0
1
2
3
4
5
6
7
8
9
17
0
1
2
3
4
5
6
7
8


ValueError: 