# TRPO (Trust Region Policy Optimization) — low-level PyTorch implementation

TRPO is an on-policy policy-gradient method that makes **monotonic-ish**, stable updates by constraining how much the policy is allowed to change each iteration via a **KL-divergence trust region**.

In this notebook you will:
- Derive the **KL constraint** (LaTeX) and how it leads to a **natural-gradient** step
- Implement TRPO "from scratch" with **PyTorch autograd** + **conjugate gradient** + **backtracking line search**
- Visualize **policy updates**, **KL per update**, and **episodic returns** with **Plotly**
- See a reference **Stable-Baselines TRPO** implementation and understand its hyperparameters


## Notebook roadmap

1. TRPO objective + the KL-divergence constraint (math)
2. A tiny offline-friendly continuous-control environment (no downloads)
3. Gaussian policy + value baseline (PyTorch)
4. GAE advantages + value function fit
5. TRPO update step (Fisher-vector product, conjugate gradient, line search)
6. Plotly: episodic rewards, KL constraint, policy update snapshots
7. Stable-Baselines TRPO: usage + hyperparameters (end)


In [None]:
import sys
import time

import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
import torch.nn.functional as F

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

np.set_printoptions(precision=4, suppress=True)

DEVICE = torch.device("cpu")

SEED = 42
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)


In [None]:
print("Python:", sys.version.split()[0])
print("NumPy:", np.__version__)
import plotly

print("Plotly:", plotly.__version__)
print("PyTorch:", torch.__version__)
print("Device:", DEVICE)


## 1) TRPO objective and the KL-divergence constraint

TRPO is usually presented as the constrained optimization problem:

\[
\max_\theta\; \mathbb{E}_{s,a\sim \pi_{\theta_{\text{old}}}}\left[\frac{\pi_\theta(a\mid s)}{\pi_{\theta_{\text{old}}}(a\mid s)}\,\hat A_{\theta_{\text{old}}}(s,a)\right]
\qquad\text{s.t.}\qquad
\mathbb{E}_{s\sim \pi_{\theta_{\text{old}}}}\left[D_{\mathrm{KL}}\!\left(\pi_{\theta_{\text{old}}}(\cdot\mid s)\,\|\,\pi_\theta(\cdot\mid s)\right)\right] \le \delta.
\]

The trust region is **average KL divergence** (under states visited by the old policy). Intuition: *"move in a direction that increases the objective, but don't move too far in policy space."*

We use the standard definition:

\[
D_{\mathrm{KL}}(p\|q) = \mathbb{E}_{x\sim p}\left[\log\frac{p(x)}{q(x)}\right].
\]


### 1.1) Why this leads to a natural-gradient step

Let \(\theta\) be the policy parameters and \(\theta_{\text{old}}\) the pre-update parameters.

TRPO uses two approximations around \(\theta_{\text{old}}\):

- **First-order** (linear) approximation of the surrogate objective:

\[
L(\theta) \approx L(\theta_{\text{old}}) + g^\top (\theta - \theta_{\text{old}})
\quad\text{where}\quad g = \nabla_\theta L(\theta)\big\rvert_{\theta=\theta_{\text{old}}}.
\]

- **Second-order** (quadratic) approximation of the KL constraint:

\[
\bar D_{\mathrm{KL}}(\theta_{\text{old}},\theta)
\approx \tfrac12 (\theta - \theta_{\text{old}})^\top H (\theta - \theta_{\text{old}}),
\]

