In [167]:
import numpy as np
import pickle

ALL_STATE_OBJS = []
ALL_STATE_OBJ_H = []
num_rows, num_cols = 3 , 3 
bsize = num_rows * num_cols



class State:
    def __init__(self):
        """ 
        Board is represented by a n * n array, 1 represents the player who moves first, -1 represents chessman of another player
        # 0 represents empty position
        """
        self.data = np.zeros((num_rows, num_cols))
        self.winner = None
        self.hash_value = None
        self.end = None
    def is_beginning(self):
        if np.sum(np.abs(self.data)) == 0:
            return True
        return False
    
    def num_chances_played(self):
        return np.sum(np.abs(self.data))
        
    def next_symbol(self):
        # Get the player which needs to play the next chance.
        # if sum is 1 this means player 1 has played his turn.
        if np.sum(self.data) == 0:
            return -1
        else:
            return 1
        
    def __hash__(self):
        return self.reduce_val
    
    def reduce_val(self):
        if self.hash_value is None:
            self.hash_value = 0
            for i in self.data.reshape(num_rows * num_cols):
                if i == -1:
                    i = 2
                self.hash_value = self.hash_value * 3 + i
        return int(self.hash_value)

    
    def is_terminal(self):
        if self.end is not None:
            # In already caclulated for the state return the value 
            return self.end

        # Check if all the state values are occupied.   
        if np.sum(np.abs(self.data)) == num_rows * num_cols:
            self.winner = None
            self.end = True
            return self.end    
        
        results = []
        cntd2, cntd1 = 0, 0
        for i in range(0, num_rows):
            results.append(np.sum(self.data[i, :]))
            results.append(np.sum(self.data[:, i]))
            cntd1 += self.data[i, i]
            cntd2 += self.data[i, num_rows - 1 - i]
        results += [cntd1, cntd2]        

        for result in results:
            if result == 3:
                self.winner = 1
                self.end = True
                return self.end
            if result == -3:
                self.winner = -1
                self.end = True
                return self.end
        self.end = False
        return self.end

    def get_next_state(self, i, j, symbol):
        newState = State()
        newState.data = np.copy(self.data)
        newState.data[i, j] = symbol
        return newState

    # print the board
    def print_state(self):
        for i in range(0, num_rows):
            print('-------------')
            out = '| '
            for j in range(0, num_cols):
                if self.data[i, j] == 1:
                    token = 'O'
                if self.data[i, j] == 0:
                    token = f'{num_cols*i+j+1}'
                if self.data[i, j] == -1:
                    token = 'X'
                out += token + ' | '
            print(out)
        print('-------------')



        
        



In [168]:

def dfs_states(currentState, currentSymbol, allStates):
    for i in range(0, num_rows):
        for j in range(0, num_cols):
            if currentState.data[i][j] == 0:
                newState = currentState.get_next_state(i, j, currentSymbol)
                newHash = newState.reduce_val()
                if newHash not in allStates.keys():
#                     if np.sum(newState.data)<0:
#                         ALL_STATE_OBJS.append(newState)
                    ALL_STATE_OBJS.append(newState)
                    ALL_STATE_OBJ_H.append(newState.reduce_val())
    
                    is_terminal = newState.is_terminal()
                    allStates[newHash] = (newState, is_terminal)
                    if not is_terminal:
                        dfs_states(newState, -currentSymbol, allStates)

def find_all_states():
    currentSymbol = -1
    currentState = State()
    allStates = dict()
    allStates[currentState.reduce_val()] = (currentState, currentState.is_terminal())
    dfs_states(currentState, currentSymbol, allStates)
    return allStates


In [169]:
_ = find_all_states()
for st in ALL_STATE_OBJS:
    assert np.sum(st.data) <= 0

In [170]:
from collections import defaultdict
next_states = {}
                        
def find_next_states():
    #import pdb;pdb.set_trace()
    for state in ALL_STATE_OBJS:
        next_symbol = state.next_symbol()
        pos_states = []
        for i in range(num_rows):
            for j in range(num_cols):
                if state.data[i][j]==0 and not state.is_terminal():
                    st1 = state.get_next_state(i, j, next_symbol)
                    pos_states.append(st1)         
        next_states[state.reduce_val()] = pos_states
        
find_next_states()        


In [171]:
st = State()
st.data = np.array([[-1,1,-1], [1,-1,-1], [1,-1,1]])
print(st.is_terminal())
print(st.winner is None)
st.print_state()
state_val[st.reduce_val()]

True
True
-------------
| X | O | X | 
-------------
| O | X | X | 
-------------
| O | X | O | 
-------------


0

