In [None]:
import numpy as np
from numpy import ndarray

In [None]:
# Create a 4x4 gridworld

from dataclasses import dataclass, field

@dataclass
class GridCell():
    index: int
    value: float = 0.0
    optimal_action: int = 0 # up = 0, right = 1, down = 2, left = 3
    #                  up   right   down    left
    next_states: tuple[int, int,    int,    int] = (0, 0, 0, 0)


# 1.
# Initialization
# V(s) ∈ R and π(s) ∈ A(s) arbitraritly for all s ∈ S [Note: I am setting it all to 0 and all actions are equally optimal]


gridworld = []

#gridworld.append(GridCell(0,    0.0,    0, (0, 0, 0, 0)))
gridworld.append(GridCell(1,    0.0,    0, (1, 2, 5, 0)))
gridworld.append(GridCell(2,    0.0,    0, (2, 3, 6, 1)))
gridworld.append(GridCell(3,    0.0,    0, (3, 3, 7, 2)))

gridworld.append(GridCell(4,    0.0,    0, (0, 5, 8, 4)))
gridworld.append(GridCell(5,    0.0,    0, (1, 6, 9, 4)))
gridworld.append(GridCell(6,    0.0,    0, (2, 7, 10, 5)))
gridworld.append(GridCell(7,    0.0,    0, (3, 7, 11, 6)))

gridworld.append(GridCell(8,    0.0,    0, (4, 9, 12, 8)))
gridworld.append(GridCell(9,    0.0,    0, (5, 10, 13, 8)))
gridworld.append(GridCell(10,   0.0,    0, (6, 11, 14, 9)))
gridworld.append(GridCell(11,   0.0,    0, (7, 11, 0, 10)))

gridworld.append(GridCell(12,   0.0,    0, (8, 13, 12, 12)))
gridworld.append(GridCell(13,   0.0,    0, (9, 14, 13, 12)))
gridworld.append(GridCell(14,   0.0,    0, (10, 0, 14, 13)))
#gridworld.append(GridCell(15,   0.0,    0, (0, 0, 0, 0)))

In [None]:
def get_next_state_value(next_state: int) -> float:
    if next_state == 0:
        value_next_state = 0.0
    else:
        next_state_index = next_state - 1
        value_next_state = gridworld[next_state_index].value
    
    return value_next_state

In [None]:
# 2.
# Policy Evaluation
# Loop:
#   Δ ← 0
#   Loop for each s ∈ S:
#     v ← V(s)
#     V(s) ← ∑_{s´,r} p(s´, r | s, π(s))[r + γV(s´)]
#     Δ ← max(Δ, |v - V(s)|)
# until Δ < θ (a small positive number determining the accuracy of estimation)

def compute_state_value(gridcell: GridCell) -> float:
    reward = -1
    gamma = 0.9

    # Calculate transition probabilites based on the # of optimal action which is 1
    transition_probability = 1

    next_state = gridcell.next_states[gridcell.optimal_action]
    return transition_probability * (reward + (gamma * get_next_state_value(next_state)))
    

def run_policy_evaluation():
    theta = 0.01

    while True:
        delta = 0.0
        gridcell: GridCell
        for gridcell in gridworld:
            v  = gridcell.value
            gridcell.value = compute_state_value(gridcell)
            delta = max(delta, abs(v - gridcell.value))
            print(f"After location ({gridcell.index}) - delta = {delta}")

        if delta < theta:
            break

In [None]:
run_policy_evaluation()

In [None]:
# 3.
# Policy Improvement
# policy-stable ← true
# Loop for each s ∈ S:
#   old-action ← π(s)
#   π(s) ← argmax_{a} ∑_{s`,r} p(s´, r | s, a)[r + γV(s´)]
#   If old-action ≠ π(s), then policy-stable ← false
# If policy-stable, then stop and return V ≈ v_* and π ≈ π_*; else go to 2

def run_policy_improvement() -> bool:
    policy_stable = True
    gridcell: GridCell

    for gridcell in gridworld:
        reward = -1
        # NOTE: Here we use all actions and NOT just the optimal ones
        transition_probability = 1/len(gridcell.next_states)
        gamma = 0.9

        old_optimal_action = gridcell.optimal_action

        # The action(s) for which V(s) has the highest value
        value_states = np.array([0.0, 0.0, 0.0, 0.0], dtype=float)

        for i, next_state in enumerate(gridcell.next_states):
            value_states[i] = transition_probability * (reward + (gamma * get_next_state_value(next_state)))
        
        max_values = np.max(value_states)
        ties = np.flatnonzero(value_states == max_values)

        if old_optimal_action in ties:
            gridcell.optimal_action = old_optimal_action
        else:
            gridcell.optimal_action = int(np.random.choice(ties))


        if old_optimal_action != gridcell.optimal_action:
            policy_stable = False
            

    return policy_stable

In [None]:
policy_stable = False

while not policy_stable:
    run_policy_evaluation()
    policy_stable = run_policy_improvement()

In [None]:
def print_grid_world_optimal_policy():
    direction = ['↑', '→', '↓', '←']
    gridworld_rep = np.array(['⊠', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '⊠'])
    gridcell: GridCell
    for gridcell in gridworld:
        gridworld_rep[gridcell.index] = direction[gridcell.optimal_action]

    print(gridworld_rep.reshape(4, 4))

def print_grid_world_value_states():
    gridworld_rep = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    gridcell: GridCell
    for gridcell in gridworld:
        gridworld_rep[gridcell.index] = gridcell.value
    print(gridworld_rep.reshape(4, 4))

print_grid_world_optimal_policy()
print_grid_world_value_states()