In [65]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Value function

In [66]:

# Setup: 4 states in a line: 0 -> 1 -> 2 -> 3 (goal)
states = 4
actions = 2
num_rewards = 2
γ = 0.9

possible_rewards = torch.tensor([0.0, 1.0])

# Joint distribution: p(s', r | s, a)
joint_dist = torch.zeros(states, actions, states, num_rewards)

# State 0
joint_dist[0, 0, 0, 0] = 1.0  # left -> stay, reward 0
joint_dist[0, 1, 1, 0] = 1.0  # right -> 1, reward 0

# State 1
joint_dist[1, 0, 0, 0] = 1.0  # left -> 0, reward 0
joint_dist[1, 1, 2, 0] = 1.0  # right -> 2, reward 0

# State 2
joint_dist[2, 0, 1, 0] = 1.0  # left -> 1, reward 0
joint_dist[2, 1, 3, 0] = 0.3  # right -> 3, reward 0
joint_dist[2, 1, 3, 1] = 0.7  # right -> 3, reward 1 (stochastic!)

# State 3 (goal) --> # any action just stays at 3 --> absorbing state!
joint_dist[3, 0, 3, 0] = 1.0
joint_dist[3, 1, 3, 0] = 1.0


# uniform distribution for policy
policy = torch.ones(states, actions) / actions

v = torch.zeros(states)


In [70]:
# Value function
# v_new = r + γ * P @ v

def state_value(s):
    # First sum: r[s] --> sum over all actions and all future states and rewards
    immediate_reward = 0.0
    for a in range(actions):
        for s_prime in range(states):
            for r_idx in range(num_rewards):
                prob = joint_dist[s, a, s_prime, r_idx]
                reward = possible_rewards[r_idx]
                immediate_reward += policy[s, a] * prob * reward

    # Second sum: (P @ v)[s] --> sum over all actions and all future states and rewards
    future_value = 0.0
    for a in range(actions):
        for s_prime in range(states):
            for r_idx in range(num_rewards):
                prob = joint_dist[s, a, s_prime, r_idx]
                Ps_s_prime = policy[s, a] * prob
                future_value += Ps_s_prime * v[s_prime]

    s_value = immediate_reward + γ * future_value
    return s_value


# # this is doing the same thing, but it is calculating everything all at once using the "unified formula" 3.14
# def state_value_compact(s):
#     total = 0.0
#     for a in range(actions):
#         for s_prime in range(states):
#             for r_idx in range(num_rewards):
#                 prob_ = joint_dist[s, a, s_prime, r_idx]
#                 reward_ = possible_rewards[r_idx]
#                 policy_ = policy[s, a]
#                 value_function_ = v[s_prime]
#
#                 total += policy_ * prob_ * (reward_ + γ *value_function_ )
#     return total


In [68]:

for iteration in range(100):
    v_new = torch.zeros(states)

    for s in range(states):
        v_new[s] = state_value(s)

    v = v_new


v


tensor([0.3002, 0.3669, 0.5151, 0.0000])

In [69]:
P = torch.zeros(states, states)
r = torch.zeros(states)


for s in range(states):
    for a in range(actions):
        for s_prime in range(states):
            for r_idx in range(num_rewards):
                prob = joint_dist[s, a, s_prime, r_idx]
                Ps_s_prime = policy[s, a] * prob
                P[s, s_prime] += Ps_s_prime


for s in range(states):
    for a in range(actions):
        for s_prime in range(states):
            for r_idx in range(num_rewards):
                prob = joint_dist[s, a, s_prime, r_idx]
                reward_val = possible_rewards[r_idx]
                r[s] += policy[s, a] * prob * reward_val

# Solve
I = torch.eye(states)
A = I - γ * P
v_direct = torch.linalg.solve(A, r)

v_direct




# P = torch.zeros(states, states)
# r = torch.zeros(states)
#
# for s in range(states):
#     for a in range(actions):
#         for s_prime in range(states):
#             for r_idx in range(num_rewards):
#                 prob = joint_dist[s, a, s_prime, r_idx]
#                 reward_val = possible_rewards[r_idx]
#                 P[s, s_prime] += policy[s, a] * prob
#                 r[s] += policy[s, a] * prob * reward_val
#
# I = torch.eye(states)
# v_direct = torch.linalg.solve(I - γ * P, r)
# v_direct # --> calculated the same way but all in one


tensor([0.3002, 0.3669, 0.5151, 0.0000])