# A2C (Advantage Actor-Critic) — Low-Level PyTorch Implementation (CartPole-v1)

A2C is an **on-policy actor-critic** algorithm:

- the **actor** learns a policy $\pi_\theta(a\mid s)$ (how to act)
- the **critic** learns a value function $V_\phi(s)$ (how good a state is)
- the actor is trained with **advantages** ("better than expected" signals)

This notebook builds the math carefully, then implements A2C with **minimal PyTorch** (no RL libraries, no high-level training abstractions), using a **vectorized Gymnasium environment** for synchronous rollouts.

---

## Learning goals

By the end you should be able to:

- derive the A2C update from the **policy gradient theorem**
- explain why the **baseline** (critic) reduces variance
- implement **GAE($\gamma,\lambda$)** and n-step bootstrapped returns
- train an A2C agent on `CartPole-v1` and visualize learning with Plotly
- map the concepts to **Stable-Baselines3** A2C hyperparameters


## Notebook roadmap

1. A2C intuition + what “advantage” means
2. Mathematical formulation (LaTeX)
3. Low-level PyTorch implementation (actor + critic)
4. Training on CartPole with vectorized rollouts
5. Plotly diagnostics (returns, losses, policy/value slices)
6. Stable-Baselines3 A2C reference + hyperparameters


In [None]:
import math
import time

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

try:
    import gymnasium as gym
    GYMNASIUM_AVAILABLE = True
except Exception as e:
    GYMNASIUM_AVAILABLE = False
    _GYM_IMPORT_ERROR = e

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e


pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

assert GYMNASIUM_AVAILABLE, f"gymnasium import failed: {_GYM_IMPORT_ERROR}"
assert TORCH_AVAILABLE, f"torch import failed: {_TORCH_IMPORT_ERROR}"

print('gymnasium', gym.__version__)
print('torch', torch.__version__)


In [None]:
# --- Run configuration ---

# Keep FAST_RUN=True for a quick demo.
# For a more reliable "solve", set FAST_RUN=False.
FAST_RUN = True

ENV_ID = "CartPole-v1"  # discrete actions, small continuous state
SEED = 42

# A2C is usually run with multiple envs in parallel.
N_ENVS = 8 if FAST_RUN else 16

# Rollout horizon per env (A2C commonly uses small n_steps).
N_STEPS = 5

# Total interaction budget
TOTAL_TIMESTEPS = 30_000 if FAST_RUN else 200_000

# Core RL hyperparameters
GAMMA = 0.99
GAE_LAMBDA = 1.0  # 1.0 => classic advantage w/ n-step bootstrapping

# Loss weights
ENT_COEF = 0.01
VF_COEF = 0.5

# Optimization
LR = 7e-4
MAX_GRAD_NORM = 0.5
RMSPROP_EPS = 1e-5

# Optional: normalize advantage each update
NORMALIZE_ADVANTAGE = True

# Network
HIDDEN_SIZES = (128, 128)

# Logging
LOG_EVERY_UPDATES = 50
RETURN_SMOOTHING_WINDOW = 50

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', DEVICE)


## 1) A2C intuition: actor + critic + advantage

### Actor
The actor is a stochastic policy $\pi_\theta(a\mid s)$.

- It outputs a distribution over actions.
- We sample actions from that distribution to explore.

### Critic
The critic is a value function $V_\phi(s)$.

- It predicts the **expected discounted return** from state $s$.
- It is trained via regression to match a bootstrapped return target.

### Advantage
The advantage measures how much better an action did compared to what the critic expected:

$$
A(s_t, a_t) = Q(s_t, a_t) - V(s_t).
$$

If $A(s_t,a_t)$ is positive, the action was better than expected, and the actor should increase its probability.

### Why “A2C”?
A2C is the **synchronous** version of A3C:

- A3C: many workers update parameters asynchronously.
- A2C: many workers collect experience **in parallel**, then we do a **single synchronized** update.

In practice, A2C typically uses a vectorized environment and batches data as:

$$
\text{batch size} = n_{\text{env}} \times n_{\text{steps}}.
$$


## 2) Mathematical formulation (policy gradient + baseline)

We model the environment as an MDP $(\mathcal{S}, \mathcal{A}, P, r, \gamma)$.

### Return
The discounted return from time $t$ is:

$$
G_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k}.
$$

### Objective
We want to maximize expected return:

$$
J(\theta) = \mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^{\infty} \gamma^t r_t\right].
$$

### Policy gradient theorem
A standard form is:

$$
\nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, Q^{\pi_\theta}(s_t, a_t)\right].
$$

### Baseline (variance reduction)
We can subtract a baseline $b(s_t)$ without changing the expectation:

$$
\mathbb{E}[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, b(s_t)] = 0.
$$

Choosing $b(s_t)=V_\phi(s_t)$ yields the advantage form:

$$
\nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, A_t\right],
\quad A_t \approx \hat{A}(s_t,a_t).
$$

### Bootstrapped n-step return
With a rollout horizon $T$ (a.k.a. `n_steps`), we use a bootstrapped target:

$$
\hat{R}_t = \sum_{k=0}^{T-1-t} \gamma^k r_{t+k} + \gamma^{T-t} V_\phi(s_T).
$$

### Generalized Advantage Estimation (GAE)
GAE defines the TD residual:

$$
\delta_t = r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)
$$

and computes advantages with an exponentially-weighted sum:

$$
\hat{A}_t^{\mathrm{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l\, \delta_{t+l}.
$$

- $\lambda=1$ recovers the classic (higher-variance) advantage.
- smaller $\lambda$ reduces variance but increases bias.

### Loss functions (minimization form)

Actor loss (to maximize expected advantage):

$$
\mathcal{L}_{\text{actor}}(\theta) = -\mathbb{E}\left[\log \pi_\theta(a_t\mid s_t)\, \hat{A}_t\right].
$$

Critic loss (value regression):

$$
\mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2}\,\mathbb{E}\left[(\hat{R}_t - V_\phi(s_t))^2\right].
$$

Entropy bonus (encourage exploration):

$$
\mathcal{L}_{\text{entropy}}(\theta) = -\mathbb{E}\left[\mathcal{H}(\pi_\theta(\cdot\mid s_t))\right].
$$

Total loss:

$$
\mathcal{L} = \mathcal{L}_{\text{actor}} + c_v\,\mathcal{L}_{\text{critic}} + c_e\,\mathcal{L}_{\text{entropy}}.
$$


In [None]:
def make_vec_env(env_id: str, n_envs: int, seed: int) -> gym.vector.SyncVectorEnv:
    env_fns = [lambda: gym.make(env_id) for _ in range(n_envs)]
    env = gym.vector.SyncVectorEnv(env_fns, autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
    env.reset(seed=[seed + i for i in range(n_envs)])
    return env


env = make_vec_env(ENV_ID, N_ENVS, SEED)

obs_space = env.single_observation_space
act_space = env.single_action_space

assert isinstance(act_space, gym.spaces.Discrete), "This notebook's implementation uses discrete actions (Categorical)."

OBS_DIM = int(np.prod(obs_space.shape))
N_ACTIONS = int(act_space.n)

print('obs_space', obs_space)
print('act_space', act_space)
print('OBS_DIM', OBS_DIM, 'N_ACTIONS', N_ACTIONS)


## 3) Actor-Critic network (low-level PyTorch)

We use a shared MLP trunk, then two heads:

- **actor head** outputs logits for a categorical distribution
- **critic head** outputs a scalar value $V(s)$

This is not the only architecture (you can also use separate networks), but it’s a common and effective baseline.


In [None]:
class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, int] = (128, 128)):
        super().__init__()

        h1, h2 = hidden_sizes
        self.fc1 = nn.Linear(obs_dim, h1)
        self.fc2 = nn.Linear(h1, h2)

        self.actor = nn.Linear(h2, n_actions)
        self.critic = nn.Linear(h2, 1)

    def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # obs: (B, obs_dim)
        x = torch.tanh(self.fc1(obs))
        x = torch.tanh(self.fc2(x))
        logits = self.actor(x)            # (B, n_actions)
        values = self.critic(x).squeeze(-1)  # (B,)
        return logits, values


