In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from collections import Counter
import pickle,random

In [None]:
%matplotlib inline

In [None]:
def game_over(test,flag):
    res = np.concatenate([test.sum(axis=0),test.sum(axis=1)])
    res = np.append(res,np.array(np.trace(test) ))
    res = np.append(res,np.array(np.trace(np.fliplr(test))))
    counts = Counter(res)
    
    if len(test[test==0])<1:
        flag = 'game over'
    if counts.get(3,0)==1 or counts.get(-3,0)==1:
        flag='game over'
    return flag

def check_valid_state(test):
    res = np.concatenate([test.sum(axis=0),test.sum(axis=1)])
    res = np.append(res,np.array(np.trace(test) ))
    res = np.append(res,np.array(np.trace(np.fliplr(test))))
    counts = Counter(res)
    if counts.get(-3,0)>2 or counts.get(3,0)>1:
        return 0
    else:
        return 1
    
# Assign proper probabilities to winning and losing states
def assign_prob(test,prob):
    res = np.concatenate([test.sum(axis=0),test.sum(axis=1)])
    res = np.append(res,np.array(np.trace(test) ))
    res = np.append(res,np.array(np.trace(np.fliplr(test))))
    
    counts = Counter(res)
    if counts.get(3,0)>=1:
        return 1
    elif counts.get(-3,0)>=1:
        return 0
    else:
        return prob

### Value function based approach

1. Setup a table of numbers, one for each possible state.-- 3**9 states
2. Remove invalid states.
3. Assign initial probability of winning from each state i.e  VALUE. The whole table is the value function.
4. To select our move, we consider all possible states from current state and choose the one with highest value (greedy) but sometimes we choose randomly as well(exploration).
5. While playing we update the current state values to make them more accurate estimate of winning probability.
6. The current value of the earlier state is adjusted to be closer to the value of the later state i.e
    If we let s denote the state before the greedy move, and s′ the state after the move, then the update to the estimated value of s, denoted V(s), can be written as 
                    V(s) = V(s) + a(V(s') - V(s))  , temporal difference learning rate

In [None]:
# Create all states
K=3
N=3
all_states = pd.DataFrame([(np.reshape(np.array(i),(K,N)) for i in product([-1,1,0],repeat=K*N))]).T

all_states['probability'] = 1/2
all_states.columns = ['state','probability']

all_states['valid'] = all_states.apply(lambda x: check_valid_state(x['state']),axis=1)

val_states = all_states[all_states['valid']==1].reset_index(drop=True)
val_states['num_blanks'] = val_states['state'].apply(lambda x: x[x==0].shape[0])

val_states['probability'] = val_states['state'].apply(lambda x: assign_prob(x))

In [None]:
val_states = pd.read_pickle('./tic_tac_toe_rand_agnt_policy.pkl')

# Play the game

In [None]:
## Playing against a random opponent
alpha = 0.2

ai_mv = 1
opp_mv = -1
flag = 'continue'

for k in range(5000):
    flag = 'continue'
    explore_exploit_flag = 'exploit'
    i=1
    initial_state = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
    cur_state = pickle.loads(pickle.dumps(val_states.loc[val_states['state'].apply(lambda x: (x==initial_state).all())].squeeze()))
    print(f'#################### Match {k} ####################')
    while flag=='continue':
        if i%2==0:
            test = pickle.loads(pickle.dumps(cur_state['state']))
            pos_states = []
            for row,col in np.argwhere(test==0):
                test[row,col] = ai_mv
                st = pickle.loads(pickle.dumps(val_states[val_states['state'].apply(lambda x: (x==test).all())]))
                pos_states.append(st)
                test[row,col]=0
            pos_states = pd.concat(pos_states)
            explore_exploit_flag = np.random.choice(['exploit','explore'],p=[0.6,0.4])
            if explore_exploit_flag=='explore':
                next_stg = pickle.loads(pickle.dumps(pos_states.sample(1,weights=pos_states['probability']).squeeze()))            
            else:
                next_stg = pickle.loads(pickle.dumps(pos_states[pos_states['probability']==pos_states['probability'].max()].sample(1).squeeze()))
            
    
        else:
            #r = int(input())
            #c = int(input())
            next_stg = pickle.loads(pickle.dumps(cur_state))
            r,c = random.choice(np.argwhere(next_stg['state']==0))
            next_stg['state'][r,c]=opp_mv
            next_stg = pickle.loads(pickle.dumps(val_states.loc[val_states['state'].apply(lambda x: (x==next_stg['state']).all())].squeeze()))


            
        next_prob = next_stg['probability']
        # update
        cur_index = val_states.loc[val_states['state'].apply(lambda x: (x==cur_state['state']).all())].index[0].copy()
        cur_prob = val_states.loc[cur_index,'probability']
        update = alpha*(next_prob-cur_prob)
        val_states.loc[cur_index,'probability']+=update        
            
        cur_state = pickle.loads(pickle.dumps(next_stg))
        #print('-------------------------')
            
        #fig, ax = plt.subplots()
        #Using matshow here just because it sets the ticks up nicely. imshow is faster.
        #ax.matshow(cur_state['state'], cmap='Oranges')

        #for (p, k), z in np.ndenumerate(cur_state['state']):
        #    ax.text(k,p, '{:0.1f}'.format(z), ha='center', va='center')

        #plt.show()
        
        
        i+=1
        flag = game_over(next_stg['state'],flag)
        
        if flag=='game over':
            break
        


In [None]:
val_states['probability'] = val_states.apply(lambda x: assign_prob(x['state'],x['probability']),axis=1)

In [None]:
val_states.to_pickle('./tic_tac_toe_rand_agnt_policy.pkl')    

In [None]:
initial_state = np.array([[0, 0, 0], [-1, -1, 0], [0, 0, 1]])
cur_state = pickle.loads(pickle.dumps(val_states.loc[val_states['state'].apply(lambda x: (x==initial_state).all())].squeeze()))
test = pickle.loads(pickle.dumps(cur_state['state']))
pos_states = []
for row,col in np.argwhere(test==0):
    test[row,col] = ai_mv
    st = pickle.loads(pickle.dumps(val_states[val_states['state'].apply(lambda x: (x==test).all())]))
    pos_states.append(st)
    test[row,col]=0
pos_states = pd.concat(pos_states)

In [None]:
pos_states