In [1]:
import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
from typing import List, Tuple

In [2]:
ROWS = 4
COLS = 4
type State = Tuple[int, int]
type Reward = int

Actions = [(-1, 0), (0, 1), (1, 0), (0, -1)]

class GridWorld:
    def __init__(self, rows=ROWS, cols=COLS):
        self._rows = rows
        self._cols = cols

        self._grid = np.array((rows, cols))
        self._terminal_states = [(0, 0), (rows-1, cols-1)]

    def get_actions(self, state: State) -> List[Tuple[int, State, Reward]]:
        if state in self._terminal_states:
            return []

        output = []
        for i, action in enumerate(Actions):
            new_r = state[0] + action[0]
            new_c = state[1] + action[1]

            new_state = (new_r, new_c)
            if self._is_valid_state(new_state):
                output.append((i, new_state, -1))
            else:
                output.append((i, state, -1))
        return output

    def _is_valid_state(self, state) -> bool:
        if 0 <= state[0] < self._rows and 0 <= state[1] < self._cols:
            return True
        return False

In [12]:
THRESHOLD = 1e-6
gridworld = GridWorld()

# rows * cols * 4 grid is a policy, each cell represents p(a|s)
type Policy = np.Array

equirandom = np.zeros((ROWS, COLS, 4)) + 0.25
equirandom[0, 0] = 0
equirandom[ROWS-1, COLS-1] = 0

def policy_evaluation_inplace(policy: Policy, threshold=THRESHOLD):
    V = np.zeros((ROWS, COLS))
    iterations = 0
    while True:
        
        delta = 0

        for r in range(ROWS):
            for c in range(COLS):
                old = V[r, c]
                actions = gridworld.get_actions((r, c))
                #print(r, c, actions)
                # the double sum from the bellman equation collapses here, since there is only one s',r for each action
                V[r,c] = sum([policy[r, c, action]*(reward + V[new_state]) for action, new_state, reward in actions])
                delta = max(delta, abs(old-V[r,c]))
        
        iterations += 1

        if (iterations %100 == 0):
            print(V)
        if delta < threshold:
            print(iterations)

            break
    return V


In [13]:
policy_evaluation_inplace(equirandom)

[[  0.         -13.99765839 -19.99663362 -21.99629468]
 [-13.99765839 -17.99712654 -19.99688008 -19.99691576]
 [-19.99663362 -19.99688008 -17.99736736 -13.99803444]
 [-21.99629468 -19.99691576 -13.99803444   0.        ]]
167


array([[  0.        , -13.99999335, -19.99999044, -21.99998948],
       [-13.99999335, -17.99999184, -19.99999114, -19.99999125],
       [-19.99999044, -19.99999114, -17.99999253, -13.99999442],
       [-21.99998948, -19.99999125, -13.99999442,   0.        ]])

In [14]:
def policy_improvement(old_policy: Policy, V) -> Tuple[bool, Policy]:
    # select a deterministic policy that is greedy wrt V, given a deterministic old_policy
    new_policy = np.zeros((ROWS, COLS, 4))
    stable_policy = True
    for r in range(ROWS):
        for c in range(COLS):
            old_greedy = np.argmax(old_policy[r, c])
            
            actions = gridworld.get_actions((r, c))
            if not actions:
                continue
                                           
            one_step_value = {i: reward + V[new_state] for i, new_state, reward in actions}

            greedy_action = max(one_step_value, key=one_step_value.get)
            new_policy[r, c, greedy_action] = 1

            # make sure that we actually have something different. 
            # There can be multiple greedy actions, don't want to get stuck in a loop
            if abs(one_step_value[old_greedy] - one_step_value[greedy_action]) >= THRESHOLD:
                stable_policy = False

            
    return stable_policy, new_policy

            

In [15]:
def policy_iteration():
    # everything goes up, lol
    equirandom = np.zeros((ROWS, COLS, 4)) + 0.25
    equirandom[0, 0] = 0
    equirandom[ROWS-1, COLS-1] = 0
    
    V = policy_evaluation_inplace(equirandom)
    stable_policy = False
    policy = equirandom
    
    while not stable_policy:
        V = policy_evaluation_inplace(policy)
        print(V)
        stable_policy, policy = policy_improvement(policy, V)
        print("======")

    print("Done Policy Iteration!")
    print(policy)
    print(V)
    

In [16]:
policy_iteration()

[[  0.         -13.99765839 -19.99663362 -21.99629468]
 [-13.99765839 -17.99712654 -19.99688008 -19.99691576]
 [-19.99663362 -19.99688008 -17.99736736 -13.99803444]
 [-21.99629468 -19.99691576 -13.99803444   0.        ]]
167
[[  0.         -13.99765839 -19.99663362 -21.99629468]
 [-13.99765839 -17.99712654 -19.99688008 -19.99691576]
 [-19.99663362 -19.99688008 -17.99736736 -13.99803444]
 [-21.99629468 -19.99691576 -13.99803444   0.        ]]
167
[[  0.         -13.99999335 -19.99999044 -21.99998948]
 [-13.99999335 -17.99999184 -19.99999114 -19.99999125]
 [-19.99999044 -19.99999114 -17.99999253 -13.99999442]
 [-21.99998948 -19.99999125 -13.99999442   0.        ]]
3
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
Done Policy Iteration!
[[[0. 0. 0. 0.]
  [0. 0. 0. 1.]
  [0. 0. 0. 1.]
  [0. 0. 1. 0.]]

 [[1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 0. 1. 0.]]

 [[1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]]

 [[1. 0. 0. 0.]
  [0. 1. 0. 0

In [23]:
def value_iteration(threshold=THRESHOLD):
    V = np.zeros((ROWS, COLS))
    iterations = 0
    while True:
        
        delta = 0

        for r in range(ROWS):
            for c in range(COLS):
                old = V[r, c]
                actions = gridworld.get_actions((r, c))
                if not actions:
                    continue
                # the double sum from the bellman equation collapses here, since there is only one s',r for each action
                V[r,c] = max([(reward + V[new_state]) for _, new_state, reward in actions])
                delta = max(delta, abs(old-V[r,c]))
        
        iterations += 1

        if (iterations % 100 == 0):
            print(V)
        if delta < threshold:
            print(iterations)
            break

    zero_policy = np.zeros((ROWS, COLS, 4))
    _, policy = policy_improvement(zero_policy, V)
    return V, policy

In [24]:
value_iteration()

4


(array([[ 0., -1., -2., -3.],
        [-1., -2., -3., -2.],
        [-2., -3., -2., -1.],
        [-3., -2., -1.,  0.]]),
 array([[[0., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.]],
 
        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.]],
 
        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.]],
 
        [[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 0.]]]))