# LunarLander: SARSA, REINFORCE, REINFORCE + Baseline

**Lecture Slides**

https://drive.google.com/drive/folders/1AlQPE7CJ8NMlhmpuwABtIU1ERTE-lOTY?usp=share_link






In [None]:
# If running on Google Colab, install dependencies
# You can re-run if versions change.
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    !pip -q install swig
    !pip -q install gymnasium[box2d] torch imageio imageio-ffmpeg --upgrade

import numpy as np
import torch, torch.nn as nn, torch.optim as optim
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from IPython.display import Video, display
import tempfile, os, glob, random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m849.2 kB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for box2d-py (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.24.0+cu126 requires torch==2.9.0, but you have torch 2.9.1 which is incompatible.
torchaudio 2.9.0+cu126 requires torch==2.9.0, but you have torch 2.9.1 which is incompatible.[0m[31m
[0mUsing device: cuda


In [None]:
# Utilities: discounted returns and inline video display via RecordVideo
def compute_returns(rewards, gamma):
    G = 0.0
    returns = []
    for r in reversed(rewards):
        G = r + gamma * G
        returns.append(G)
    returns.reverse()
    return returns

def evaluate_and_show_video(env_id, policy_fn, num_episodes=1, name_prefix='eval'):
    tmpdir_obj = tempfile.TemporaryDirectory()
    video_dir = tmpdir_obj.name
    env = gym.make(env_id, render_mode='rgb_array')
    env = RecordVideo(env, video_folder=video_dir, episode_trigger=lambda eid: True, name_prefix=name_prefix)

    returns = []
    for ep in range(num_episodes):
        state, _ = env.reset()
        done = False
        ep_ret = 0.0
        while not done:
            action = policy_fn(state)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_ret += reward
        returns.append(ep_ret)
    env.close()

    # Find the newest mp4 and display it
    mp4s = sorted(glob.glob(os.path.join(video_dir, '*.mp4')))
    if mp4s:
        print(f'Average return over {num_episodes} eval episodes: {np.mean(returns):.2f}')
        display(Video(mp4s[-1], embed=True, html_attributes='controls loop autoplay'))
    else:
        print('No video produced.')
    # tmpdir_obj will be GC'ed; keeping reference prevents premature cleanup until display completes
    return returns, tmpdir_obj


## SARSA

In [None]:
# SARSA network (from sarsa.py)
class QNetwork(nn.Module):
    def __init__(self, state_dim, num_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, num_actions),
        )
    def forward(self, x):
        return self.net(x)

def train_sarsa(
    env_id='LunarLander-v3',
    gamma=0.99, lr=5e-4,
    max_episodes=500, max_steps=1000,
    eps_start=1.0, eps_end=0.05, eps_decay_steps=200_000
):
    env = gym.make(env_id)
    state_dim = env.observation_space.shape[0]
    num_actions = env.action_space.n
    q_net = QNetwork(state_dim, num_actions).to(device)
    optimizer = optim.Adam(q_net.parameters(), lr=lr)

    def get_epsilon(step):
        frac = min(step / eps_decay_steps, 1.0)
        return eps_start + frac * (eps_end - eps_start)

    def select_action(state, epsilon):
        if random.random() < epsilon:
            return env.action_space.sample()
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            q_vals = q_net(s)
            return int(torch.argmax(q_vals, dim=1).item())

    global_step = 0
    for ep in range(1, max_episodes + 1):
        state, _ = env.reset()
        ep_ret = 0.0
        epsilon = get_epsilon(global_step)
        action = select_action(state, epsilon)
        for t in range(max_steps):
            global_step += 1
            epsilon = get_epsilon(global_step)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_ret += reward
            if done:
                next_action = 0
            else:
                next_action = select_action(next_state, epsilon)
            # SARSA update
            s_t = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            q_vals = q_net(s_t)
            q_sa = q_vals[0, action]
            with torch.no_grad():
                if done:
                    target = torch.as_tensor(reward, dtype=torch.float32, device=device)
                else:
                    ns_t = torch.as_tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)
                    next_q = q_net(ns_t)
                    target = reward + gamma * next_q[0, next_action]
            loss = nn.functional.smooth_l1_loss(q_sa, target)
            optimizer.zero_grad(); loss.backward();
            nn.utils.clip_grad_norm_(q_net.parameters(), 10.0); optimizer.step()
            state, action = next_state, next_action
            if done: break
        if ep % 10 == 0:
            print(f'[SARSA] Episode {ep:4d} | Return: {ep_ret:7.2f} | Epsilon: {epsilon:.3f}')
    env.close()
    return q_net

def greedy_policy_sarsa(q_net):
    def _fn(state):
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            q = q_net(s)
            return int(torch.argmax(q, dim=1).item())
    return _fn


### Train SARSA (short run)

In [None]:
# Short demo training; increase episodes for better results.
sarsa_net = train_sarsa(max_episodes=2400)


[SARSA] Episode   10 | Return: -231.77 | Epsilon: 0.996
[SARSA] Episode   20 | Return:    3.56 | Epsilon: 0.990
[SARSA] Episode   30 | Return:  -94.49 | Epsilon: 0.986
[SARSA] Episode   40 | Return: -335.23 | Epsilon: 0.981
[SARSA] Episode   50 | Return: -300.18 | Epsilon: 0.976
[SARSA] Episode   60 | Return:  -96.63 | Epsilon: 0.972
[SARSA] Episode   70 | Return: -104.98 | Epsilon: 0.968
[SARSA] Episode   80 | Return: -614.92 | Epsilon: 0.964
[SARSA] Episode   90 | Return: -123.82 | Epsilon: 0.959
[SARSA] Episode  100 | Return: -144.94 | Epsilon: 0.955
[SARSA] Episode  110 | Return: -298.32 | Epsilon: 0.949
[SARSA] Episode  120 | Return: -120.27 | Epsilon: 0.945
[SARSA] Episode  130 | Return: -104.18 | Epsilon: 0.940
[SARSA] Episode  140 | Return: -111.48 | Epsilon: 0.936
[SARSA] Episode  150 | Return: -246.98 | Epsilon: 0.931
[SARSA] Episode  160 | Return: -124.74 | Epsilon: 0.926
[SARSA] Episode  170 | Return:  -79.40 | Epsilon: 0.922
[SARSA] Episode  180 | Return: -132.32 | Epsilon

  logger.warn(
  IMAGEMAGICK_BINARY = r"C:\Program Files\ImageMagick-6.8.8-Q16\magick.exe"


Average return over 1 eval episodes: 301.09


In [None]:
_ = evaluate_and_show_video('LunarLander-v3', greedy_policy_sarsa(sarsa_net), num_episodes=5, name_prefix='sarsa')

  logger.warn(


Average return over 5 eval episodes: 174.76


## REINFORCE

In [None]:
class PolicyNet(nn.Module):
    def __init__(self, state_dim, num_actions, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, num_actions),
        )
    def forward(self, x):
        return self.net(x)  # logits

def train_reinforce(
    env_id='LunarLander-v3', gamma=0.99, lr=1e-3,
    max_episodes=500, max_steps=1000, batch_episodes=10
):
    env = gym.make(env_id)
    state_dim = env.observation_space.shape[0]
    num_actions = env.action_space.n
    policy = PolicyNet(state_dim, num_actions).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)

    batch_log_probs, batch_returns = [], []
    episodes_in_batch = 0

    for ep in range(1, max_episodes + 1):
        state, _ = env.reset()
        ep_rewards, ep_logps = [], []
        for t in range(max_steps):
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            logits = policy(s)
            dist = torch.distributions.Categorical(logits=logits)
            action = int(dist.sample().item())
            logp = dist.log_prob(torch.as_tensor([action], device=device)).squeeze(0)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_rewards.append(reward); ep_logps.append(logp)
            if done: break
        Gs = compute_returns(ep_rewards, gamma)
        batch_log_probs.extend(ep_logps)
        batch_returns.extend([torch.as_tensor(G, dtype=torch.float32, device=device) for G in Gs])
        episodes_in_batch += 1
        if ep % 10 == 0:
            print(f'[REINFORCE] Episode {ep:4d} | Return: {sum(ep_rewards):7.2f}')
        if episodes_in_batch >= batch_episodes:
            logps_t = torch.stack(batch_log_probs)
            returns_t = torch.stack(batch_returns)
            loss = -(logps_t * returns_t).mean()
            optimizer.zero_grad(); loss.backward();
            nn.utils.clip_grad_norm_(policy.parameters(), 10.0); optimizer.step()
            batch_log_probs.clear(); batch_returns.clear(); episodes_in_batch = 0
    env.close()
    return policy

def greedy_policy_reinforce(policy):
    def _fn(state):
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            logits = policy(s)
            return int(torch.argmax(logits, dim=1).item())
    return _fn


### Train REINFORCE (short run)

In [None]:
reinforce_policy = train_reinforce(max_episodes=2000, batch_episodes=20)
_ = evaluate_and_show_video('LunarLander-v3', greedy_policy_reinforce(reinforce_policy), num_episodes=1, name_prefix='reinforce')


[REINFORCE] Episode   10 | Return: -147.64
[REINFORCE] Episode   20 | Return: -265.04
[REINFORCE] Episode   30 | Return:  -87.81
[REINFORCE] Episode   40 | Return: -464.75
[REINFORCE] Episode   50 | Return:  -98.99
[REINFORCE] Episode   60 | Return: -151.67
[REINFORCE] Episode   70 | Return: -382.93
[REINFORCE] Episode   80 | Return: -381.29
[REINFORCE] Episode   90 | Return: -141.87
[REINFORCE] Episode  100 | Return: -293.73
[REINFORCE] Episode  110 | Return:  -94.92
[REINFORCE] Episode  120 | Return: -110.78
[REINFORCE] Episode  130 | Return: -190.50
[REINFORCE] Episode  140 | Return: -314.10
[REINFORCE] Episode  150 | Return: -147.35
[REINFORCE] Episode  160 | Return: -139.79
[REINFORCE] Episode  170 | Return: -101.40
[REINFORCE] Episode  180 | Return: -191.04
[REINFORCE] Episode  190 | Return:   26.86
[REINFORCE] Episode  200 | Return: -110.58
[REINFORCE] Episode  210 | Return: -139.92
[REINFORCE] Episode  220 | Return: -134.18
[REINFORCE] Episode  230 | Return: -209.14
[REINFORCE]

KeyboardInterrupt: 

## REINFORCE + Baseline

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, num_actions, hidden_dim=256):
        super().__init__()
        self.body = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        self.pi = nn.Linear(hidden_dim, num_actions)
        self.v = nn.Linear(hidden_dim, 1)
    def forward(self, x):
        h = self.body(x)
        return self.pi(h), self.v(h).squeeze(-1)

def train_reinforce_baseline(
    env_id='LunarLander-v3', gamma=0.99, lr=1e-3,
    max_episodes=500, max_steps=1000, batch_episodes=10, value_coef=0.5
):
    env = gym.make(env_id)
    state_dim = env.observation_space.shape[0]
    num_actions = env.action_space.n
    ac = ActorCritic(state_dim, num_actions).to(device)
    optimizer = optim.Adam(ac.parameters(), lr=lr)

    b_logps, b_values, b_returns = [], [], []
    episodes_in_batch = 0

    for ep in range(1, max_episodes + 1):
        state, _ = env.reset()
        ep_rewards, ep_logps, ep_vals = [], [], []
        for t in range(max_steps):
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            logits, value = ac(s)
            dist = torch.distributions.Categorical(logits=logits)
            action = int(dist.sample().item())
            logp = dist.log_prob(torch.as_tensor([action], device=device)).squeeze(0)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_rewards.append(reward); ep_logps.append(logp); ep_vals.append(value.squeeze(0))
            if done: break
        Gs = compute_returns(ep_rewards, gamma)
        b_logps.extend(ep_logps)
        b_values.extend(ep_vals)
        b_returns.extend([torch.as_tensor(G, dtype=torch.float32, device=device) for G in Gs])
        episodes_in_batch += 1
        if ep % 10 == 0:
            print(f'[REINFORCE+BL] Episode {ep:4d} | Return: {sum(ep_rewards):7.2f}')
        if episodes_in_batch >= batch_episodes:
            logps_t = torch.stack(b_logps)
            values_t = torch.stack(b_values)
            returns_t = torch.stack(b_returns)
            adv = returns_t - values_t.detach()
            # Normalize advantages to stabilize learning
            adv = (adv - adv.mean()) / (adv.std() + 1e-8)
            policy_loss = -(logps_t * adv).mean()
            value_loss = nn.functional.smooth_l1_loss(values_tensor, returns_tensor)
            loss = policy_loss + value_coef * value_loss
            optimizer.zero_grad(); loss.backward();
            nn.utils.clip_grad_norm_(ac.parameters(), 10.0); optimizer.step()
            b_logps.clear(); b_values.clear(); b_returns.clear(); episodes_in_batch = 0
    env.close()
    return ac

def greedy_policy_ac(ac):
    def _fn(state):
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            logits, _ = ac(s)
            return int(torch.argmax(logits, dim=1).item())
    return _fn


### Train REINFORCE + Baseline (short run)

In [None]:
ac_model = train_reinforce_baseline(max_episodes=6000, batch_episodes=20, value_coef=0.5)
_ = evaluate_and_show_video('LunarLander-v3', greedy_policy_ac(ac_model), num_episodes=1, name_prefix='reinforce_bl')