In [196]:
def calc_V_values(reward_non_terminal = 1):
    # value iteration
    state_val = defaultdict(lambda:1)
    for state in ALL_STATE_OBJS:
        st_k = state.reduce_val()
        if state.is_terminal():
            if state.winner is None:
                state_val[st_k] = 0 
            else:
                state_val[st_k] = 10 if state.winner == 1 else -10
    # Calculate V values 
    delta = 0.001
    cnt = 0
    gamma = 0.9

    for it in range(100):
        cnt += 1
        state_vals_up = {}

        for st in ALL_STATE_OBJS:

            if st.is_terminal():
                state_vals_up[st.reduce_val()] = state_val[st.reduce_val()]
                continue

            if np.sum(st.data) == 0:
                # Intermediate state, no need to calculate
                assert st.num_chances_played()%2 == 0
                continue

            hash_st = st.reduce_val()
            maxx_expectation = -1000
            next_sts = next_states[hash_st]

            for st1 in next_sts:
                assert np.sum(st1.data)==0

                next_sts_p = next_states[st1.reduce_val()]
                len_next_states = len(next_sts_p)

                expectation = 0

                if st1.is_terminal():
                    assert len_next_states == 0
                    expectation = state_val[st1.reduce_val()]
                else:    
                    for state in next_sts_p:
                        assert np.sum(state.data)==-1
                        expectation += (1/len_next_states) * state_val[state.reduce_val()]

                if expectation > maxx_expectation:
                    maxx_expectation = expectation

            state_vals_up[hash_st] = reward_non_terminal + gamma * maxx_expectation


        maxx_diff = -100
        for h, v in state_val.items():
            maxx_diff = max(maxx_diff, abs(state_vals_up[h] - v))
        if maxx_diff < delta:
            #print(f"converged in {cnt}")
            break
        state_val = state_vals_up   
    return state_val    

       

In [197]:
state_val = calc_V_values() 

In [198]:
st = State()
st.data = np.array([[-1,0,0], [1,-1,-1], [1,0,0]])
st.print_state()
state_val[st.reduce_val()]

-------------
| X | 2 | 3 | 
-------------
| O | X | X | 
-------------
| O | 8 | 9 | 
-------------


7.299999999999999

In [175]:
def get_next_mdp_states(st):
    next_mdp_states = []
    assert np.sum(st.data) == -1
    hash_st = st.reduce_val()
    next_sts = next_states[hash_st]    
    for st1 in next_sts:
        assert np.sum(st1.data)==0
        next_sts_p = next_states[st1.reduce_val()]
        len_next_states = len(next_sts_p)
        if st1.is_terminal():
            assert len_next_states == 0
            next_mdp_states.append((st1, st1))
        for state in next_sts_p:
            assert np.sum(state.data)==-1
            next_mdp_states.append((st1, state))
    return next_mdp_states

## 1. (5 terminal states of MDP)

In [176]:
cnt = 0
print("TERMINAL WINNING STATES FOR X")
for state in ALL_STATE_OBJS:
    # X -> -1, O -> 1.
    if state.is_terminal() and state.winner==1:
        cnt += 1
        state.print_state()
        if cnt == 5:
            break

TERMINAL WINNING STATES FOR X
-------------
| X | O | X | 
-------------
| O | O | X | 
-------------
| X | O | 9 | 
-------------
-------------
| X | O | X | 
-------------
| O | O | O | 
-------------
| X | X | 9 | 
-------------
-------------
| X | O | X | 
-------------
| O | O | O | 
-------------
| X | 8 | X | 
-------------
-------------
| X | O | X | 
-------------
| O | O | 6 | 
-------------
| X | O | X | 
-------------
-------------
| X | O | X | 
-------------
| O | O | O | 
-------------
| 7 | X | X | 
-------------


## 2,3. Value to Go and V values for the 5 different trajectories

In [177]:
trajectories = []
for state in ALL_STATE_OBJS:
    #import pdb;pdb.set_trace()
    if len(trajectories) == 5:
        break
    if state.num_chances_played() == 1:
        trajectory = [state]
        #import pdb;pdb.set_trace()
        while True:
            #pdb.set_trace()
            states = next_states[state.reduce_val()]
            #pdb.set_trace()
            for st in states:
                assert st.reduce_val() in ALL_STATE_OBJ_H
            state = states[np.random.randint(len(states))]
            trajectory.append(state)  
            if state.is_terminal():
                trajectories.append(trajectory)
                break

In [178]:
val_to_go_trajectories = []
for i, tr in enumerate (trajectories):
    pref_sum = [state_val[tr[-1].reduce_val()]]
    for st in tr[::-1][1:]:
        if st.num_chances_played()%2 == 0:
            continue
        st_v = 1 if not st.is_terminal() else state_val[st.reduce_val()]
        pref_sum.append(pref_sum[-1]*gamma + st_v)
    val_to_go_trajectories.append(pref_sum[::-1])    

    
    
for tr, v_to_go, i in zip(trajectories, val_to_go_trajectories, range(len(trajectories))):
    print("##"*50)
    print(f"\n Trajectory number {i+1} is:")
    cnt = 0
    for st in tr:        
        if st.num_chances_played() % 2:
            st.print_state()            
            print("Value for above state is {}".format(state_val[st.reduce_val()]))
            print("Value to go above state is {}".format(v_to_go[cnt]))
            print("*"*50)
            cnt += 1
        else:
            if st.is_terminal():
                st.print_state()
            #print("Intermediate stage")

####################################################################################################

 Trajectory number 1 is:
