In [52]:
import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
from enum import Enum
from random import choice

In [58]:
class Options(Enum):
    # Primitive options
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3
    
    # Multi-step options
    OUT_11 = 4
    OUT_12 = 5
    OUT_21 = 6
    OUT_22 = 7
    OUT_31 = 8
    OUT_32 = 9
    OUT_41 = 10
    OUT_42 = 11
    

class Hallway:
    def __init__(self, start_position, goals):
        
        self._start_position = start_position
        self._goals = goals
        
        
        self.primitive_options = [Options.UP, Options.DOWN, Options.LEFT, Options.RIGHT]
        self.multistep_options = [Options.OUT_11,
                                   Options.OUT_12,
                                   Options.OUT_21,
                                   Options.OUT_22,
                                   Options.OUT_31,
                                   Options.OUT_32,
                                   Options.OUT_41,
                                   Options.OUT_42]

        self._all_options = self.primitive_options + self.multistep_options
        
        self._current_state = list(self._start_position)
        self._grid = self.__get_grid()
        self._states = self.__get_states()
        
        if self._start_position not in self._states:
            raise Exception('Invalid start position')
            
        for goal in self._goals:
            if goal not in self._states:
                raise Exception('Invalid goal')
        
    def reset(self):
        self._current_state = list(self._start_position)
        self._grid = self.__get_grid()
        self._states = self.__get_states()
        
        
    def step(self, option):
        
        if option in self.primitive_options:
            next_state, r = self.__primitive_step(self, option)
        else:
            r = 0
            next_state = self._current_state
        
        self._current_state = next_state
        return next_state, r
    
    def render(self):
        fig, ax = plt.subplots()
        ax.imshow(self._grid)

        # draw gridlines
        ax.grid(which='major', axis='both', linestyle='-', color='k', linewidth=2)
        ax.set_xticks(np.arange(-0.5, 12, 1))
        ax.set_yticks(np.arange(-0.5, 12, 1))

        ax.set_yticklabels([])
        ax.set_xticklabels([])
        plt.show()
        
    def __primitive_step(self, option):
        real_option = option
        if np.random.random_sample() <= 1./3.:
            real_option = choice([op for op in self.primitive_options if op != option])
                
        if real_option == Options.UP:
            new_state = [self._current_state[0] - 1, self._current_state[1]]
        elif real_option == Options.DOWN:
            new_state = [self._current_state[0] + 1, self._current_state[1]]
        elif real_option == Options.LEFT:
            new_state = [self._current_state[0], self._current_state[1] - 1]
        elif real_option == Options.RIGHT:
            new_state = [self._current_state[0], self._current_state[1] + 1]
        else:
            raise Exception('Invalid option called')
            
        
        if new_state in self._states:
            if new_state in self._goals:
                return new_state, +1
            else:
                return new_state, 0
            
        return self._current_state, 0
        
        
    def __get_states(self):
        states = set()
        for row in range(len(self._grid)):
            for column in range(len(self._grid[0])):
                if self._grid[row][column] > 0:
                    states.add((row,column))
        
    def __get_grid(self):
        return [
            [0,0,0,0,0,0,0,0,0,0,0,0,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,2,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,0,2,0,0,0,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,0,0,2,0,0,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,1,1,1,1,1,2,1,1,1,1,1,0],
            [0,1,1,1,1,1,0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0,0,0,0,0,0,0]
        ]


In [59]:
hall = Hallway([0,0], [[7,9]])

TypeError: argument of type 'NoneType' is not iterable