In [None]:
import numpy as np
from numpy import ndarray
from scipy.stats import poisson

def row_col_to_index(i, j):
    return i * 21 + j

In [None]:
# Create a 4x4 gridworld

from dataclasses import dataclass

@dataclass
class LocationGridCell():
    #            cars_in_location1, cars_in_location2
    state: tuple[int,               int]                = (0, 0)
    value: float = 0.0
    optimal_action: int = 0 # (-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)

# 1.
# Initialization
# V(s) ∈ R and π(s) ∈ A(s) arbitraritly for all s ∈ S [Note: I am setting it all to 0 and all actions are equally optimal]


num_cars_in_locations = []

# Create a 21x21 grid (441 cells) representing states from (0,0) to (20,20)
# Each state (i, j) represents (cars_at_location1, cars_at_location2)
# Index = i * 21 + j

for i in range(21):  # cars_at_location1: 0 to 20
    for j in range(21):  # cars_at_location2: 0 to 20
        index = row_col_to_index(i, j)
        # Initialize next_states as placeholder (same state for all actions)
        # Update these based on your transition logic
        num_cars_in_locations.append(LocationGridCell((i, j)))

In [None]:
def get_next_state_value(loc1: int, loc2: int) -> float:
    next_state_index = row_col_to_index(loc1, loc2)
    value_next_state = num_cars_in_locations[next_state_index].value
    return value_next_state

In [None]:
# 2.
# Policy Evaluation
# Loop:
#   Δ ← 0
#   Loop for each s ∈ S:
#     v ← V(s)
#     V(s) ← ∑_{s´,r} p(s´, r | s, π(s))[r + γV(s´)]
#     Δ ← max(Δ, |v - V(s)|)
# until Δ < θ (a small positive number determining the accuracy of estimation)

def compute_state_value(location_gridcell: LocationGridCell) -> float:
    gamma = 0.9
    num_req_ret_to_consider = 11
    state_value = 0.0

    net_move_from_day_before = location_gridcell.optimal_action # This can be (-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)

    loc1 = location_gridcell.state[0]
    loc1_today = loc1 - net_move_from_day_before
    loc2 = location_gridcell.state[1]
    loc2_today = loc2 + net_move_from_day_before

    # Make sure we don't go negative
    if loc1_today > 0 and loc2_today > 0:
        # The main question to ask here is: What affects the next state?
        # In this case it is request and return for each of our two locations
        #
        # Moreover the transition from one state to the next also depends
        # on how many cars are available at each location after the request
        # and return for a day
        for rq1 in range(num_req_ret_to_consider):
            for re1 in range(num_req_ret_to_consider):
                for rq2 in range(num_req_ret_to_consider):
                    for re2 in range(num_req_ret_to_consider):

                        transition_probability = poisson.pmf(rq1, 3) * poisson.pmf(re1, 3) * poisson.pmf(rq2, 4) * poisson.pmf(re2, 2)

                        actual_rentals_loc1 = min(rq1, loc1_today)
                        actual_rentals_loc2 = min(rq2, loc2_today)

                        reward = (-2 * abs(net_move_from_day_before)) + (10 * actual_rentals_loc1) + (10 * actual_rentals_loc2)

                        new_loc1 = max(0, min(loc1_today - actual_rentals_loc1 + re1, 20))
                        new_loc2 = max(0, min(loc2_today - actual_rentals_loc2 + re2, 20))

                        state_value += transition_probability * (reward + gamma * get_next_state_value(new_loc1, new_loc2))
    
    return state_value
    

def run_policy_evaluation():
    theta = 0.01
    iteration = 1

    while True:
        delta = 0.0
        location_gridcell: LocationGridCell

        for location_gridcell in num_cars_in_locations:
            v  = location_gridcell.value
            location_gridcell.value = compute_state_value(location_gridcell)
            delta = max(delta, abs(v - location_gridcell.value))
        
        print(f"After iteration {iteration} - delta = {delta}")

        if delta < theta:
            print(f"Converged after iteration {iteration}")
            break

        iteration += 1

In [None]:
run_policy_evaluation()

