In [13]:
import numpy as np
import plotly.graph_objects as go

DEFINING AND SIMULATING THE MDP

In [14]:
# Define the states, actions, and rewards
states = [0, 1, 2, 3]  # Example states
actions = ['up', 'down', 'left', 'right']  # Example actions

In [15]:
# Transition probability matrix: P(s' | s, a)
# # P[next_state][current_state][action] = probability
P = np.zeros((len(states), len(states), len(actions)))

In [16]:
# Reward matrix: R(s, a, s')
R = np.zeros((len(states), len(actions)))

In [17]:
# Populate the transition probabilities and rewards
# (For simplicity, this is just an example setup. Real MDPs will be more complex)
P[1, 0, 0] = 1  # Going up from state 0 goes to state 1
P[0, 1, 1] = 1  # Going down from state 1 goes to state 0

In [18]:
R[0, 0] = 10   # Reward for moving up from state 0 to state 1
R[1, 1] = -10  # Reward for moving down from state 1 to state 0

VALUE ITERATION ALGORITHM

In [19]:
def value_iteration(states, actions, P, R, gamma=0.9, theta=1e-6):
    V = np.zeros(len(states))  # Initialize value function to zero for all states
    policy = np.zeros(len(states), dtype=int)  # Initialize policy (best action for each state)

    while True:
        delta = 0
        for s in range(len(states)):
            v = V[s]  # Store the current value of the state
            # Update value function with the best action's expected value
            V[s] = max(sum(P[next_state, s, a] * (R[s, a] + gamma * V[next_state]) for next_state in range(len(states))) for a in range(len(actions)))
            delta = max(delta, abs(v - V[s]))  # Check convergence

        # Break if the change in value function is smaller than the threshold theta
        if delta < theta:
            break

    # Derive the policy based on the computed value function
    for s in range(len(states)):
        policy[s] = np.argmax([sum(P[next_state, s, a] * (R[s, a] + gamma * V[next_state]) for next_state in range(len(states))) for a in range(len(actions))])

    return V, policy

In [20]:
# Plotting the Value Function
def plot_value_function(V):
    fig = go.Figure(data=go.Scatter(y=V, mode='lines+markers'))
    fig.update_layout(title='Value Function over States',
                      xaxis_title='States',
                      yaxis_title='Value')
    fig.show()

In [21]:
V, policy = value_iteration(states, actions, P, R)
plot_value_function(V)