In [85]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import random

In [86]:
#utitlity functions
def isterminal(state, size):
    if(state == (0,0) or state == (size-1,size-1)):
        return True
    return False

In [87]:
class DP:
    def __init__(self, size, actions, policy, discount):
        self.size = size
        self.actions = actions
        self.discount = discount
        self.policy = policy
        
    def getNextState(self, state, action):
        if(isterminal(state, self.size)):
            return state, 0
        else:
            next_state = tuple(np.array(state) + np.array(action))
            if(next_state[0] < self.size and next_state[1] < self.size and next_state[0] > 0 and next_state[1] > 0):
                return next_state, -1
            else:
                return state, -1
        
    def policy_iteration(self):
        
        new_v =  np.zeros((self.size, self.size))
        while(True):
            #policy evaluation
            new_v = np.zeros((self.size,self.size))
            while(True):
                delta = 0
                for i in range(self.size):
                    for j in range(self.size):
                        state = (i,j)
                        v = 0
                        for a in range(len(self.actions)):
                            new_state, reward = self.getNextState(state, self.actions[a])
                            v += self.policy[state][a]*(reward + self.discount*new_v[new_state[0], new_state[1]]) 
#                         print(state, v)
                        
                        delta = max(delta, np.abs(v - new_v[i,j]))
                        new_v[i,j] = v
                    
#                 for i in range(self.size):
#                     for j in range(self.size):
#                         print(new_v[i,j], end=" ")
#                     print()
#                 print()
                if(delta < 1e-4):
                    break

#             print("Pnefn")
            #policy improvement
            policy_stable = True
            for i in range(self.size):
                for j in range(self.size):
                    action_values = np.zeros(len(self.actions))
                    state = (i,j)
                    for a in range(len(self.actions)):
                        new_state , reward = self.getNextState(state, self.actions[a])
                        action_values[a] += (reward + self.discount*new_v[new_state[0], new_state[1]])
                    old_action = np.argmax(self.policy[state])
                    new_action = np.argmax(action_values)
                    
                    if(old_action != new_action):
                        policy_stable = False
                    self.policy[state] = np.eye(4)[new_action]
                        
#             for i in self.policy:
#                 print(self.policy[i])
#             print()
            
            if(policy_stable):
                break
        return self.policy, new_v
                        
                

In [88]:
fake_policy = {}

for i in range(4):
    for j in range(4):
        fake_policy[(i,j)] = np.ones(4)/4
        
dp = DP(4, [(-1,0),(0,1),(1,0),(0,-1)], fake_policy, 1)

for i in fake_policy:
    print(fake_policy[i])
print()

policy, v = dp.policy_iteration()

print("policy distribution")
for i in range(4):
    for j in range(4):
        print((i,j), policy[(i,j)])
print()

temp = []
for i in range(4):
    for j in range(4):
        print(np.argmax(policy[(i,j)]), end=" ")
    print()

print(np.round(v))
print()



[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]

policy distribution
(0, 0) [1. 0. 0. 0.]
(0, 1) [0. 0. 1. 0.]
(0, 2) [0. 0. 1. 0.]
(0, 3) [0. 0. 1. 0.]
(1, 0) [0. 1. 0. 0.]
(1, 1) [0. 1. 0. 0.]
(1, 2) [0. 1. 0. 0.]
(1, 3) [0. 0. 1. 0.]
(2, 0) [0. 1. 0. 0.]
(2, 1) [0. 1. 0. 0.]
(2, 2) [0. 1. 0. 0.]
(2, 3) [0. 0. 1. 0.]
(3, 0) [0. 1. 0. 0.]
(3, 1) [0. 1. 0. 0.]
(3, 2) [0. 1. 0. 0.]
(3, 3) [1. 0. 0. 0.]

0 2 2 2 
1 1 1 2 
1 1 1 2 
1 1 1 0 
[[ 0. -5. -4. -3.]
 [-5. -4. -3. -2.]
 [-4. -3. -2. -1.]
 [-3. -2. -1.  0.]]

