In [1]:
import numpy as np
import matplotlib.pyplot as plt

# Define the environment
num_states = 5  # Number of states in the 1D grid
trap_state = 0  # Trap position
goal_state = 4  # Goal position
start_state = 2  # Starting position

actions = [-1, 1]  # Possible actions: left (-1), right (+1)

# Rewards
rewards = np.zeros(num_states)
rewards[trap_state] = -1  # Negative reward for the trap
rewards[goal_state] = 1   # Positive reward for the goal

# Hyperparameters
gamma = 0.9  # Discount factor
alpha = 0.1  # Learning rate
epsilon = 0.1  # Exploration rate
num_episodes = 500  # Number of episodes

# Initialize Q-table
Q = np.zeros((num_states, len(actions)))

def epsilon_greedy(state, epsilon):
    if np.random.rand() < epsilon:
        return np.random.choice(len(actions))  # Explore: random action
    else:
        return np.argmax(Q[state])  # Exploit: best action

# Training loop
for episode in range(num_episodes):
    state = start_state
    while state != trap_state and state != goal_state:
        action_idx = epsilon_greedy(state, epsilon)
        action = actions[action_idx]

        # Take action and observe next state and reward
        next_state = max(0, min(num_states - 1, state + action))
        reward = rewards[next_state] - 0.01  # Include move cost

        # Update Q-value
        best_next_action = np.max(Q[next_state])
        Q[state, action_idx] += alpha * (reward + gamma * best_next_action - Q[state, action_idx])

        # Move to the next state
        state = next_state

# Derive policy from Q-table
policy = np.argmax(Q, axis=1)
policy_actions = [actions[a] for a in policy]

# Display results
print("Optimal Q-table:")
print(Q)
print("\nOptimal Policy:")
print(policy_actions)

# Visualization
def visualize_policy(policy):
    grid = ['Trap', ' ', 'Start', ' ', 'Goal']
    arrows = ['←' if action == -1 else '→' for action in policy]
    for i, arrow in enumerate(arrows):
        if i == trap_state:
            continue
        grid[i] = arrow
    return grid

visualized_policy = visualize_policy(policy)
print("\nVisualized Policy:")
print(visualized_policy)


Optimal Q-table:
[[ 0.          0.        ]
 [-0.347339    0.72422646]
 [ 0.51093028  0.881     ]
 [ 0.71071519  0.99      ]
 [ 0.          0.        ]]

Optimal Policy:
[-1, 1, 1, 1, -1]

Visualized Policy:
['Trap', '→', '→', '→', '→']