-------------
| X | 2 | 3 | 
-------------
| 4 | 5 | 6 | 
-------------
| 7 | 8 | 9 | 
-------------
Value for above state is 9.5626
Value to go above state is -4.58
**************************************************
-------------
| X | 2 | 3 | 
-------------
| X | 5 | 6 | 
-------------
| 7 | O | 9 | 
-------------
Value for above state is 10.0
Value to go above state is -6.2
**************************************************
-------------
| X | 2 | X | 
-------------
| X | O | 6 | 
-------------
| 7 | O | 9 | 
-------------
Value for above state is 10.0
Value to go above state is -8.0
**************************************************
-------------
| X | X | X | 
-------------
| X | O | 6 | 
-------------
| 7 | O | O | 
-------------
Value for above state is -10
Value to go above state is -10
**************************************************
################

## 4 .Get Q values for all the possible actions for 2 States

In [200]:
           
def get_all_q_values(st, state_val, reward_non_terminal=1):
        print("##"*50)
        print("Original State is")
        st.print_state()
        for i in range(num_rows):
            for j in range(num_cols):
                if st.data[i][j] == 0:
                    #import pdb;pdb.set_trace()
                    action = (i, j, st.next_symbol())
                    reward = reward_non_terminal if not st.is_terminal() else 10 * st.winner
                    
                    nst = st.get_next_state(*action)
                    qv = 0
                    if nst.is_terminal():
                        if nst.winner is not None:
                            qv = 10 * nst.winner
                    
                    else:
                        nxt_sts = next_states[nst.reduce_val()]

                        for state in nxt_sts:
                            qv += state_val[state.reduce_val()]/ len(nxt_sts)
                            
                    qv = qv* gamma + reward
                    
                    print(f"For action {action[0]*3+action[1]+1} on original state Q value is {qv}")   
                    #nst.print_state()
                    print("*"*30)
                    
cnt  = 0 
sts5 = []
for st in ALL_STATE_OBJS:
    if st.num_chances_played()==5:
        sts5.append(st)

                            

## Q values for all the possible actions from 2 MDP states. 

In [201]:
for i in range(2):
    import random
    idx = random.randint(0,len(sts3))
    get_all_q_values(sts5[idx], state_val)


####################################################################################################
Original State is
-------------
| 1 | O | 3 | 
-------------
| 4 | X | X | 
-------------
| X | O | 9 | 
-------------
For action 1 on original state Q value is -2.0
******************************
For action 3 on original state Q value is 1.3000000000000003
******************************
For action 4 on original state Q value is -1.4
******************************
For action 9 on original state Q value is -4.700000000000001
******************************
####################################################################################################
Original State is
-------------
| X | 2 | X | 
-------------
| 4 | O | X | 
-------------
| 7 | O | 9 | 
-------------
For action 2 on original state Q value is 10.0
******************************
For action 4 on original state Q value is -2.0
******************************
For action 7 on original state Q value is -2.0
*****************

## 5.  Change Reward function (+4 and -4)

In [194]:
q_game_state = State()
q_game_state.data = np.array([[-1, 0, -1], [1, 1, 0], [-1, 0, 0]])
state_val = calc_V_values(reward_non_terminal=4)

get_all_q_values(q_game_state, state_val)

converged in 5
####################################################################################################
Original State is
-------------
| X | 2 | X | 
-------------
| O | O | 6 | 
-------------
| X | 8 | 9 | 
-------------
For action 2 on original state Q value is 12.700000000000001
******************************
For action 6 on original state Q value is 10.0
******************************
For action 8 on original state Q value is 5.799999999999999
******************************
For action 9 on original state Q value is 3.0999999999999996
******************************


In [195]:
q_game_state = State()
q_game_state.data = np.array([[-1, 0, -1], [1, 1, 0], [-1, 0, 0]])
state_val = calc_V_values(reward_non_terminal=4)

get_all_q_values(q_game_state, state_val, reward_non_terminal=4)

converged in 5
####################################################################################################
Original State is
-------------
| X | 2 | X | 
-------------
| O | O | 6 | 
-------------
| X | 8 | 9 | 
-------------
For action 2 on original state Q value is 15.700000000000001
******************************
For action 6 on original state Q value is 13.0
******************************
For action 8 on original state Q value is 8.799999999999999
******************************
For action 9 on original state Q value is 6.1
******************************


In [189]:
# Negative non terminal reqard state 
state_val = calc_V_values(reward_non_terminal=-4)
q_game_state = State()
q_game_state.data = np.array([[-1, 0, -1], [1, 1, 0], [-1, 0, 0]])
get_all_q_values(q_game_state, state_val, reward_non_terminal=-4)

converged in 5
####################################################################################################
Original State is
-------------
| X | 2 | X | 
-------------
| O | O | 6 | 
-------------
| X | 8 | 9 | 
-------------
For action 2 on original state Q value is 0.5
******************************
For action 6 on original state Q value is 5.0
******************************
For action 8 on original state Q value is -4.0
******************************
For action 9 on original state Q value is -6.7
******************************