where \(H\) is the Hessian of the average KL at \(\theta_{\text{old}}\) (equivalently, the policy's **Fisher information matrix** for common exponential-family policies).

Define the step \(p = \theta - \theta_{\text{old}}\). The constrained problem becomes:

\[
\max_p\; g^\top p
\qquad\text{s.t.}\qquad
\tfrac12 p^\top H p \le \delta.
\]

The solution is:

\[
p^\* = \sqrt{\frac{2\delta}{g^\top H^{-1} g}}\; H^{-1} g.
\]

So we need:
1. The policy-gradient \(g\)
2. The product \(H^{-1} g\) (without forming \(H\) explicitly) → **conjugate gradient** + **Hessian-vector products**
3. A step scaling + **backtracking line search** to satisfy the *true* KL constraint and improve the surrogate.


## 2) A tiny offline-friendly continuous-control environment

To keep the notebook self-contained (no Gym downloads), we use a 1D point-mass with state \(s=(x,v)\) and action \(a\in[-1,1]\):

- Dynamics: small acceleration changes velocity, velocity changes position
- Goal: reach \(x=0\) with small velocity
- Reward: negative quadratic cost (plus a small terminal bonus when reaching the goal)

This is *not* meant to be a benchmark; it's just enough to show that TRPO learns and that the KL trust region stabilizes updates.


In [None]:
class PointMass1DEnv:
    def __init__(
        self,
        dt: float = 0.05,
        max_steps: int = 150,
        x_init_range: float = 2.0,
        v_init_range: float = 0.5,
        action_max: float = 1.0,
        goal_x: float = 0.0,
        goal_tol: float = 0.05,
        goal_bonus: float = 5.0,
        seed: int | None = None,
    ):
        self.dt = float(dt)
        self.max_steps = int(max_steps)
        self.x_init_range = float(x_init_range)
        self.v_init_range = float(v_init_range)
        self.action_max = float(action_max)
        self.goal_x = float(goal_x)
        self.goal_tol = float(goal_tol)
        self.goal_bonus = float(goal_bonus)
        self.rng = np.random.default_rng(seed)

        self.steps = 0
        self.x = 0.0
        self.v = 0.0

    @property
    def obs_dim(self):
        return 2

    @property
    def act_dim(self):
        return 1

    def reset(self, seed: int | None = None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.steps = 0
        self.x = self.rng.uniform(-self.x_init_range, self.x_init_range)
        self.v = self.rng.uniform(-self.v_init_range, self.v_init_range)
        return np.array([self.x, self.v], dtype=np.float32)

    def step(self, action):
        a = float(np.clip(action, -self.action_max, self.action_max))

        # simple damped dynamics
        self.v = 0.99 * self.v + a * self.dt
        self.x = self.x + self.v * self.dt
        self.steps += 1

        # quadratic cost around the goal
        cost = (self.x - self.goal_x) ** 2 + 0.1 * (self.v**2) + 0.001 * (a**2)
        reward = -float(cost)

        done = False
        if abs(self.x - self.goal_x) < self.goal_tol and abs(self.v) < self.goal_tol:
            done = True
            reward += float(self.goal_bonus)
        if self.steps >= self.max_steps:
            done = True

        obs = np.array([self.x, self.v], dtype=np.float32)
        return obs, reward, done, {}


In [None]:
env = PointMass1DEnv(seed=SEED)
obs = env.reset()

xs, vs, acts, rews = [obs[0]], [obs[1]], [], []
done = False
while not done:
    a = rng.uniform(-1.0, 1.0)
    obs, r, done, _ = env.step(a)
    xs.append(obs[0])
    vs.append(obs[1])
    acts.append(a)
    rews.append(r)

fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(y=xs, mode="lines", name="x"), row=1, col=1)
fig.add_trace(go.Scatter(y=vs, mode="lines", name="v"), row=2, col=1)
fig.add_trace(go.Scatter(y=acts, mode="lines", name="a"), row=3, col=1)
fig.update_layout(
    title="One random rollout in the toy env",
    height=650,
    showlegend=True,
)
fig.update_yaxes(title_text="position x", row=1, col=1)
fig.update_yaxes(title_text="velocity v", row=2, col=1)
fig.update_yaxes(title_text="action a", row=3, col=1)
fig.update_xaxes(title_text="time step", row=3, col=1)
fig.show()

print("Return (sum reward):", float(np.sum(rews)))


## 3) Policy and value function (PyTorch)

We'll use:
- A **Gaussian policy** \(\pi_\theta(a\mid s)=\mathcal{N}(\mu_\theta(s),\sigma_\theta(s)^2)\) with diagonal covariance (here 1D)
- A **value network** \(V_\phi(s)\) as a baseline

For TRPO we need:
- \(\log \pi_\theta(a\mid s)\) to compute the surrogate objective
- The **KL** between old and new Gaussian policies to build the trust region (and its Hessian-vector product)


In [None]:
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else output_activation
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        layers.append(act())
    return nn.Sequential(*layers)


class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.Tanh)
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def forward(self, obs: torch.Tensor):
        mean = self.net(obs)
        log_std = self.log_std.expand_as(mean)
        return mean, log_std

    def dist(self, obs: torch.Tensor):
        mean, log_std = self.forward(obs)
        return torch.distributions.Normal(mean, torch.exp(log_std))

    @torch.no_grad()
    def act(self, obs: torch.Tensor):
        dist = self.dist(obs)
        action = dist.sample()
        logp = dist.log_prob(action).sum(-1)
        return action, logp


