In [236]:
import numpy as np
import matplotlib.pyplot as plt

In [237]:
class ENV:
      def __init__(self, size: tuple, terminal_states: list):
            
            self.grid = np.zeros(size)
            self.terminal_states = terminal_states
            self.actions = ['up', 'down', 'left', 'right']
            self.agent_position = [3,0]

      def make_move(self, state, action):
            if state in self.terminal_states:
                  return state
            if action == 'left':
                  self.agent_position = [state[0], max(0, state[1]-1)]
                  return [state[0], max(0, state[1]-1)]
            elif action == 'right':
                  self.agent_position = [state[0], min(self.grid.shape[1]-1, state[1]+1)]
                  return [state[0], min(self.grid.shape[1]-1, state[1]+1)]
            elif action == 'up':
                  self.agent_position = [max(state[0]-1, 0), state[1]]
                  return [max(state[0]-1, 0), state[1]]
            elif action == 'down':
                  self.agent_position = [min(self.grid.shape[0]-1, state[0]+1), state[1]]
                  return [min(self.grid.shape[0]-1, state[0]+1), state[1]]

      def get_reward(self, state, action):
            if state in self.terminal_states:
                  return 0, state
            new_state = self.make_move(state, action)
            return -1, new_state

      def display(self, agent = False):
            for i in range(self.grid.shape[0]):
                  temp = []
                  for j in range(self.grid.shape[1]):
                        
                        if [i,j] in self.terminal_states:
                              if agent and self.agent_position == [i,j]:
                                    temp.append('X*')
                              else:
                                    temp.append('X')
                              
                        else :
                              if agent and self.agent_position == [i,j]:
                                    temp.append('*')
                              else:
                                    temp.append('-')
                  print(' '.join(_ for _ in temp))
            


      
            


In [238]:
env = ENV((4,4), [[0,0], [3,3]])

In [239]:
env.display(True)

X - - -
- - - -
- - - -
* - - X


In [240]:
env.get_reward(env.agent_position, 'up')[0]

-1

In [241]:
class Policy:
      def __init__(self, env:ENV):
            self.actions = ['up', 'down', 'left', 'right']
            self.grid = env.grid
            self.V = env.grid
            self.terminal_states = env.terminal_states
            self.policy = {i:{j:{a:0.25 for a in self.actions} if [i,j] not in self.terminal_states else {a: 0 for a in self.actions} for j in range(self.grid.shape[1])  } for i in range(self.grid.shape[0])}
            self.env =  env

      def evaluate_policy(self, eps = 1e-2):
            delta = 0
            old_V = self.V
            err = 1e3
            while err > eps:
                  for i in range(self.grid.shape[0]):
                        for j in range(self.grid.shape[1]):
                              s_prime_reward = [self.env.get_reward([i,j], a) for a in self.env.actions]
                              self.V[i,j] = sum(self.policy[i][j][a]*(r_p + self.V[s_p[0], s_p[1]]) for (a, (r_p, s_p)) in zip(self.env.actions, s_prime_reward))
                  delta = max(delta, abs(old_V.sum() - self.V.sum()))
                  err = delta
                  print(abs(old_V.sum() - self.V.sum()))

      

In [242]:
pi = Policy(env)

In [243]:
pi.V

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [244]:
pi.policy[0][2].keys()

dict_keys(['up', 'down', 'left', 'right'])

In [245]:
pi.evaluate_policy()

0.0


In [246]:
pi.V

array([[ 0.       , -1.       , -1.25     , -1.3125   ],
       [-1.       , -1.5      , -1.6875   , -1.75     ],
       [-1.25     , -1.6875   , -1.84375  , -1.8984375],
       [-1.3125   , -1.75     , -1.8984375,  0.       ]])