In [2]:
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from tqdm.notebook import tqdm

In [None]:
def initialize_q_table(env):
    """
    Initialize and return the Q-table as a numpy array.

    The Q-table should have dimensions [number_of_states, number_of_actions].
    Each entry corresponds to the estimated value for taking an action in a given state.

    Parameters:
    - env (gym.Env): The environment from which to derive the number of states and actions.

    Returns:
    - q_table (np.ndarray): The initialized Q-table (e.g., zeros) with shape (n_states, n_actions).
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    q_table = np.zeros((n_states, n_actions))
    return q_table


def epsilon_greedy_policy(state, q_table, epsilon):
    """
    Choose an action using the epsilon-greedy policy.

    With probability epsilon, a random action is selected (exploration).
    With probability 1 - epsilon, the action with the highest Q-value for the current state is selected (exploitation).

    Parameters:
    - state (int): The current state.
    - q_table (np.ndarray): The current Q-table.
    - epsilon (float): The probability of choosing a random action.

    Returns:
    - action (int): The chosen action.
    """
    if np.random.random()<epsilon:
        return np.random.randint(0, len(q_table[state]))
    else:
        return int(np.argmax(q_table[state]))


def sarsa_update(q_table, state, action, reward, next_state, next_action, alpha, gamma):
    """
    Perform the SARSA update on the Q-table.

    The SARSA update rule is:
    Q(s, a) = Q(s, a) + alpha * [reward + gamma * Q(s', a') - Q(s, a)]
    
    Parameters:
    - q_table (np.ndarray): The current Q-table.
    - state (int): The current state.
    - action (int): The action taken in the current state.
    - reward (float): The reward received after taking the action.
    - next_state (int): The next state reached.
    - next_action (int): The next action chosen from the next state.
    - alpha (float): The learning rate.
    - gamma (float): The discount factor.
    """
    # Your code here to update the Q-table using the SARSA update rule
    q_table[state][action]+=alpha*(reward+gamma*q_table[next_state][next_action]-q_table[state][action])
    return


def train_sarsa(env, num_episodes, alpha, gamma):
    """s
    Train the agent using the SARSA algorithm over a given number of episodes.

    For each episode:
      - Reset the environment.
      - Choose the initial action using the epsilon-greedy policy.
      - For each time step, perform the SARSA update until the episode terminates.
      - Optionally decay epsilon after each episode.

    Parameters:
    - env (gym.Env): The environment to train on.
    - num_episodes (int): Total number of episodes for training.
    - alpha (float): The learning rate.
    - gamma (float): The discount factor.

    Returns:
    - q_table (np.ndarray): The trained Q-table.
    - rewards (list): A list of total rewards per episode (useful for monitoring learning progress).
    """
    # Your code here to implement the training loop using SARSA
    q_table = initialize_q_table(env)
    rewards = []

    for eps_id in tqdm(range(num_episodes)):
        rewards.append(0)
        state, _ = env.reset()
        epsilon = 1.0 / (eps_id/10000 + 1) if eps_id < num_episodes*0.9 else 0
        # epsilon = 1.0 / (eps_id + 1)
        action = epsilon_greedy_policy(state, q_table, epsilon)
        for _ in range(1000):
            next_state, reward, is_done, is_trunc, _ = env.step(action)
            next_action = epsilon_greedy_policy(next_state, q_table, epsilon)
            sarsa_update(q_table, state, action, reward, next_state, next_action, alpha, gamma)
            rewards[-1]+=reward
            if is_done or is_trunc:
                break
            state = next_state
            action = next_action

        # Print progess every 5000 episodes
        if (eps_id+1)%5000==0:
            avg_reward = np.mean(rewards[-5000:])
            print(f"Episode {eps_id + 1}, Average Reward (last 5000 episodes): {avg_reward:.2f}")

    return q_table, rewards    

In [None]:
num_episodes = 60000

env = gym.make("FrozenLake-v1", render_mode="rgb_array", is_slippery=False)
obs, _ = env.reset()
q_table, rewards = train_sarsa(env, num_episodes, alpha=0.1, gamma=0.99)
env.close()

  0%|          | 0/60000 [00:00<?, ?it/s]

Episode 5000, Average Reward (last 5000 episodes): 0.10
Episode 10000, Average Reward (last 5000 episodes): 0.30
Episode 15000, Average Reward (last 5000 episodes): 0.46
Episode 20000, Average Reward (last 5000 episodes): 0.56
Episode 25000, Average Reward (last 5000 episodes): 0.63
Episode 30000, Average Reward (last 5000 episodes): 0.68
Episode 35000, Average Reward (last 5000 episodes): 0.72
Episode 40000, Average Reward (last 5000 episodes): 0.74
Episode 45000, Average Reward (last 5000 episodes): 0.77
Episode 50000, Average Reward (last 5000 episodes): 0.79
Episode 55000, Average Reward (last 5000 episodes): 0.84
Episode 60000, Average Reward (last 5000 episodes): 1.00


In [None]:

# def evaluate_policy(env, q_table, num_episodes, render=False):
#     """
#     Evaluate the learned policy derived from the Q-table over a number of episodes.

#     For each episode:
#       - Reset the environment.
#       - At each time step, choose the action with the highest Q-value (no exploration).
#       - Optionally render the environment to visualize the agent's behavior.
#       - Accumulate total rewards to measure policy performance.

#     Parameters:
#     - env (gym.Env): The environment to evaluate on.
#     - q_table (np.ndarray): The Q-table representing the learned policy.
#     - num_episodes (int): The number of evaluation episodes.
#     - render (bool): Whether to render the environment during evaluation.

#     Returns:
#     - average_reward (float): The average reward over all evaluation episodes.
#     """
#     # Your code here to evaluate the learned policy
#     pass

In [8]:
env = gym.make("FrozenLake-v1", render_mode="rgb_array", is_slippery=False)
env = RecordVideo(env, video_folder="./videos")
obs, _ = env.reset()
for _ in range(1000):
    env.render()
    action = epsilon_greedy_policy(obs, q_table, 0)
    next_obs, reward, is_done, is_trunc, _ = env.step(action)
    if is_done or is_trunc:
        break
    obs = next_obs
env.close()