class ValueNet(nn.Module):
    def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, 1], activation=nn.Tanh)

    def forward(self, obs: torch.Tensor):
        return self.net(obs).squeeze(-1)


## 4) TRPO building blocks (low-level)

We implement:
- GAE(\(\gamma,\lambda\)) for advantages
- Value function regression
- Conjugate gradient for solving \(H x = g\)
- Fisher/Hessian-vector product via autograd on the mean KL
- Backtracking line search enforcing the KL constraint


In [None]:
def gaussian_kl(mean_old, log_std_old, mean_new, log_std_new):
    """KL( N_old || N_new ) for diagonal Gaussians; returns shape (batch,)."""
    var_old = torch.exp(2.0 * log_std_old)
    var_new = torch.exp(2.0 * log_std_new)
    kl_per_dim = (
        log_std_new
        - log_std_old
        + (var_old + (mean_old - mean_new) ** 2) / (2.0 * var_new)
        - 0.5
    )
    return kl_per_dim.sum(dim=-1)


def flat_params(model: nn.Module):
    return torch.cat([p.data.view(-1) for p in model.parameters()])


def set_flat_params(model: nn.Module, flat: torch.Tensor):
    idx = 0
    with torch.no_grad():
        for p in model.parameters():
            n = p.numel()
            p.copy_(flat[idx : idx + n].view_as(p))
            idx += n


def flat_grad(grads, params):
    out = []
    for g, p in zip(grads, params):
        if g is None:
            out.append(torch.zeros_like(p).view(-1))
        else:
            out.append(g.contiguous().view(-1))
    return torch.cat(out)


def conjugate_gradient(fvp_fn, b, cg_iters=10, residual_tol=1e-10):
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(cg_iters):
        Avp = fvp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x = x + alpha * p
        r = r - alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / (rdotr + 1e-8)
        p = r + beta * p
        rdotr = new_rdotr

    return x


def trpo_update(
    policy: GaussianPolicy,
    obs: torch.Tensor,
    act: torch.Tensor,
    adv: torch.Tensor,
    logp_old: torch.Tensor,
    max_kl: float = 0.01,
    cg_iters: int = 10,
    cg_damping: float = 1e-2,
    backtrack_iters: int = 10,
    backtrack_coeff: float = 0.8,
):
    """One TRPO policy update step."""

    params = list(policy.parameters())
    old_params = flat_params(policy)

    with torch.no_grad():
        mean_old, log_std_old = policy.forward(obs)
        mean_old = mean_old.detach()
        log_std_old = log_std_old.detach()

    def surrogate():
        dist = policy.dist(obs)
        logp = dist.log_prob(act).sum(-1)
        ratio = torch.exp(logp - logp_old)
        return (ratio * adv).mean()

    def mean_kl():
        mean_new, log_std_new = policy.forward(obs)
        return gaussian_kl(mean_old, log_std_old, mean_new, log_std_new).mean()

    surr = surrogate()
    g = torch.autograd.grad(surr, params, retain_graph=True, allow_unused=True)
    g_flat = flat_grad(g, params).detach()

    def fvp(v):
        kl = mean_kl()
        grads = torch.autograd.grad(kl, params, create_graph=True, allow_unused=True)
        flat_kl_grad = flat_grad(grads, params)
        kl_v = torch.dot(flat_kl_grad, v)
        grads2 = torch.autograd.grad(kl_v, params, allow_unused=True)
        hvp = flat_grad(grads2, params).detach()
        return hvp + cg_damping * v

    step_dir = conjugate_gradient(fvp, g_flat, cg_iters=cg_iters)
    shs = torch.dot(step_dir, fvp(step_dir))
    step_size = torch.sqrt(torch.tensor(2.0 * max_kl, dtype=shs.dtype) / (shs + 1e-8))
    full_step = step_dir * step_size

    def eval_surr_and_kl():
        with torch.no_grad():
            s = surrogate().item()
            k = mean_kl().item()
        return s, k

    surr_old_val, _ = eval_surr_and_kl()

    step_frac = 1.0
    accepted = False
    surr_new_val = surr_old_val
    kl_new_val = 0.0

    for _ in range(backtrack_iters):
        new_params = old_params + step_frac * full_step
        set_flat_params(policy, new_params)

        surr_new_val, kl_new_val = eval_surr_and_kl()

        if (surr_new_val > surr_old_val) and (kl_new_val <= max_kl):
            accepted = True
            break
        step_frac *= backtrack_coeff

    if not accepted:
        set_flat_params(policy, old_params)

    return {
        "surr_old": float(surr_old_val),
        "surr_new": float(surr_new_val),
        "kl": float(kl_new_val),
        "step_frac": float(step_frac if accepted else 0.0),
        "accepted": bool(accepted),
    }


