In [125]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import random

In [126]:
#utitlity functions
def isterminal(state, size):
    if(state == (0,0) or state == (size-1,size-1)):
        return True
    return False

In [127]:
class DP:
    def __init__(self, size, actions, policy, discount):
        self.size = size
        self.actions = actions
        self.discount = discount
        self.policy = policy
        
    def getNextState(self, state, action):
        if(isterminal(state, self.size)):
            return state, 0
        else:
            next_state = tuple(np.array(state) + np.array(action))
            if(next_state[0] < 4 and next_state[1] < 4):
                return next_state, -1
            else:
                return state, -1
        
    def action_values(self, state, values):
        expected_values = np.zeros(len(self.actions))
        for i in range(len(self.actions)):
            [new_state, reward] = self.getNextState(state, self.actions[i])
            expected_values[i] += (reward + self.discount*values[new_state[0], new_state[1]])
        return expected_values
        
    def policy_eval(self):
        updated_values = np.zeros((self.size, self.size))
        delta = 1000
        
        while(delta > 1e-5):
            delta = 0
            for i in range(self.size):
                for j in range(self.size):
                    value = 0
                    state = (i,j)
                    for a in range(len(self.actions)):
                        [new_state, reward] = self.getNextState(state, self.actions[a])
                        value += self.policy[state][a]*(reward + self.discount*updated_values[new_state[0], new_state[1]])
                    delta = max(delta, np.abs(value - updated_values[i, j]))
                    updated_values[i, j] = value
                    
        return updated_values
    
    def policy_iter(self):
        stable = True
        
        while(True):
            v = self.policy_eval()
            stable = True
            
            for i in range(self.size):
                for j in range(self.size):
                    state = (i,j)
                    best_action = np.argmax(self.policy[state])
                    
                    new_best_action = np.argmax(self.action_values(state, v))
                    
                    if(best_action != new_best_action):
                        stable = False
                    temp = np.zeros(4)
                    temp[new_best_action] = 1
                    self.policy[state] = np.eye(len(self.actions))[best_action]
                    print(v)
            if(stable):
                return self.policy, v  
        

In [128]:
fake_policy = {}

for i in range(4):
    for j in range(4):
        fake_policy[(i,j)] = np.ones(4)/4
        
dp = DP(4, [(0,-1),(0,1),(1,0),(-1,0)], fake_policy, 1)

policy, v = dp.policy_iter()

print("policy distribution")
for i in range(4):
    for j in range(4):
        print((i,j), policy[(i,j)])
print()

m = [[0 for i in range(4)] for j in range(4)]

c = 0
for i in range(4):
    for j in range(4):
        m[i][j] = c
        c+=1
        
print(m)

new_policy = np.zeros((16,4))
for i in range(4):
    for j in range(4):
        new_policy[m[i][j]] = np.array(policy[i,j])
print()
        
print(new_policy)
print()

print(np.reshape(np.argmax(new_policy, axis=1), (4,4)))
print()

print(np.round(v))
print()



[[  0.          -9.76649827 -11.16727335  -8.93381999]
 [ -9.76649827 -12.2645443  -12.76259949 -11.63420089]
 [-11.16727335 -12.76259949 -11.98439532  -9.20619685]
 [ -8.93381999 -11.63420089  -9.20619685   0.        ]]
[[  0.          -9.76649827 -11.16727335  -8.93381999]
 [ -9.76649827 -12.2645443  -12.76259949 -11.63420089]
 [-11.16727335 -12.76259949 -11.98439532  -9.20619685]
 [ -8.93381999 -11.63420089  -9.20619685   0.        ]]
[[  0.          -9.76649827 -11.16727335  -8.93381999]
 [ -9.76649827 -12.2645443  -12.76259949 -11.63420089]
 [-11.16727335 -12.76259949 -11.98439532  -9.20619685]
 [ -8.93381999 -11.63420089  -9.20619685   0.        ]]
[[  0.          -9.76649827 -11.16727335  -8.93381999]
 [ -9.76649827 -12.2645443  -12.76259949 -11.63420089]
 [-11.16727335 -12.76259949 -11.98439532  -9.20619685]
 [ -8.93381999 -11.63420089  -9.20619685   0.        ]]
[[  0.          -9.76649827 -11.16727335  -8.93381999]
 [ -9.76649827 -12.2645443  -12.76259949 -11.63420089]
 [-11.

KeyboardInterrupt: 