In [0]:
import random
from operator import add
MIN_DELTA = 1e-4

class GridMarkovDP(object):
    def __init__(self, metadata):
        self.width = metadata['width']
        self.height = metadata['height']
        self.initial_value = metadata['initial_value']
        self.obstacles = metadata['obstacles']
        self.living_cost = metadata['living_cost']

        self.discount = metadata['discount']
        self.transition_distribution = metadata['transition_distribution']
        self.rewards = {tuple(terminal['state']) : terminal['reward'] for terminal in metadata['terminals']}
        self.terminals = list(self.rewards.keys())

        self._init_grid()

        # enumerate state space
        self.states = set()
        for row in range(self.height):
            for col in range(self.width):
                if self.grid[row][col] is not None:
                    self.states.add((row, col))
        
        # move one tile at a time
        self.actions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
        self.num_actions = len(self.actions)

        # initialize values and policy
        self.policy = {}
        self.values = {}
        for state in self.states:
            self.values[state] = self.initial_value
            self.policy[state] = random.choice(self.actions)

    def R(self, state):
        if state in self.terminals:
            return self.rewards[state]
        else:
            # living cost
            return self.living_cost
          
    def _init_grid(self):
        self.grid = [[self.initial_value for col in range(self.width)] for row in range(self.height)]
        # apply obstacles
        for obstacle in self.obstacles:
            self.grid[obstacle[0]][obstacle[1]] = None
            
    def _move_forward(self, state, action):
        new_state = tuple(map(add, state, action))
        return new_state if new_state in self.states else state

    def _move_backward(self, state, action):
        new_action = self.actions[(self.actions.index(action) + 2) % self.num_actions]
        new_state = tuple(map(add, state, new_action))
        return new_state if new_state in self.states else state

    def _move_left(self, state, action):
        new_action = self.actions[(self.actions.index(action) - 1) % self.num_actions]
        new_state = tuple(map(add, state, new_action))
        return new_state if new_state in self.states else state

    def _move_right(self, state, action):
        new_action = self.actions[(self.actions.index(action) + 1) % self.num_actions]
        new_state = tuple(map(add, state, new_action))
        return new_state if new_state in self.states else state 
      
    def allowed_actions(self,state):
      if state in self.terminals:
        return[None]
      else:
        return self.actions
      
     
    def next_state_distribution(self, state, action):
        if action == None:
            return [(0.0, state)]
        else:
            return [(self.transition_distribution['forward'], self._move_forward(state, action)),
                    (self.transition_distribution['left'], self._move_left(state, action)),
                    (self.transition_distribution['right'], self._move_right(state, action)),
                    (self.transition_distribution['backward'], self._move_backward(state, action))]
          
    
    def update_values(self,values):
      self.values=values
      
      
    def update_policy(self,policy):
      self.policy=policy
      
      
      
    def clear(self):
        self._init_grid()
        for state in self.states:
            self.values[state] = self.initial_value
            self.policy[state] = random.choice(self.actions)
            
            
    
      