In [53]:
import numpy as np
import copy
from collections import defaultdict

WINNING_POSITIONS = [(0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8),(2,4,6)]
gamma = 0.9
X_PLAYER, O_PLAYER = 1, 2
total_states = 0

def check_win(state):
    for a, b, c in WINNING_POSITIONS:
        # player - 1- 'X' wins, R(s) = -10
        if state[a] == state[b] == state[c] and state[a] == 'x':
            return -10
        # player - 2- 'O' wins, R(s) = 10
        elif state[a] == state[b] == state[c] and state[a] == 'o':
            return 10

    # draw
    if len(valid_moves(state)) == 0:
        return 0
    # non terminal state
    else:
        return 
        
def valid_moves(state):
    s = np.array(list(state))
    return np.asarray(np.where(s == '_')).flatten()
 
    
initial_state = '_________'
states = defaultdict()
root_states = []


def generate_child_states(parent):
    
    available_moves = valid_moves(parent)
    
    for move in available_moves:
        s = list(parent)
        s[move] = 'o'
        s = ''.join(s)
        
        reward = check_win(s)
        if reward != 1:
            states[s] = reward
            continue
        
        for position in valid_moves(s):
            new_parent = list(s)
            new_parent[position] = 'x'
            new_parent = ''.join(new_parent)
            
            r = check_win(new_parent)
            
            states[new_parent] = r
            
            if r != 1:
                continue
            else:
                generate_child_states(new_parent)        
    return
            
            
# Generate roots
def generate_initial_states():
    
    for i in range(9):
        s = list(initial_state)
        s[i] = 'x'
        s = ''.join(s)
        root_states.append(s)
        
        states[s] = check_win(s)
        
        # Generate child states
        generate_child_states(s)
        

generate_initial_states()

In [66]:
print(len(states))

states_keys = list(states.keys())
probability_table = np.zeros((9, len(states), len(states)))
rewards_list = np.zeros(len(states))

3055


In [67]:
def transition_probabilities(parent):
    available_moves = valid_moves(parent)
    parent_index = states_keys.index(parent)
    
    for move in available_moves:
        s = list(parent)
        s[move] = 'o'
        s = ''.join(s)
        
        if check_win(s) != 1:
            rewards_list[states_keys.index(s)] = states[s]
            probability_table[move, parent_index, states_keys.index(s)] = 1 # double check this
            continue
        
        for position in valid_moves(s):
            new_parent = list(s)
            new_parent[position] = 'x'
            new_parent = ''.join(new_parent)
            
            probability_table[move, parent_index, states_keys.index(new_parent)] = 1/len(valid_moves(s)) # double check this 
            rewards_list[states_keys.index(new_parent)] = states[new_parent]
            
            
            if check_win(new_parent) != 1:
                continue
            else:
                transition_probabilities(new_parent)        
    return

In [86]:
for root in root_states:
    rewards_list[states_keys.index(root)] = states[root]
    transition_probabilities(root)
    
print(len(rewards_list))

3055


In [120]:
# Printing trajectories

def print_grid(state):
    st = list(state)
    print(' {:1} | {:1} | {:1}'.format(state[0],state[1],state[2]))
    print('-----------')
    print(' {:1} | {:1} | {:1}'.format(state[3],state[4],state[5]))
    print('-----------')
    print(' {:1} | {:1} | {:1}'.format(state[6],state[7],state[8]))


rtg_list = []
    
def print_trajectory(parent):
    
    print_grid(parent)
    
    available_moves = valid_moves(parent)
    choice = np.random.choice(available_moves)
    s = list(parent)
    s[choice] = 'o'
    s = ''.join(s)
    
    reward = check_win(s)
    if reward != 1:
        print_grid(s)
        rtg_list.append(states[s])
        return states[s]
        
        
    next_x_moves = valid_moves(s)
    choice = np.random.choice(next_x_moves)
    new_parent = list(s)
    new_parent[choice] = 'x'
    new_parent = ''.join(new_parent)
    
    r = check_win(new_parent)

    if r != 1:
        print_grid(new_parent)
        rtg_list.append(states[new_parent])
        return states[new_parent]
    else:
        rtg = states[new_parent] + gamma * print_trajectory(new_parent)
        rtg_list.append(rtg)
        return rtg

    
    
random_seeds = [0, 10, 20, 30, 40]
for i in range(len(random_seeds)):
    np.random.seed(random_seeds[i])
    print('Reward to go for state1 is ' + str(states[root_states[i]] + gamma * print_trajectory(root_states[i])))
    
    for j in range(len(rtg_list)):
        print('Reward to go for state' + str(j+2) + ' is ' + str(rtg_list[len(rtg_list) - j - 1]))
    
    print('--------------------------------------------------------------------')
    
    rtg_list = []

 x | _ | _
-----------
 _ | _ | _
-----------
 _ | _ | _


 x | _ | _
-----------
 _ | _ | o
-----------
 _ | x | _


 x | o | _
-----------
 _ | _ | o
-----------
 x | x | _


 x | o | _
-----------
 x | _ | o
-----------
 x | x | o


Reward to go for state1 is -4.58
Reward to go for state2 is -6.2
Reward to go for state3 is -8.0
Reward to go for state4 is -10
--------------------------------------------------------------------
 _ | x | _
