# SARSA
SARSA (State-Action-Reward-State-Action) is an on-policy reinforcement learning algorithm used to learn the optimal policy for a Markov Decision Process. At each step, SARSA updates the action-value function (Q-value) based on the current state, the action taken, the reward received, the next state, and the next action chosen by the current policy. The update rule is:

Q(s, a) ← Q(s, a) + α [r + γ Q(s', a') − Q(s, a)]

where:
- s: current state
- a: current action
- r: reward received after taking action a in state s
- s': next state
- a': next action chosen in state s'
- α: learning rate
- γ: discount factor

SARSA is called "on-policy" because it updates its Q-values using the actions actually taken by the current policy, rather than the best possible actions.

## Key Point of SARSA

The main point of the SARSA algorithm is to learn the action-value function (Q-function) for the current policy, enabling the agent to improve its policy over time by balancing exploration and exploitation. SARSA updates its estimates based on the actions actually taken, making it sensitive to the exploration strategy used during learning.

## Relationship to TD(λ)

SARSA applies TD(lambda) prediction methods to state-action pairs, rather than to states
We then need a trace for each action-state pair: E(s,a)

In [8]:
from utils import compress_state, generate_extreme_value_state_image
import numpy as np

import minari
from collections import defaultdict
from IPython.display import display, HTML
import uuid

def sarsa_algorithm(dataset_id, gamma = 0.99, lambda_ = 0.0, alpha = 1.0):
    """
    Implements the SARSA algorithm for reinforcement learning: updates the Q-values based on the agent's actions and rewards.
    Args:
        dataset_id (str): The ID of the dataset to use.
        gamma (float): Discount factor for future rewards.
        lambda_ (float): Eligibility trace decay factor.
        alpha (float): Learning rate for Q-value updates.
    Returns:
        Q (dict): A dictionary mapping state-action pairs to Q-values.

        state_locations (dict): maps state keys to (episode_index, timestep) of first occurrence.

    """
    dataset = minari.load_dataset(dataset_id)
    Q = defaultdict(float)
    state_locations = {}

    for episode_index, episode in enumerate(dataset.iterate_episodes()):
        observations = episode.observations
        rewards = episode.rewards
        actions = episode.actions

        for t in range(len(actions)):
            obs_t = {k: v[t] for k, v in observations.items()}
            obs_tp1 = {k: v[t + 1] for k, v in observations.items()} if t + 1 < len(actions) else None
            reward = rewards[t]
            action = actions[t]

            s_t = compress_state(obs_t)
            s_tp1 = compress_state(obs_tp1) if obs_tp1 is not None else None

            # SARSA update
            if s_tp1 is not None:
                next_action = np.argmax([Q[(s_tp1, a)] for a in range(dataset.action_space.n)])
                Q[(s_t, action)] += alpha * (reward + gamma * Q[(s_tp1, next_action)] - Q[(s_t, action)])
            else:
                # Fix: dataset.action_space is a gym.spaces.Discrete, use range(n) to iterate actions
                next_action = np.argmax([Q[(s_tp1, a)] for a in range(dataset.action_space.n)])
                Q[(s_t, action)] += alpha * (reward + gamma * Q[(s_tp1, next_action)] - Q[(s_t, action)])
        
        if s_t not in state_locations:
                state_locations[s_t] = (episode_index, t)

        # Print Q-value stats
        q_values = np.array(list(Q.values()))
        print("Q-value function statistics:")
        print(f"  Count:       {len(q_values)}")
        print(f"  Min value:   {np.min(q_values):.4f}")
        print(f"  Max value:   {np.max(q_values):.4f}")
        print(f"  Mean value:  {np.mean(q_values):.4f}")
        print(f"  Std dev:     {np.std(q_values):.4f}")
        return Q, state_locations
                

In [9]:
from utils import generate_extreme_value_state_image_from_q_table
dataset_id = "minigrid/BabyAI-Pickup/optimal-fullobs-v0"
output_path = "./minigrid/BabyAI-Pickup/optimal-fullobs-v0/q_learning/highest_value_function.png"

generate_extreme_value_state_image_from_q_table(
    dataset_id=dataset_id,
    output_path=output_path,
    q_fn_generator=sarsa_algorithm,
    highest=True
) # Generate image of highest Q-values

# === Display in notebook ===
cache_buster = uuid.uuid4().hex
HTML(f'<img src="{output_path}?v={cache_buster}" width="400">')

Q-value function statistics:
  Count:       175
  Min value:   0.0000
  Max value:   0.9594
  Mean value:  0.0055
  Std dev:     0.0723
Selected state: ((2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 5, 2, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 7, 4, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 4, 3, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 6, 5, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 5, 0, 2, 5, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 

  from pkg_resources import resource_stream, resource_exists


Sampling rejected: unreachable object at (15, 5)