In [None]:
def compute_improved_policy_indices(location_gridcell: LocationGridCell, actions: ndarray[int]) -> ndarray[int]:
    gamma = 0.9
    num_req_ret_to_consider = 11
    
    # The action(s) for which V(s) has the highest value - total #actions = 11 (-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
    value_states = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=float)

    for i, a in enumerate(actions):
        net_move_from_day_before = a

        loc1 = location_gridcell.state[0]
        loc1_today = loc1 - net_move_from_day_before
        loc2 = location_gridcell.state[1]
        loc2_today = loc2 + net_move_from_day_before

        # Make sure we don't go negative
        if loc1_today < 0 or loc2_today < 0:
            value_states[i] = float('-inf')
            continue  # Skip this invalid scenario
        
        state_value = 0.0

        for rq1 in range(num_req_ret_to_consider):
            for re1 in range(num_req_ret_to_consider):
                for rq2 in range(num_req_ret_to_consider):
                    for re2 in range(num_req_ret_to_consider):

                        transition_probability = poisson.pmf(rq1, 3) * poisson.pmf(re1, 3) * poisson.pmf(rq2, 4) * poisson.pmf(re2, 2)

                        actual_rentals_loc1 = min(rq1, loc1_today)
                        actual_rentals_loc2 = min(rq2, loc2_today)

                        reward = (-2 * abs(net_move_from_day_before)) + (10 * actual_rentals_loc1) + (10 * actual_rentals_loc2)

                        new_loc1 = max(0, min(loc1_today - actual_rentals_loc1 + re1, 20))
                        new_loc2 = max(0, min(loc2_today - actual_rentals_loc2 + re2, 20))

                        state_value += transition_probability * (reward + gamma * get_next_state_value(new_loc1, new_loc2))

        value_states[i] = state_value

    max_values = np.max(value_states)
    ties: ndarray[int] = np.flatnonzero(value_states == max_values)
    return ties
    

# 3.
# Policy Improvement
# policy-stable ← true
# Loop for each s ∈ S:
#   old-action ← π(s)
#   π(s) ← argmax_{a} ∑_{s`,r} p(s´, r | s, a)[r + γV(s´)]
#   If old-action ≠ π(s), then policy-stable ← false
# If policy-stable, then stop and return V ≈ v_* and π ≈ π_*; else go to 2

def run_policy_improvement() -> bool:
    policy_stable = True
    location_gridcell: LocationGridCell

    actions = np.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], dtype=int)

    for location_gridcell in num_cars_in_locations:
        old_optimal_action = location_gridcell.optimal_action

        ties = compute_improved_policy_indices(location_gridcell, actions)
        selected_actions = actions[ties]

        if old_optimal_action in selected_actions:
            location_gridcell.optimal_action = old_optimal_action
        else:
            location_gridcell.optimal_action = np.random.choice(selected_actions)


        if old_optimal_action != location_gridcell.optimal_action:
            policy_stable = False
            

    return policy_stable

In [None]:
policy_stable = False

while not policy_stable:
    run_policy_evaluation()
    policy_stable = run_policy_improvement()

In [None]:
def print_grid_world_optimal_policy():
    direction = ['↑', '→', '↓', '←']
    gridworld_rep = np.array(['0'] * 441)  # 21x21 = 441 cells
    gridcell: LocationGridCell
    for gridcell in num_cars_in_locations:
        gridworld_rep[gridcell.index] = direction[gridcell.optimal_action]

    print(gridworld_rep.reshape(21, 21))

def print_grid_world_value_states():
    gridworld_rep = np.array([0.0] * 441)  # 21x21 = 441 cells
    gridcell: LocationGridCell
    for gridcell in num_cars_in_locations:
        gridworld_rep[gridcell.index] = gridcell.value
    print(gridworld_rep.reshape(21, 21))

print_grid_world_optimal_policy()
print_grid_world_value_states()

In [None]:
actions = np.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])
ties = np.array([1, 5, 9])

for i, a in enumerate(actions):
    print(f'{i}:{a}')

print('\n\n\n')
selected_actions = actions[ties]

for a in selected_actions:
    print(f'{a}')

test = np.random.choice(selected_actions)
print(test)