-----------
 _ | _ | _
-----------
 _ | _ | _


 _ | x | o
-----------
 _ | _ | _
-----------
 _ | x | _


 x | x | o
-----------
 _ | _ | _
-----------
 o | x | _


 x | x | o
-----------
 _ | o | _
-----------
 o | x | _


Reward to go for state1 is 10.0
Reward to go for state2 is 10.0
Reward to go for state3 is 10.0
Reward to go for state4 is 10
--------------------------------------------------------------------
 _ | _ | x
-----------
 _ | _ | _
-----------
 _ | _ | _


 _ | _ | x
-----------
 x | o | _
-----------
 _ | _ | _


 _ | _ | x
-------

In [102]:
import math

values = np.zeros(len(states))
actions = [0 for i in range(len(states))]


def check_for_convergence(old_vals):
    diff = values - old_vals
    flag = 0
    for d in diff:
        if abs(d) > 0.1:
            flag = 1
    
    if flag == 0:
        return True;
    else:
        return False;


while True:
    old_vals = copy.deepcopy(values)
    
    for i in range(len(states)):
        win = check_win(states_keys[i])
        if win != 1:
            values[i] = rewards_list[i]
            
        maximum = -math.inf
        action = 0
        
        for j in range(9):
            new_val = rewards_list[i] + gamma * (np.dot(probability_table[j, i, :], values))
            
            if new_val > maximum:
                maximum = new_val
                action = j
                
        values[i] = maximum
        actions[i] = action
        
    if check_for_convergence(old_vals) == False:
        continue;
    else:
        break;


In [126]:
def print_value_trajectory(parent):
    
    print_grid(parent)
    print('Value of the above state is ' + str(values[states_keys.index(parent)]))
    
    available_moves = valid_moves(parent)
    choice = np.random.choice(available_moves)
    s = list(parent)
    s[choice] = 'o'
    s = ''.join(s)
    
    reward = check_win(s)
    if reward != 1:
        print_grid(s)
        print('Value of the above state is ' + str(values[states_keys.index(s)]))
        return states[s]
        
        
    next_x_moves = valid_moves(s)
    choice = np.random.choice(next_x_moves)
    new_parent = list(s)
    new_parent[choice] = 'x'
    new_parent = ''.join(new_parent)
    
    r = check_win(new_parent)

    if r != 1:
        print_grid(new_parent)
        print('Value of the above state is ' + str(values[states_keys.index(new_parent)]))
        return 
    else:
        print_value_trajectory(new_parent)
        return 

    
    
random_seeds = [0, 10, 20, 30, 40]
for i in range(len(random_seeds)):
    np.random.seed(random_seeds[i])
    print_value_trajectory(root_states[i])
    print('-----------------------------------------------------------------')

 x | _ | _
-----------
 _ | _ | _
-----------
 _ | _ | _


Value of the above state is 9.562600000000002
 x | _ | _
-----------
 _ | _ | o
-----------
 _ | x | _


Value of the above state is 9.514000000000001
 x | o | _
-----------
 _ | _ | o
-----------
 x | x | _


Value of the above state is 1.3000000000000003
 x | o | _
-----------
 x | _ | o
-----------
 x | x | o


Value of the above state is -10.0
-----------------------------------------------------------------
 _ | x | _
-----------
 _ | _ | _
-----------
 _ | _ | _


Value of the above state is 9.736171428571428
 _ | x | o
-----------
 _ | _ | _
-----------
 _ | x | _


Value of the above state is 10.0
 x | x | o
-----------
 _ | _ | _
-----------
 o | x | _


Value of the above state is 10.0
 x | x | o
-----------
 _ | o | _
-----------
 o | x | _


Value of the above state is 10.0
-----------------------------------------------------------------
 _ | _ | x
-----------
 _ | _ | _
-----------
 _ | _ | _


Value of the above 

In [133]:
def print_Q_states(parent):
    
    index = states_keys.index(parent)
    #print('Q Value of the above state is ' + str(rewards_list[index] + gamma * values[index]))
    
    available_moves = valid_moves(parent)
    
    for move in available_moves:
        s = list(parent)
        s[move] = 'o'
        s = ''.join(s)

        reward = check_win(s)
        if reward != 1:
            print('Q - Value of the above state is ' + str(values[states_keys.index(s)]))
            continue
        
        total = 0
        
        for position in valid_moves(s):
                
            new_parent = list(s)
            new_parent[position] = 'x'
            new_parent = ''.join(new_parent)
            
            r = check_win(new_parent)
            
            
            total += (probability_table[move, index, states_keys.index(new_parent)] * values[states_keys.index(new_parent)])
            
        total *= gamma 
        total += rewards_list[states_keys.index(parent)]
        
        print(total)



for i in range(3,5):
    print_Q_states(root_states[i])
    print('-----------------------------------------------------------------')

9.187685714285713
8.958571428571428
9.736171428571428
9.625085714285715
8.410085714285714
9.187685714285713
8.958571428571428
9.736171428571428
-----------------------------------------------------------------
8.312885714285715
6.924314285714286
8.312885714285713
6.924314285714285
6.924314285714286
8.312885714285713
6.924314285714286
8.312885714285713
-----------------------------------------------------------------