In [None]:
def collect_batch(env, policy, value_net, steps_per_batch, gamma=0.99, lam=0.98):
    obs_buf = np.zeros((steps_per_batch, env.obs_dim), dtype=np.float32)
    act_buf = np.zeros((steps_per_batch, env.act_dim), dtype=np.float32)
    rew_buf = np.zeros(steps_per_batch, dtype=np.float32)
    done_buf = np.zeros(steps_per_batch, dtype=np.float32)
    val_buf = np.zeros(steps_per_batch, dtype=np.float32)
    logp_buf = np.zeros(steps_per_batch, dtype=np.float32)

    ep_returns = []
    ep_ret = 0.0

    obs = env.reset()

    for t in range(steps_per_batch):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        with torch.no_grad():
            a_t, logp_t = policy.act(obs_t)
            v_t = value_net(obs_t)

        a = a_t.squeeze(0).cpu().numpy()
        logp = float(logp_t.item())
        v = float(v_t.item())

        next_obs, r, done, _ = env.step(a)

        obs_buf[t] = obs
        act_buf[t] = a
        rew_buf[t] = r
        done_buf[t] = float(done)
        val_buf[t] = v
        logp_buf[t] = logp

        ep_ret += float(r)

        obs = next_obs
        if done:
            ep_returns.append(ep_ret)
            ep_ret = 0.0
            obs = env.reset()

    # bootstrap value for the last state (if last transition wasn't terminal)
    with torch.no_grad():
        last_val = value_net(
            torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        ).item()

    adv_buf = np.zeros(steps_per_batch, dtype=np.float32)
    last_gae = 0.0

    for t in reversed(range(steps_per_batch)):
        if t == steps_per_batch - 1:
            next_nonterminal = 1.0 - done_buf[t]
            next_value = last_val
        else:
            next_nonterminal = 1.0 - done_buf[t]
            next_value = val_buf[t + 1]

        delta = rew_buf[t] + gamma * next_value * next_nonterminal - val_buf[t]
        last_gae = delta + gamma * lam * next_nonterminal * last_gae
        adv_buf[t] = last_gae

    ret_buf = adv_buf + val_buf

    # normalize advantages (very common and usually helpful)
    adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)

    batch = {
        "obs": torch.as_tensor(obs_buf, dtype=torch.float32, device=DEVICE),
        "act": torch.as_tensor(act_buf, dtype=torch.float32, device=DEVICE),
        "logp_old": torch.as_tensor(logp_buf, dtype=torch.float32, device=DEVICE),
        "adv": torch.as_tensor(adv_buf, dtype=torch.float32, device=DEVICE),
        "ret": torch.as_tensor(ret_buf, dtype=torch.float32, device=DEVICE),
        "ep_returns": ep_returns,
    }
    return batch


In [None]:
# --- Run configuration ---
FAST_RUN = True  # set False for a longer run

TOTAL_ITERS = 25 if FAST_RUN else 150
STEPS_PER_BATCH = 1024 if FAST_RUN else 4096

GAMMA = 0.99
LAMBDA = 0.98

