In [37]:
import numpy as np
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

In [38]:
class MDP:
    def __init__(self):
        self.size = 5
        self.a = (0,1)
        self.b = (0,3)
        self.a_prime = (4,1)
        self.b_prime = (2,3)
        self.actions = [(0,1), (1,0), (0,-1), (-1,0)]
        
    def step(self, state, action):
        if(state == self.a):
            return self.a_prime, 10
        if(state == self.b):
            return self.b_prime, 5
        next_state = tuple(np.array(state) + np.array(action))
        if(next_state[0] < 0 or next_state[1] < 0 or next_state[0] >= self.size or next_state[1] >= self.size):
            return state, -1
        return next_state, 0
        
    def figure32(self):
        m = np.zeros((self.size, self.size), dtype=int)
        c = 0
        for i in range(self.size):
            for j in range(self.size):
                m[i,j] = c
                c+=1
            
        A = np.zeros((self.size**2, self.size**2))
        b = np.zeros(self.size**2)
        
        for i in range(self.size):
            for j in range(self.size):
                state = (i,j)
                for a in range(len(self.actions)):
                    (next_i, next_j), reward = self.step(state, self.actions[a])
#                     print(m[i,j], m[next_i, next_j])
                    A[m[i,j], m[next_i, next_j]] += 0.25*0.9
                    b[m[i,j]] += 0.25*reward
                A[m[i,j], m[i,j]] = A[m[i,j], m[i,j]]-1
        
        X = np.linalg.solve(A,b)
        X = np.round(X, 1)
        for i in range(self.size):
            for j in range(self.size):
                print(X[m[i,j]], end=" ")
            print()
            
class Iterations:
    def __init__(self):
        self.size = 4
        self.terminal = [(0,0), (3,3)]
        self.actions = [(0,1), (1,0), (0,-1), (-1,0)]
        self.policy = np.ones((self.size**2, len(self.actions)))/4
        
    def step(self, state, action):
        if(state in self.terminal):
            return state, 0
        next_state = tuple(np.array(state) + np.array(action))
        if(next_state[0] < 0 or next_state[1] < 0 or next_state[0] >= self.size or next_state[1] >= self.size):
            return state, -1
        return next_state, -1
        
    def policy_iteration(self):
        m = np.zeros((self.size, self.size), dtype=int)
        c = 0
        for i in range(self.size):
            for j in range(self.size):
                m[i,j] = c
                c+=1
                
        value_func = None
        while(True):
            value_func = np.zeros((self.size, self.size))
            #evaluation
            while(True):
                delta = 0
                new_v = np.copy(value_func)
                for i in range(self.size):
                    for j in range(self.size):
                        state = (i,j)
                        value = 0
                        for a in range(len(self.actions)):
                            next_state, reward = self.step(state, self.actions[a])
                            value += self.policy[m[i,j],a]*(reward + value_func[next_state[0], next_state[1]])
                        delta = max(delta, abs(value-value_func[i,j]))
                        value_func[i,j] = value
#                 print(value_func)
#                 print(delta)
#                 value_func = new_v
                if(delta < 1e-4):
                    break
                
#             print(value_func)
            #improvement
            policy_stable = True
            new_policy = []
            for i in range(self.size):
                for j in range(self.size):
                    state = (i,j)
                    action_values = np.zeros(len(self.actions))
                    for a in range(len(self.actions)):
                        next_state, reward = self.step(state, self.actions[a])
                        action_values[a] += (reward + value_func[next_state[0], next_state[1]])
                    
                    old_action = np.argmax(self.policy[m[i,j]])
                    new_action = np.argmax(action_values)
                    
                    if(old_action != new_action):
                        policy_stable = False
                    
                    temp = np.zeros(len(self.actions), dtype=int)
                    temp[new_action] = 1
                    self.policy[m[i,j]] = temp
                    
            if(policy_stable):
                break
                
        print(self.policy)
        print(value_func)
                

In [39]:
mdp = MDP()
mdp.figure32()

-3.3 -8.8 -4.4 -5.3 -1.5 
-1.5 -3.0 -2.3 -1.9 -0.5 
-0.1 -0.7 -0.7 -0.4 0.4 
1.0 0.4 0.4 0.6 1.2 
1.9 1.3 1.2 1.4 2.0 


In [40]:
it = Iterations()
it.policy_iteration()

[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[ 0. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1.  0.]]
[[ 0.   -1.75 -2.   -2.  ]
 [-1.75 -2.   -2.   -2.  ]
 [-2.   -2.   -2.   -1.75]
 [-2.   -2.   -1.75  0.  ]]
[[ 0.     -2.4375 -2.9375 -3.    ]
 [-2.4375 -2.875  -3.     -2.9375]
 [-2.9375 -3.     -2.875  -2.4375]
 [-3.     -2.9375 -2.4375  0.    ]]
[[ 0.      -3.0625  -3.84375 -3.96875]
 [-3.0625  -3.71875 -3.90625 -3.84375]
 [-3.84375 -3.90625 -3.71875 -3.0625 ]
 [-3.96875 -3.84375 -3.0625   0.     ]]
[[ 0.        -3.65625   -4.6953125 -4.90625  ]
 [-3.65625   -4.484375  -4.78125   -4.6953125]
 [-4.6953125 -4.78125   -4.484375  -3.65625  ]
 [-4.90625   -4.6953125 -3.65625    0.       ]]
[[ 0.         -4.20898438 -5.50976562 -5.80078125]
 [-4.20898438 -5.21875    -5.58984375 -5.50976562]
 [-5.50976562 -5.58984375 -5.21875    -4.20898438]
 [-5.80078125 -5.50976562 -4.20898438  0.        ]]
[[ 0.         -4.734375   -6.27734375 -6.65527344]
 [-4.73437

 [-21.73110604 -19.75971425 -13.83784416   0.        ]]
[[  0.         -13.84647022 -19.77249651 -21.74541014]
 [-13.84647022 -17.79958288 -19.77401836 -19.77249651]
 [-19.77249651 -19.77401836 -17.79958288 -13.84647022]
 [-21.74541014 -19.77249651 -13.84647022   0.        ]]
[[  0.         -13.8546374  -19.78459881 -21.75895333]
 [-13.8546374  -17.81024429 -19.7860397  -19.78459881]
 [-19.78459881 -19.7860397  -17.81024429 -13.8546374 ]
 [-21.75895333 -19.78459881 -13.8546374    0.        ]]
[[  0.         -13.86237012 -19.79605731 -21.77177607]
 [-13.86237012 -17.82033855 -19.79742155 -19.79605731]
 [-19.79605731 -19.79742155 -17.82033855 -13.86237012]
 [-21.77177607 -19.79605731 -13.86237012   0.        ]]
[[  0.         -13.8696915  -19.80690626 -21.78391669]
 [-13.8696915  -17.82989584 -19.80819793 -19.80690626]
 [-19.80690626 -19.80819793 -17.82989584 -13.8696915 ]
 [-21.78391669 -19.80690626 -13.8696915    0.        ]]
[[  0.         -13.8766234  -19.81717809 -21.79541148]
 [-13

 [-21.99814115 -19.99833891 -13.99887902   0.        ]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[ 0. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1.  0.]]
[[ 0. -1. -2. -2.]
 [-1. -2. -2. -2.]
 [-2. -2. -2. -1.]
 [-2. -2. -1.  0.]]
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[ 0. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1. -1.]
 [-1. -1. -1.  0.]]
[[ 0. -1. -2. -2.]
 [-1. -2. -2. -2.]
 [-2. -2. -2. -1.]
 [-2. -2. -1.  0.]]
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
[[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
