In [1]:
import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
from typing import List, Tuple

In [2]:
ROWS = 4
COLS = 4
type State = Tuple[int, int]
type Reward = int

Actions = [(-1, 0), (0, 1), (1, 0), (0, -1)]

class GridWorld:
    def __init__(self, rows=ROWS, cols=COLS):
        self._rows = rows
        self._cols = cols

        self._grid = np.array((rows, cols))
        self._terminal_states = [(0, 0), (rows-1, cols-1)]

    def get_actions(self, state: State) -> List[Tuple[int, State, Reward]]:
        if state in self._terminal_states:
            return []

        output = []
        for i, action in enumerate(Actions):
            new_r = state[0] + action[0]
            new_c = state[1] + action[1]

            new_state = (new_r, new_c)
            if self._is_valid_state(new_state):
                output.append((i, new_state, -1))
            else:
                output.append((i, state, -1))
        return output

    def _is_valid_state(self, state) -> bool:
        if 0 <= state[0] < self._rows and 0 <= state[1] < self._cols:
            return True
        return False

In [3]:
THRESHOLD = 0.1
gridworld = GridWorld()

# rows * cols * 4 grid is a policy, each cell represents p(a|s)
type Policy = np.Array

equirandom = np.zeros((ROWS, COLS, 4)) + 0.25
equirandom[0, 0] = 0
equirandom[ROWS-1, COLS-1] = 0

def policy_evaluation_inplace(policy: Policy, threshold=THRESHOLD):
    V = np.zeros((ROWS, COLS))
    iterations = 0
    while True:
        
        delta = 0

        for r in range(ROWS):
            for c in range(COLS):
                old = V[r, c]
                actions = gridworld.get_actions((r, c))
                #print(r, c, actions)
                # the double sum from the bellman equation collapses here, since there is only one s',r for each action
                V[r,c] = sum([policy[r, c, action]*(reward + V[new_state]) for action, new_state, reward in actions])
                delta = max(delta, abs(old-V[r,c]))
        
        iterations += 1

        if (iterations %100 == 0):
            print(V)
        if delta < threshold:
            print(iterations)

            break
    return V


In [15]:
policy_evaluation_inplace(equirandom)

36


array([[  0.        , -13.36554418, -19.08788486, -20.99604821],
       [-13.36554418, -17.22143946, -19.15466216, -19.16432809],
       [-19.08788486, -19.15466216, -17.28668965, -13.46743372],
       [-20.99604821, -19.16432809, -13.46743372,   0.        ]])

In [16]:
def policy_improvement(old_policy: Policy, V) -> Tuple[bool, Policy]:
    # select a deterministic policy that is greedy wrt V, given a deterministic old_policy
    new_policy = np.zeros((ROWS, COLS, 4))
    stable_policy = True
    for r in range(ROWS):
        for c in range(COLS):
            old_greedy = np.argmax(old_policy[r, c])
            
            actions = gridworld.get_actions((r, c))
            if not actions:
                continue
                                           
            one_step_value = {i: reward + V[new_state] for i, new_state, reward in actions}

            greedy_action = max(one_step_value, key=one_step_value.get)
            new_policy[r, c, greedy_action] = 1

            # make sure that we actually have something different. 
            # There can be multiple greedy actions, don't want to get stuck in a loop
            if abs(one_step_value[old_greedy] - one_step_value[greedy_action]) >= THRESHOLD:
                stable_policy = False

            
    return stable_policy, new_policy

            

In [17]:
def policy_iteration():
    # everything goes up, lol
    policy = np.zeros((ROWS, COLS, 4))
    policy[:, :, 0] = 1

    stable_policy = False
    while not stable_policy:
        V = policy_evaluation_inplace(policy)
        print(V)
        stable_policy, policy = policy_improvement(policy, V)
        print("======")

    print("Done Policy Iteration!")
    print(policy)
    print(V)
    

In [18]:
policy_iteration()

KeyboardInterrupt: 

In [None]:
A = np.array([[1, 2, 3], [4, 5, 6]])

In [None]:
state = (1,1)

In [None]:
A[state]

In [5]:
policy = np.zeros((ROWS, COLS, 4))
policy[:, :, 0] = 1
policy

array([[[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]],

       [[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]],

       [[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]],

       [[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]]])

In [6]:
policy_evaluation_inplace(policy)

[[   0. -100. -100. -100.]
 [  -1. -101. -101. -101.]
 [  -2. -102. -102. -102.]
 [  -3. -103. -103.    0.]]
[[   0. -200. -200. -200.]
 [  -1. -201. -201. -201.]
 [  -2. -202. -202. -202.]
 [  -3. -203. -203.    0.]]
[[   0. -300. -300. -300.]
 [  -1. -301. -301. -301.]
 [  -2. -302. -302. -302.]
 [  -3. -303. -303.    0.]]
[[   0. -400. -400. -400.]
 [  -1. -401. -401. -401.]
 [  -2. -402. -402. -402.]
 [  -3. -403. -403.    0.]]
[[   0. -500. -500. -500.]
 [  -1. -501. -501. -501.]
 [  -2. -502. -502. -502.]
 [  -3. -503. -503.    0.]]
[[   0. -600. -600. -600.]
 [  -1. -601. -601. -601.]
 [  -2. -602. -602. -602.]
 [  -3. -603. -603.    0.]]
[[   0. -700. -700. -700.]
 [  -1. -701. -701. -701.]
 [  -2. -702. -702. -702.]
 [  -3. -703. -703.    0.]]
[[   0. -800. -800. -800.]
 [  -1. -801. -801. -801.]
 [  -2. -802. -802. -802.]
 [  -3. -803. -803.    0.]]
[[   0. -900. -900. -900.]
 [  -1. -901. -901. -901.]
 [  -2. -902. -902. -902.]
 [  -3. -903. -903.    0.]]
[[ 0.000e+00 -1.000

KeyboardInterrupt: 