In [134]:
import numpy as np
from numpy import ndarray

In [135]:
# Create a 4x4 gridworld

from dataclasses import dataclass, field

@dataclass
class GridCell():
    index: int
    value: float = 0.0
    optimal_actions: ndarray = field(default_factory=lambda: np.array([1, 1, 1, 1]))
    #                  up   right   down    left
    next_states: tuple[int, int,    int,    int] = (0, 0, 0, 0)


# 1.
# Initialization
# V(s) ∈ R and π(s) ∈ A(s) arbitraritly for all s ∈ S [Note: I am setting it all to 0 and all actions are equally optimal]


gridworld = []

#gridworld.append(GridCell(0,    0.0,    np.array([1, 1, 1, 1]), (0, 0, 0, 0)))
gridworld.append(GridCell(1,    0.0,    np.array([1, 1, 1, 1]), (1, 2, 5, 0)))
gridworld.append(GridCell(2,    0.0,    np.array([1, 1, 1, 1]), (2, 3, 6, 1)))
gridworld.append(GridCell(3,    0.0,    np.array([1, 1, 1, 1]), (3, 3, 7, 2)))

gridworld.append(GridCell(4,    0.0,    np.array([1, 1, 1, 1]), (0, 5, 8, 4)))
gridworld.append(GridCell(5,    0.0,    np.array([1, 1, 1, 1]), (1, 6, 9, 4)))
gridworld.append(GridCell(6,    0.0,    np.array([1, 1, 1, 1]), (2, 7, 10, 5)))
gridworld.append(GridCell(7,    0.0,    np.array([1, 1, 1, 1]), (3, 7, 11, 6)))

gridworld.append(GridCell(8,    0.0,    np.array([1, 1, 1, 1]), (4, 9, 12, 8)))
gridworld.append(GridCell(9,    0.0,    np.array([1, 1, 1, 1]), (5, 10, 13, 8)))
gridworld.append(GridCell(10,   0.0,    np.array([1, 1, 1, 1]), (6, 11, 14, 9)))
gridworld.append(GridCell(11,   0.0,    np.array([1, 1, 1, 1]), (7, 11, 0, 10)))

gridworld.append(GridCell(12,   0.0,    np.array([1, 1, 1, 1]), (8, 13, 12, 12)))
gridworld.append(GridCell(13,   0.0,    np.array([1, 1, 1, 1]), (9, 14, 13, 12)))
gridworld.append(GridCell(14,   0.0,    np.array([1, 1, 1, 1]), (10, 0, 14, 13)))
#gridworld.append(GridCell(15,   0.0,    np.array([1, 1, 1, 1]), (0, 0, 0, 0)))

In [136]:
def get_next_state_value(next_state: int) -> float:
    if next_state == 0:
        value_next_state = 0.0
    else:
        next_state_index = next_state - 1
        value_next_state = gridworld[next_state_index].value
    
    return value_next_state

In [137]:
# 2.
# Policy Evaluation
# Loop:
#   Δ ← 0
#   Loop for each s ∈ S:
#     v ← V(s)
#     V(s) ← ∑_{s´,r} p(s´, r | s, π(s))[r + γV(s´)]
#     Δ ← max(Δ, |v - V(s)|)
# until Δ < θ (a small positive number determining the accuracy of estimation)

def compute_state_value(gridcell: GridCell) -> float:
    reward = -1
    gamma = 0.9

    state_value = 0.0
    
    # Retrieve optimal actions for the state
    optimal_action_indices = np.argwhere(gridcell.optimal_actions)

    # Calculate transition probabilites based on the # of optimal actions
    transition_probability = 1/len(optimal_action_indices)

    for optimal_action_index in optimal_action_indices:
        action_index = optimal_action_index.item()
        next_state = gridcell.next_states[action_index]
    
        # Avoid changing gridcell.value inplace
        state_value += transition_probability * (reward + (gamma * get_next_state_value(next_state)))

    return state_value
    

def run_policy_evaluation():
    theta = 0.01

    while True:
        delta = 0.0
        gridcell: GridCell
        for gridcell in gridworld:
            v  = gridcell.value
            gridcell.value = compute_state_value(gridcell)
            delta = max(delta, abs(v - gridcell.value))

        if delta < theta:
            break


In [138]:
# 3.
# Policy Improvement
# policy-stable ← true
# Loop for each s ∈ S:
#   old-action ← π(s)
#   π(s) ← argmax_{a} ∑_{s`,r} p(s´, r | s, a)[r + γV(s´)]
#   If old-action ≠ π(s), then policy-stable ← false
# If policy-stable, then stop and return V ≈ v_* and π ≈ π_*; else go to 2

def run_policy_improvement() -> bool:
    policy_stable = True
    gridcell: GridCell

    for gridcell in gridworld:
        reward = -1
        # NOTE: Here we use all actions and NOT just the optimal ones
        transition_probability = 1/len(gridcell.next_states)
        gamma = 1

        old_optimal_actions = gridcell.optimal_actions

        # The action(s) for which V(s) has the highest value
        value_states = np.array([0.0, 0.0, 0.0, 0.0], dtype=float)

        for i, next_state in enumerate(gridcell.next_states):
            value_states[i] = transition_probability * (reward + (gamma * get_next_state_value(next_state)))
        
        max_values = np.max(value_states)
        ties = np.flatnonzero(value_states == max_values)
        gridcell.optimal_actions = np.zeros(4, dtype=int)
        gridcell.optimal_actions[ties] = 1
        
        if not np.array_equal(old_optimal_actions, gridcell.optimal_actions):
            policy_stable = False
            

    return policy_stable

In [139]:
policy_stable = False

while not policy_stable:
    run_policy_evaluation()
    policy_stable = run_policy_improvement()

In [None]:
def print_grid_world_optimal_policy():
    direction = ['↑', '→', '↓', '←']
    gridworld_rep = np.array(['⊠', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '⊠'], dtype=object)
    gridcell: GridCell
    for gridcell in gridworld:
        optimal_action_indices = np.argwhere(gridcell.optimal_actions)
        for optimal_action_index in optimal_action_indices:
            oa_index = optimal_action_index.item()
            if gridworld_rep[gridcell.index] == '0':
                gridworld_rep[gridcell.index] = direction[oa_index]
            else:
                gridworld_rep[gridcell.index] += direction[oa_index]

    print(gridworld_rep.reshape(4, 4))

def print_grid_world_value_states():
    gridworld_rep = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    gridcell: GridCell
    for gridcell in gridworld:
        gridworld_rep[gridcell.index] = gridcell.value
    print(gridworld_rep.reshape(4, 4))

print_grid_world_optimal_policy()
print_grid_world_value_states()

[['⊠' '←' '←' '↓←']
 ['↑' '↑←' '↑→↓←' '↓']
 ['↑' '↑→↓←' '→↓' '↓']
 ['↑→' '→' '→' '⊠']]