MAX_KL = 0.01
CG_ITERS = 10
CG_DAMPING = 1e-2
BACKTRACK_ITERS = 10
BACKTRACK_COEFF = 0.8

VF_LR = 3e-4
VF_ITERS = 10 if FAST_RUN else 80
VF_BATCH = 128

SNAPSHOT_EVERY = 5

env = PointMass1DEnv(seed=SEED)
policy = GaussianPolicy(env.obs_dim, env.act_dim, hidden_sizes=(64, 64)).to(DEVICE)
value_net = ValueNet(env.obs_dim, hidden_sizes=(64, 64)).to(DEVICE)

vf_optim = torch.optim.Adam(value_net.parameters(), lr=VF_LR)

x_grid = np.linspace(-env.x_init_range, env.x_init_range, 101, dtype=np.float32)


In [None]:
history = {
    "iter": [],
    "ep_ret_mean": [],
    "ep_ret_p10": [],
    "ep_ret_p90": [],
    "kl": [],
    "surr_old": [],
    "surr_new": [],
    "step_frac": [],
    "policy_std": [],
}

policy_snapshots = []

t0 = time.time()

for it in range(TOTAL_ITERS):
    batch = collect_batch(
        env,
        policy,
        value_net,
        steps_per_batch=STEPS_PER_BATCH,
        gamma=GAMMA,
        lam=LAMBDA,
    )

    # --- Fit value function ---
    for _ in range(VF_ITERS):
        n = batch["obs"].shape[0]
        bs = min(VF_BATCH, n)
        idx = torch.as_tensor(rng.choice(n, size=bs, replace=False), device=DEVICE)
        v_pred = value_net(batch["obs"][idx])
        v_loss = F.mse_loss(v_pred, batch["ret"][idx])
        vf_optim.zero_grad()
        v_loss.backward()
        vf_optim.step()

    # --- TRPO policy update ---
    stats = trpo_update(
        policy,
        obs=batch["obs"],
        act=batch["act"],
        adv=batch["adv"],
        logp_old=batch["logp_old"],
        max_kl=MAX_KL,
        cg_iters=CG_ITERS,
        cg_damping=CG_DAMPING,
        backtrack_iters=BACKTRACK_ITERS,
        backtrack_coeff=BACKTRACK_COEFF,
    )

    # --- Metrics ---
    ep_returns = batch["ep_returns"]
    if len(ep_returns) > 0:
        ep_mean = float(np.mean(ep_returns))
        ep_p10 = float(np.percentile(ep_returns, 10))
        ep_p90 = float(np.percentile(ep_returns, 90))
    else:
        ep_mean, ep_p10, ep_p90 = float("nan"), float("nan"), float("nan")

    with torch.no_grad():
        policy_std = float(torch.exp(policy.log_std).mean().item())

    history["iter"].append(it)
    history["ep_ret_mean"].append(ep_mean)
    history["ep_ret_p10"].append(ep_p10)
    history["ep_ret_p90"].append(ep_p90)
    history["kl"].append(stats["kl"])
    history["surr_old"].append(stats["surr_old"])
    history["surr_new"].append(stats["surr_new"])
    history["step_frac"].append(stats["step_frac"])
    history["policy_std"].append(policy_std)

    # snapshot policy mean(action|x,v=0) over a grid
    if (it == 0) or (it % SNAPSHOT_EVERY == 0) or (it == TOTAL_ITERS - 1):
        obs_grid = np.stack([x_grid, np.zeros_like(x_grid)], axis=1)
        with torch.no_grad():
            mu, _ = policy.forward(torch.as_tensor(obs_grid, dtype=torch.float32, device=DEVICE))
        policy_snapshots.append({"iter": it, "mu": mu.squeeze(-1).cpu().numpy()})

    if (it + 1) % max(1, TOTAL_ITERS // 5) == 0 or it == 0:
        print(
            f"iter {it:03d} | ep_ret_mean {ep_mean:8.2f} | KL {stats['kl']:.4f} | "
            f"step_frac {stats['step_frac']:.3f} | std {policy_std:.3f}"
        )

print(f"Done in {time.time() - t0:.2f}s")


In [None]:
# Plotly: learning curves and trust-region diagnostics

iters = history["iter"]

fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=(
        "Episodic return (mean + 10/90 percentile band)",
        "Mean KL(old || new) per update (should be ≤ max_kl)",
        "Policy std (exp(log_std))",
    ),
)

