
# TD Control Demo — SARSA vs Q-Learning (Cliff-Walking)

This notebook runs **SARSA** (on-policy) and **Q-learning** (off-policy) on a self-contained **Cliff-Walking** environment and plots learning curves.  
It uses the `ch7_td_control` package you just downloaded.


In [None]:

# If running on Colab or another environment, ensure the folder is on sys.path
import sys, os
pkg_dir = '/mnt/data'  # adjust if needed, e.g., '.' if `ch7_td_control/` is in the working dir
if pkg_dir not in sys.path:
    sys.path.insert(0, pkg_dir)

# Verify package location
!ls -la /mnt/data | head -n 50


In [None]:

from ch7_td_control.cliff_env import CliffWalkingEnv
from ch7_td_control.policies import fixed_epsilon, linear_decay, exp_decay, inverse_glie
from ch7_td_control.sarsa import sarsa
from ch7_td_control.q_learning import q_learning
from ch7_td_control.plot_utils import plot_learning_curves

episodes = 500
alpha = 0.1
gamma = 0.99
seed = 0

# Choose an exploration schedule here:
# sched = fixed_epsilon(0.1)
# sched = linear_decay(eps0=0.3, eps_min=0.05, T=episodes)
# sched = exp_decay(eps0=0.3, k=0.01, eps_min=0.05)
sched = inverse_glie(c=1.0)  # GLIE-like


In [None]:

env = CliffWalkingEnv(seed=seed)

sarsa_Q, sarsa_returns = sarsa(env, episodes=episodes, alpha=alpha, gamma=gamma, eps_schedule=sched, seed=seed)
ql_Q, ql_returns       = q_learning(env, episodes=episodes, alpha=alpha, gamma=gamma, eps_schedule=sched, seed=seed)

plot_learning_curves(list(range(episodes)), sarsa_returns, ql_returns, ma_k=20,
                     title="Cliff-Walking: SARSA (safer transients) vs Q-learning (greedy target)")



## Visualize Greedy Paths from the Learned Q-tables

Below we extract greedy policies from the learned Q-tables and roll them out.  
`A` = agent, `S` = start, `G` = goal, `X` = cliff cells, `*` = visited path.


In [None]:

from typing import Tuple, Dict

def greedy_action(Q: Dict[Tuple[Tuple[int,int], int], float], s, n_actions=4):
    best_a, best_q = 0, float("-inf")
    for a in range(n_actions):
        q = Q.get((s,a), 0.0)
        if q > best_q:
            best_q, best_a = q, a
    return best_a

def rollout_greedy(env, Q, max_steps=200):
    s = env.reset()
    path = [s]
    for _ in range(max_steps):
        a = greedy_action(Q, s, n_actions=env.action_space.n)
        s, r, done, _ = env.step(a)
        path.append(s)
        if done:
            break
    return path

# Rollout & render
env = CliffWalkingEnv(seed=seed)
path_sarsa = rollout_greedy(env, sarsa_Q)
print("SARSA greedy rollout:")
env.render(path_sarsa)

env = CliffWalkingEnv(seed=seed)
path_ql = rollout_greedy(env, ql_Q)
print("Q-learning greedy rollout:")
env.render(path_ql)