def sample_actions_and_logp(logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Low-level categorical sampling without torch.distributions.

    Returns:
      actions: (B,) int64
      logp:    (B,) log-prob of sampled action
      entropy: (B,) categorical entropy
    """
    log_probs = F.log_softmax(logits, dim=-1)  # (B, A)
    probs = log_probs.exp()

    actions = torch.multinomial(probs, num_samples=1).squeeze(-1)  # (B,)
    logp = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
    entropy = -(probs * log_probs).sum(dim=-1)

    return actions, logp, entropy


@torch.no_grad()
def policy_action_probs(logits: torch.Tensor) -> torch.Tensor:
    return F.softmax(logits, dim=-1)


## 4) GAE implementation

We compute advantages backwards in time:

$$
\delta_t = r_t + \gamma (1-d_t) V(s_{t+1}) - V(s_t)
$$

$$
A_t = \delta_t + \gamma\lambda(1-d_t) A_{t+1}
$$

where $d_t\in\{0,1\}$ is the done flag.


In [None]:
def compute_gae(
    rewards: torch.Tensor,   # (T, N)
    dones: torch.Tensor,     # (T, N) float32 {0,1}
    values: torch.Tensor,    # (T, N)
    last_values: torch.Tensor,  # (N,)
    gamma: float,
    gae_lambda: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    T, N = rewards.shape
    advantages = torch.zeros((T, N), device=rewards.device, dtype=torch.float32)
    last_adv = torch.zeros((N,), device=rewards.device, dtype=torch.float32)

    for t in reversed(range(T)):
        mask = 1.0 - dones[t]
        next_values = last_values if t == T - 1 else values[t + 1]
        delta = rewards[t] + gamma * mask * next_values - values[t]
        last_adv = delta + gamma * gae_lambda * mask * last_adv
        advantages[t] = last_adv

    returns = advantages + values
    return advantages, returns


## 5) Training loop (A2C)

Key design choices in this minimal implementation:

- **Vectorized envs** (`n_envs`) to match A2C’s synchronous batching.
- Rollout buffer of shape `(n_steps, n_envs, ...)`.
- Compute **GAE + bootstrapped returns**.
- Single gradient update per rollout (no replay buffer, no off-policy corrections).

We also record:

- episodic return (score) whenever any env finishes an episode
- actor loss, critic loss, entropy, explained variance (optional diagnostic)


In [None]:
def explained_variance(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    var_y = np.var(y_true)
    if var_y < 1e-12:
        return float('nan')
    return 1.0 - float(np.var(y_true - y_pred) / var_y)


def train_a2c(
    env_id: str,
    seed: int,
    device: torch.device,
    n_envs: int,
    n_steps: int,
    total_timesteps: int,
    gamma: float,
    gae_lambda: float,
    ent_coef: float,
    vf_coef: float,
    lr: float,
    max_grad_norm: float,
    rmsprop_eps: float,
    hidden_sizes: tuple[int, int],
    normalize_advantage: bool,
    log_every_updates: int = 50,
    ):
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = make_vec_env(env_id, n_envs, seed)
    obs_space = env.single_observation_space
    act_space = env.single_action_space

    obs_dim = int(np.prod(obs_space.shape))
    n_actions = int(act_space.n)

    model = ActorCritic(obs_dim, n_actions, hidden_sizes=hidden_sizes).to(device)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, eps=rmsprop_eps)

    # Rollout buffers
    obs_buf = torch.zeros((n_steps, n_envs, obs_dim), device=device, dtype=torch.float32)
    act_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.int64)
    rew_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
    done_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
    val_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)

    obs, _ = env.reset(seed=[seed + i for i in range(n_envs)])

    # Episode tracking across vector envs
    ep_returns_running = np.zeros((n_envs,), dtype=np.float32)
    ep_lengths_running = np.zeros((n_envs,), dtype=np.int32)
    ep_returns: list[float] = []
    ep_lengths: list[int] = []

    updates = total_timesteps // (n_envs * n_steps)
    history_updates: list[dict] = []
    last_adv_flat = None

    t0 = time.time()
    global_step = 0
    model.train()

    for update in range(1, updates + 1):
        # --- Collect rollout ---
        for t in range(n_steps):
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
            obs_buf[t] = obs_t

            with torch.no_grad():
                logits, values = model(obs_t)
                actions, _, _ = sample_actions_and_logp(logits)

            act_buf[t] = actions
            val_buf[t] = values

            next_obs, rewards, terminated, truncated, _ = env.step(actions.cpu().numpy())
            dones = np.logical_or(terminated, truncated)

            rew_buf[t] = torch.as_tensor(rewards, dtype=torch.float32, device=device)
            done_buf[t] = torch.as_tensor(dones, dtype=torch.float32, device=device)

            # Episode bookkeeping
            ep_returns_running += rewards
            ep_lengths_running += 1
            for i in np.where(dones)[0]:
                ep_returns.append(float(ep_returns_running[i]))
                ep_lengths.append(int(ep_lengths_running[i]))
                ep_returns_running[i] = 0.0
                ep_lengths_running[i] = 0

            obs = next_obs
            global_step += n_envs

        # Bootstrap value from last observation
        with torch.no_grad():
            obs_last = torch.as_tensor(obs, dtype=torch.float32, device=device)
            _, last_values = model(obs_last)  # (N,)

        advantages, returns = compute_gae(
            rewards=rew_buf,
            dones=done_buf,
            values=val_buf,
            last_values=last_values,
            gamma=gamma,
            gae_lambda=gae_lambda,
        )

        # Flatten (T, N, ...) -> (T*N, ...)
        b_obs = obs_buf.reshape(-1, obs_dim)
        b_act = act_buf.reshape(-1)
        b_adv = advantages.reshape(-1)
        b_ret = returns.reshape(-1)

        if normalize_advantage:
            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

        # --- Compute losses ---
        logits, values_pred = model(b_obs)
        log_probs = F.log_softmax(logits, dim=-1)
        probs = log_probs.exp()

        b_logp = log_probs.gather(1, b_act.unsqueeze(1)).squeeze(1)
        entropy = -(probs * log_probs).sum(dim=-1).mean()

        actor_loss = -(b_logp * b_adv.detach()).mean()
        critic_loss = 0.5 * F.mse_loss(values_pred, b_ret.detach())

        loss = actor_loss + vf_coef * critic_loss - ent_coef * entropy

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        last_adv_flat = b_adv.detach().cpu().numpy()

        # Diagnostics
        y_true = b_ret.detach().cpu().numpy()
        y_pred = values_pred.detach().cpu().numpy()

        mean_ep_return = float(np.mean(ep_returns[-RETURN_SMOOTHING_WINDOW:])) if len(ep_returns) else float('nan')

        history_updates.append(
            dict(
                update=update,
                timesteps=global_step,
                actor_loss=float(actor_loss.detach().cpu().item()),
                critic_loss=float(critic_loss.detach().cpu().item()),
                entropy=float(entropy.detach().cpu().item()),
                explained_variance=explained_variance(y_true, y_pred),
                episodes=len(ep_returns),
                mean_return_window=mean_ep_return,
            )
        )

        if update % log_every_updates == 0 or update == 1 or update == updates:
            elapsed = time.time() - t0
            print(
                f"update {update:>4d}/{updates} | steps {global_step:>7d} | episodes {len(ep_returns):>5d} | "
                f"mean_return@{RETURN_SMOOTHING_WINDOW} {mean_ep_return:>7.1f} | "
                f"loss {float(loss.detach().cpu()):>8.4f} | {elapsed:>6.1f}s"
            )

    env.close()
    hist_df = pd.DataFrame(history_updates)
    return model, hist_df, np.array(ep_returns, dtype=np.float32), np.array(ep_lengths, dtype=np.int32), last_adv_flat


In [None]:
model, hist_df, ep_returns, ep_lengths, last_adv_flat = train_a2c(
    env_id=ENV_ID,
    seed=SEED,
    device=DEVICE,
    n_envs=N_ENVS,
    n_steps=N_STEPS,
    total_timesteps=TOTAL_TIMESTEPS,
    gamma=GAMMA,
    gae_lambda=GAE_LAMBDA,
    ent_coef=ENT_COEF,
    vf_coef=VF_COEF,
    lr=LR,
    max_grad_norm=MAX_GRAD_NORM,
    rmsprop_eps=RMSPROP_EPS,
    hidden_sizes=HIDDEN_SIZES,
    normalize_advantage=NORMALIZE_ADVANTAGE,
    log_every_updates=LOG_EVERY_UPDATES,
)

hist_df.tail()


## 6) Plot: score (return) per episode

CartPole gives reward $+1$ per time step, so **episode return = episode length** (up to 500).


In [None]:
episodes = np.arange(1, len(ep_returns) + 1)

roll_mean = pd.Series(ep_returns).rolling(RETURN_SMOOTHING_WINDOW).mean().to_numpy()

fig = go.Figure()
fig.add_trace(go.Scatter(x=episodes, y=ep_returns, mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=episodes, y=roll_mean, mode='lines', name=f'mean@{RETURN_SMOOTHING_WINDOW}', line=dict(width=3)))
fig.update_layout(
    title='A2C on CartPole-v1 — score (return) per episode',
    xaxis_title='episode',
    yaxis_title='return',
)
fig.show()


## 7) Plot: training diagnostics (losses, entropy, explained variance)

- **Actor loss** becomes more negative when advantages are consistently positive for sampled actions.
- **Critic loss** should generally decrease as the value function fits the returns.
- **Entropy** typically decreases as the policy becomes more confident.
- **Explained variance** (rough critic diagnostic) near 1 is good; near 0 means the critic explains little.


In [None]:
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Actor loss", "Critic loss", "Entropy", "Explained variance"),
)

fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['actor_loss'], name='actor_loss'), row=1, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['critic_loss'], name='critic_loss'), row=1, col=2)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['entropy'], name='entropy'), row=2, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['explained_variance'], name='explained_variance'), row=2, col=2)

fig.update_layout(height=700, title='A2C training diagnostics', showlegend=False)
fig.update_xaxes(title_text='timesteps')
fig.show()


## 8) Plot: advantage distribution (last update)

A2C pushes up the log-probability of actions with **positive advantage** and pushes down those with **negative advantage**.


In [None]:
fig = px.histogram(
    x=last_adv_flat,
    nbins=60,
    title='Advantage histogram (last update)',
)
fig.update_layout(xaxis_title='advantage', yaxis_title='count')
fig.show()


## 9) Visualize the learned policy + value function (2D slice)

CartPole states are 4D:

$$
s = (x, \dot{x}, \theta, \dot{\theta}).
$$

To visualize something, we take a **2D slice** over pole angle $\theta$ and pole angular velocity $\dot{\theta}$, while fixing $x=0$ and $\dot{x}=0$.

- Left plot: $\pi(a=1\mid s)$ (probability of pushing right)
- Right plot: $V(s)$ (critic estimate)


In [None]:
@torch.no_grad()
def policy_value_slice(model: nn.Module, device: torch.device, grid_n: int = 70):
    model.eval()
    angles = np.linspace(-0.21, 0.21, grid_n)  # roughly CartPole angle limits
    ang_vels = np.linspace(-3.0, 3.0, grid_n)

    theta, theta_dot = np.meshgrid(angles, ang_vels)
    states = np.zeros((grid_n * grid_n, 4), dtype=np.float32)
    states[:, 2] = theta.ravel()
    states[:, 3] = theta_dot.ravel()

    obs_t = torch.as_tensor(states, dtype=torch.float32, device=device)
    logits, values = model(obs_t)
    probs = policy_action_probs(logits)
    p_right = probs[:, 1].reshape(grid_n, grid_n).cpu().numpy()
    v = values.reshape(grid_n, grid_n).cpu().numpy()

    return angles, ang_vels, p_right, v


angles, ang_vels, p_right, v = policy_value_slice(model, DEVICE)

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Policy: P(push right)", "Critic: V(s)"),
)

fig.add_trace(
    go.Heatmap(
        x=angles,
        y=ang_vels,
        z=p_right,
        colorscale='RdBu',
        zmin=0.0,
        zmax=1.0,
        colorbar=dict(title='P(right)'),
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Heatmap(
        x=angles,
        y=ang_vels,
        z=v,
        colorscale='Viridis',
        colorbar=dict(title='V(s)'),
    ),
    row=1,
    col=2,
)

fig.update_layout(
    height=420,
    title='Learned policy/value on a 2D state slice (x=0, xdot=0)',
)
fig.update_xaxes(title_text='pole angle θ', row=1, col=1)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=1)
fig.update_xaxes(title_text='pole angle θ', row=1, col=2)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=2)
fig.show()


## 10) Quick evaluation (deterministic actions)

We evaluate by taking the greedy action $\arg\max_a \pi(a\mid s)$.


In [None]:
@torch.no_grad()
def evaluate(model: nn.Module, env_id: str, n_episodes: int = 10, seed: int = 0):
    env = gym.make(env_id)
    returns = []
    for ep in range(n_episodes):
        obs, _ = env.reset(seed=seed + ep)
        done = False
        ret = 0.0
        while not done:
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            logits, _ = model(obs_t)
            action = int(torch.argmax(logits, dim=-1).item())

            obs, reward, terminated, truncated, _ = env.step(action)
            done = bool(terminated or truncated)
            ret += float(reward)
        returns.append(ret)
    env.close()
    return np.array(returns, dtype=np.float32)


model.eval()
eval_returns = evaluate(model, ENV_ID, n_episodes=10, seed=SEED + 1000)
print('eval returns:', eval_returns)
print('mean ± std:', float(eval_returns.mean()), '±', float(eval_returns.std()))


## 11) Pitfalls + diagnostics

- **On-policy constraint**: A2C uses data from the *current* policy. If you reuse old experience without correction, it becomes biased.
- **Done handling**: You must stop bootstrapping across episode boundaries. Here we treat `terminated OR truncated` as terminal for simplicity.
- **Entropy coefficient**: Too high keeps the policy random; too low can collapse exploration early.
- **Critic collapse**: If the critic is too weak/strong relative to the actor, learning can become unstable.
- **Parallel envs matter**: With too few envs you get higher-variance updates.

Good quick checks:

- returns increase over time
- entropy decreases gradually (not instantly)
- critic loss decreases and explained variance improves


## 12) Exercises

1. Change $\lambda$ in GAE (e.g. 0.9) and compare learning curves.
2. Swap RMSprop for Adam and compare stability.
3. Implement **continuous actions** by outputting a Gaussian policy (mean + log-std) and testing on `Pendulum-v1`.
4. Add a learning-rate schedule.
5. Add observation normalization and compare speed.


## 13) Stable-Baselines3 A2C reference implementation (web research)

Stable-Baselines3 (SB3) includes an A2C implementation.

- Docs page: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html

### Minimal usage

```python
from stable_baselines3 import A2C
import gymnasium as gym

env = gym.make("CartPole-v1")
model = A2C(
    policy="MlpPolicy",
    env=env,
    learning_rate=7e-4,
    n_steps=5,
    gamma=0.99,
    gae_lambda=1.0,
    ent_coef=0.0,
    vf_coef=0.5,
    max_grad_norm=0.5,
    rms_prop_eps=1e-5,
    use_rms_prop=True,
    normalize_advantage=False,
)
model.learn(total_timesteps=200_000)
```

### SB3 A2C hyperparameters (signature + meaning)

From the SB3 docs, the constructor signature is:

```
A2C(policy, env, learning_rate=0.0007, n_steps=5, gamma=0.99, gae_lambda=1.0,
    ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, rms_prop_eps=1e-05,
    use_rms_prop=True, use_sde=False, sde_sample_freq=-1,
    rollout_buffer_class=None, rollout_buffer_kwargs=None,
    normalize_advantage=False, stats_window_size=100, tensorboard_log=None,
    policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
```

Parameter meanings (SB3 docs):

- `policy`: policy class (e.g. `MlpPolicy`, `CnnPolicy`)
- `env`: environment (Gym env, VecEnv, or registered env id string)
- `learning_rate`: float or schedule
- `n_steps`: rollout length per env (batch size = `n_steps * n_env`)
- `gamma`: discount factor
- `gae_lambda`: bias/variance trade-off for GAE; `1.0` equals classic advantage
- `ent_coef`: entropy coefficient
- `vf_coef`: value loss coefficient
- `max_grad_norm`: gradient clipping threshold
- `rms_prop_eps`: RMSprop epsilon
- `use_rms_prop`: use RMSprop (default) vs Adam
- `use_sde`: generalized State Dependent Exploration (gSDE)
- `sde_sample_freq`: resample gSDE noise every n steps (`-1` = only at rollout start)
- `rollout_buffer_class`: custom rollout buffer class
- `rollout_buffer_kwargs`: kwargs for rollout buffer
- `normalize_advantage`: normalize advantages
- `stats_window_size`: episodes window for logging averages
- `tensorboard_log`: tensorboard log dir
- `policy_kwargs`: kwargs for policy network/architecture
- `verbose`: verbosity level
- `seed`: random seed
- `device`: `cpu`, `cuda`, or `auto`
- `_init_setup_model`: build the network immediately


## References

- Mnih et al. (2016), *Asynchronous Methods for Deep Reinforcement Learning* (A3C)
- Schulman et al. (2016), *High-Dimensional Continuous Control Using Generalized Advantage Estimation*
- Stable-Baselines3 A2C docs: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html