# return band
fig.add_trace(
    go.Scatter(x=iters, y=history["ep_ret_p90"], mode="lines", line=dict(width=0), showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=iters,
        y=history["ep_ret_p10"],
        mode="lines",
        fill="tonexty",
        line=dict(width=0),
        name="p10–p90",
        opacity=0.25,
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(x=iters, y=history["ep_ret_mean"], mode="lines+markers", name="mean"),
    row=1,
    col=1,
)

# KL curve
fig.add_trace(
    go.Scatter(x=iters, y=history["kl"], mode="lines+markers", name="KL"),
    row=2,
    col=1,
)
fig.add_hline(y=MAX_KL, line_dash="dash", line_color="black", row=2, col=1)

# policy std
fig.add_trace(
    go.Scatter(x=iters, y=history["policy_std"], mode="lines+markers", name="std"),
    row=3,
    col=1,
)

fig.update_layout(height=850, title="TRPO learning diagnostics")
fig.update_xaxes(title_text="iteration", row=3, col=1)
fig.update_yaxes(title_text="return", row=1, col=1)
fig.update_yaxes(title_text="KL", row=2, col=1)
fig.update_yaxes(title_text="std", row=3, col=1)
fig.show()


In [None]:
# Plotly: how the policy mean changes over iterations

fig = go.Figure()
for snap in policy_snapshots:
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=snap["mu"],
            mode="lines",
            name=f"iter {snap['iter']}",
        )
    )

fig.update_layout(
    title="Policy mean action μ(x, v=0) snapshots",
    xaxis_title="position x (with v fixed at 0)",
    yaxis_title="mean action μ",
    height=450,
)
fig.show()


## 5) Stable-Baselines TRPO (reference implementation)

TRPO **does exist** in the original `stable-baselines` (TensorFlow) project via `stable_baselines.trpo_mpi.TRPO` (and is re-exported as `stable_baselines.TRPO` if `mpi4py` is installed).

Example usage (not executed here):

```python
import gym

# Requires the original stable-baselines (TensorFlow) + mpi4py.
from stable_baselines import TRPO
from stable_baselines.common.policies import MlpPolicy

env = gym.make("CartPole-v1")
model = TRPO(
    MlpPolicy,
    env,
    gamma=0.99,
    timesteps_per_batch=1024,
    max_kl=0.01,
    cg_iters=10,
    lam=0.98,
    entcoeff=0.0,
    cg_damping=1e-2,
    vf_stepsize=3e-4,
    vf_iters=3,
    verbose=1,
)
model.learn(total_timesteps=200_000)
```

Source used to verify signature and defaults:
- https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/trpo_mpi/trpo_mpi.py


### Stable-Baselines TRPO hyperparameters (what they mean)

From the upstream `TRPO.__init__` signature:

- `gamma` — discount factor \(\gamma\)
- `timesteps_per_batch` — on-policy batch size (number of environment steps collected before each TRPO update)
- `max_kl` — trust-region radius \(\delta\): target/upper bound on mean KL(old \|\| new)
- `cg_iters` — number of conjugate-gradient iterations used to approximately solve \(H x = g\)
- `lam` — GAE parameter \(\lambda\) controlling bias/variance tradeoff in advantages
- `entcoeff` — entropy bonus coefficient (encourages exploration by penalizing low entropy)
- `cg_damping` — adds a small multiple of the identity to the Fisher/Hessian-vector product for numerical stability
- `vf_stepsize` — learning rate for the value function optimizer
- `vf_iters` — number of value-function optimization iterations per update
- `tensorboard_log` / `full_tensorboard_log` — logging configuration
- `policy_kwargs` — extra arguments passed to the policy network constructor
- `seed` — RNG seed
- `n_cpu_tf_sess` — TensorFlow session CPU threading configuration

A good way to tune TRPO is to start with:
- `max_kl` around `0.01` and adjust up/down for faster learning vs stability
- `timesteps_per_batch` larger for smoother updates (at higher compute cost)
- `cg_damping` slightly larger if updates become numerically unstable
