# ACKTR from scratch (low-level PyTorch) — CartPole-v1

This notebook implements **ACKTR-style** policy optimization in **low-level PyTorch**:

- Actor update uses a **K-FAC preconditioner** (approx. Fisher inverse) + **trust-region clipping**.
- Critic is trained as a baseline (value function) with a simple first-order optimizer for stability.

We log training dynamics and visualize them with **Plotly**, including **episodic reward progression**.

Prereqs:

- PyTorch
- Gymnasium
- Plotly

Theory reference: see `00_overview.ipynb` in this folder.


## Notebook roadmap

1. Setup + environment
2. Actor/Critic networks
3. Rollout collection + GAE
4. K-FAC optimizer (Linear layers)
5. Training loop (ACKTR update)
6. Plotly diagnostics (reward + KL + losses)
7. Stable-Baselines ACKTR reference + hyperparameters


In [None]:
import random
import time

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

import gymnasium as gym

import torch
import torch.nn as nn
from torch.distributions import Categorical

pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

print('NumPy', np.__version__)
print('Pandas', pd.__version__)
print('Plotly', plotly.__version__)
print('Gymnasium', gym.__version__)
print('Torch', torch.__version__)


In [None]:
# --- Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Keep the implementation CPU-friendly and deterministic-ish.
DEVICE = torch.device('cpu')
print('DEVICE', DEVICE)


In [None]:
# --- Run configuration ---
FAST_RUN = True  # set False for a longer, smoother curve

# Environment
ENV_ID = 'CartPole-v1'

# Rollout / training
TOTAL_TIMESTEPS = 40_000 if FAST_RUN else 200_000
ROLLOUT_STEPS = 128

# Discounting / advantage
GAMMA = 0.99
GAE_LAMBDA = 0.95

# Loss weights
ENT_COEF = 0.00

# Critic optimizer
CRITIC_LR = 1e-3

# K-FAC / ACKTR knobs (actor)
ACTOR_LR = 0.10
KFAC_DAMPING = 0.03
KFAC_STATS_DECAY = 0.95
KFAC_CLIP = 0.01  # trust region / KL clip (see theory)
INVERSE_UPDATE_INTERVAL = 1

print('TOTAL_TIMESTEPS', TOTAL_TIMESTEPS)


## 1) Environment

`CartPole-v1` is a classic discrete-action benchmark:

- state $s \in \mathbb{R}^4$
- actions $a \in \{0,1\}$
- reward $r_t = 1$ per step until termination

It’s a good fit for a minimal ACKTR demonstration because the policy is a **categorical distribution**.


In [None]:
env = gym.make(ENV_ID)
obs_dim = int(env.observation_space.shape[0])
act_dim = int(env.action_space.n)

obs, _ = env.reset(seed=SEED)
print('obs_dim', obs_dim, 'act_dim', act_dim)
print('first obs', obs)


## 2) Actor–critic parameterization

We use two networks:

- Actor: logits for a categorical policy $\pi_\theta(a\mid s)$.
- Critic: a value function baseline $V_\phi(s)$.

The actor loss (policy gradient with entropy bonus) is:

$$
\mathcal{L}_{\pi}(\theta)
+    = -\mathbb{E}\left[\log \pi_\theta(a\mid s)\,\hat A(s,a)\right]
+    - \beta\,\mathbb{E}\left[\mathcal{H}(\pi_\theta(\cdot\mid s))\right].
+$$

The critic trains by regression to (bootstrapped) returns:

$$
\mathcal{L}_V(\phi) = \tfrac{1}{2}\,\mathbb{E}\left[(V_\phi(s) - \hat R)^2\right].
+$$


In [None]:
class Actor(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.Tanh())
            last = h
        self.net = nn.Sequential(*layers)
        self.logits = nn.Linear(last, act_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.logits(self.net(obs))


class Critic(nn.Module):
    def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.Tanh())
            last = h
        self.net = nn.Sequential(*layers)
        self.v = nn.Linear(last, 1)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.v(self.net(obs)).squeeze(-1)


actor = Actor(obs_dim, act_dim).to(DEVICE)
critic = Critic(obs_dim).to(DEVICE)

critic_optim = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)

print('actor params', sum(p.numel() for p in actor.parameters()))
print('critic params', sum(p.numel() for p in critic.parameters()))


## 3) Rollouts + GAE

We collect on-policy rollouts of length $T$ and compute **generalized advantage estimation** (GAE):

$$
\delta_t = r_t + \gamma (1-d_t) V(s_{t+1}) - V(s_t)
+$$

$$
\hat A_t = \sum_{l=0}^{\infty} (\gamma\lambda)^l\,\delta_{t+l}
+$$

with $d_t \in \{0,1\}$ indicating episode termination.


In [None]:
def compute_gae(rewards, values, dones, last_value, gamma: float, lam: float):
    """NumPy GAE for a single rollout segment."""
    T = len(rewards)
    advantages = np.zeros(T, dtype=np.float32)
    gae = 0.0
    for t in reversed(range(T)):
        next_value = last_value if t == T - 1 else values[t + 1]
        next_nonterminal = 1.0 - dones[t]
        delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
        gae = delta + gamma * lam * next_nonterminal * gae
        advantages[t] = gae
    returns = advantages + values
    return advantages, returns


## 4) K-FAC optimizer (Linear layers)

ACKTR replaces a vanilla gradient step with a **(preconditioned) natural gradient** step.

For the policy parameters $\theta$, the natural gradient direction is:

$$
\Delta\theta = F^{-1} g,\qquad g = \nabla_\theta J(\theta).
+$$

K-FAC approximates $F$ block-wise per layer using Kronecker factors:

$$
F_{\ell} \approx G_{\ell} \otimes A_{\ell},
+\quad A_{\ell}=\mathbb{E}[a a^\top],\quad G_{\ell}=\mathbb{E}[g g^\top].
+$$

For a linear layer, this yields the matrix-form update (with damping):

$$
\Delta W_{\ell} \approx G_{\ell}^{-1}\,\nabla_{W_{\ell}}\mathcal{L}\,A_{\ell}^{-1}.
+$$

We also apply a trust-region-style scaling so the policy does not change too much.


In [None]:
class KFACOptimizer:
    """Minimal K-FAC for nn.Linear modules (actor only).

    - Collects factor stats (A,G) via forward/backward hooks on a Fisher-like loss.
    - Preconditions parameter gradients with G^{-1} @ grad @ A^{-1}.
    - Scales the step using a trust-region clip.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float,
        damping: float,
        stats_decay: float,
        kfac_clip: float,
        inverse_update_interval: int = 1,
    ):
        self.model = model
        self.lr = float(lr)
        self.damping = float(damping)
        self.stats_decay = float(stats_decay)
        self.kfac_clip = float(kfac_clip)
        self.inverse_update_interval = int(inverse_update_interval)

        self._collect_stats = False
        self._step = 0

        self.modules = []
        self.state = {}

        for module in self.model.modules():
            if isinstance(module, nn.Linear):
                self.modules.append(module)
                self.state[module] = {
                    'A': None,
                    'G': None,
                    'A_inv': None,
                    'G_inv': None,
                }
                module.register_forward_hook(self._forward_hook)
                module.register_full_backward_hook(self._backward_hook)

    def set_collect_stats(self, collect: bool):
        self._collect_stats = bool(collect)

    def _forward_hook(self, module, inputs, output):
        if not self._collect_stats:
            return
        module._kfac_input = inputs[0].detach()

    def _backward_hook(self, module, grad_input, grad_output):
        if not self._collect_stats:
            return
        module._kfac_grad_output = grad_output[0].detach()

    @torch.no_grad()
    def update_stats(self):
        for module in self.modules:
            if not hasattr(module, '_kfac_input') or not hasattr(module, '_kfac_grad_output'):
                continue

            a = module._kfac_input
            g = module._kfac_grad_output

            if a.dim() != 2 or g.dim() != 2:
                continue

            batch = a.shape[0]
            ones = torch.ones(batch, 1, device=a.device, dtype=a.dtype)
            a_aug = torch.cat([a, ones], dim=1)

            A_new = (a_aug.t() @ a_aug) / batch
            G_new = (g.t() @ g) / batch

            st = self.state[module]
            if st['A'] is None:
                st['A'] = A_new
                st['G'] = G_new
            else:
                d = self.stats_decay
                st['A'] = d * st['A'] + (1 - d) * A_new
                st['G'] = d * st['G'] + (1 - d) * G_new

        self._step += 1
        if self._step % self.inverse_update_interval == 0:
            self._update_inverses()

    @torch.no_grad()
    def _update_inverses(self):
        for module in self.modules:
            st = self.state[module]
            if st['A'] is None or st['G'] is None:
                continue

            A = st['A'] + self.damping * torch.eye(st['A'].shape[0], device=st['A'].device, dtype=st['A'].dtype)
            G = st['G'] + self.damping * torch.eye(st['G'].shape[0], device=st['G'].device, dtype=st['G'].dtype)

            st['A_inv'] = torch.linalg.inv(A)
            st['G_inv'] = torch.linalg.inv(G)

    @torch.no_grad()
    def step(self):
        eps = 1e-8

        shs = 0.0  # proxy for g^T F^{-1} g
        updates = []

        for module in self.modules:
            st = self.state[module]
            if st['A_inv'] is None or st['G_inv'] is None:
                continue

            if module.weight.grad is None:
                continue
            if module.bias is None or module.bias.grad is None:
                continue

            grad_w = module.weight.grad
            grad_b = module.bias.grad
            grad_wb = torch.cat([grad_w, grad_b.unsqueeze(1)], dim=1)

            nat_wb = st['G_inv'] @ grad_wb @ st['A_inv']
            nat_w = nat_wb[:, :-1]
            nat_b = nat_wb[:, -1]

            shs += float((grad_w * nat_w).sum().item() + (grad_b * nat_b).sum().item())
            updates.append((module.weight, nat_w))
            updates.append((module.bias, nat_b))

        # Trust-region / KL clip: only scale down.
        # (Theory: predicted KL \approx 0.5 * alpha^2 * g^T F^{-1} g)
        nu = 1.0
        if shs > 0:
            predicted_kl = 0.5 * shs
            nu = float(min(1.0, np.sqrt(self.kfac_clip / (predicted_kl + eps))))
        else:
            predicted_kl = 0.0

        for param, nat_grad in updates:
            param.add_(nat_grad, alpha=-self.lr * nu)

        return {
            'shs': shs,
            'predicted_kl': predicted_kl,
            'nu': nu,
        }


actor_kfac = KFACOptimizer(
    actor,
    lr=ACTOR_LR,
    damping=KFAC_DAMPING,
    stats_decay=KFAC_STATS_DECAY,
    kfac_clip=KFAC_CLIP,
    inverse_update_interval=INVERSE_UPDATE_INTERVAL,
)


## 5) Training loop (ACKTR update)

Each update:

1. Collect $T$ on-policy transitions.
2. Compute $\hat A_t$ and $\hat R_t$ with GAE.
3. Update critic by minimizing $\mathcal{L}_V$.
4. For the actor:
   - build a Fisher-like loss (to collect K-FAC stats)
   - backprop that loss to update $A$ and $G$
   - backprop the policy loss and take a K-FAC-preconditioned step

We log:

- episodic returns
- actor loss, critic loss, entropy
- estimated KL (before/after update)
- trust-region scale factor $\nu$


In [None]:
num_updates = TOTAL_TIMESTEPS // ROLLOUT_STEPS
print('num_updates', num_updates)

obs, _ = env.reset(seed=SEED)

episode_return = 0.0
episode_len = 0
episode_returns = []
episode_lengths = []

logs = []
start = time.time()

for update in range(1, num_updates + 1):
    # --- Rollout buffers ---
    obs_buf = np.zeros((ROLLOUT_STEPS, obs_dim), dtype=np.float32)
    act_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.int64)
    rew_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
    done_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
    val_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)

    for t in range(ROLLOUT_STEPS):
        obs_buf[t] = obs

        obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        with torch.no_grad():
            logits = actor(obs_t)
            dist = Categorical(logits=logits)
            action = dist.sample()
            value = critic(obs_t)

        next_obs, reward, terminated, truncated, _ = env.step(int(action.item()))
        done = bool(terminated or truncated)

        act_buf[t] = int(action.item())
        rew_buf[t] = float(reward)
        done_buf[t] = float(done)
        val_buf[t] = float(value.item())

        episode_return += float(reward)
        episode_len += 1

        obs = next_obs
        if done:
            episode_returns.append(episode_return)
            episode_lengths.append(episode_len)
            episode_return = 0.0
            episode_len = 0
            obs, _ = env.reset()

    with torch.no_grad():
        if done_buf[-1] == 1.0:
            last_value = 0.0
        else:
            last_obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            last_value = float(critic(last_obs_t).item())

    advantages, returns = compute_gae(
        rewards=rew_buf,
        values=val_buf,
        dones=done_buf,
        last_value=last_value,
        gamma=GAMMA,
        lam=GAE_LAMBDA,
    )

    obs_batch = torch.tensor(obs_buf, dtype=torch.float32, device=DEVICE)
    act_batch = torch.tensor(act_buf, dtype=torch.int64, device=DEVICE)
    adv_batch = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
    ret_batch = torch.tensor(returns, dtype=torch.float32, device=DEVICE)

    adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-8)

    # --- Critic update (first-order) ---
    critic_optim.zero_grad(set_to_none=True)
    v_pred = critic(obs_batch)
    critic_loss = 0.5 * (ret_batch - v_pred).pow(2).mean()
    critic_loss.backward()
    critic_optim.step()

    # --- Actor update (ACKTR-style via K-FAC) ---
    actor_kfac.set_collect_stats(True)
    logits_old = actor(obs_batch).detach()
    dist_old = Categorical(logits=logits_old)

    logits = actor(obs_batch)
    dist = Categorical(logits=logits)
    logp = dist.log_prob(act_batch)
    entropy = dist.entropy().mean()

    actor_loss = -(logp * adv_batch.detach()).mean() - ENT_COEF * entropy

    # Fisher-like loss: E[-log pi(a|s)]
    fisher_loss = -logp.mean()

    actor.zero_grad(set_to_none=True)
    fisher_loss.backward(retain_graph=True)
    actor_kfac.set_collect_stats(False)
    actor_kfac.update_stats()

    actor.zero_grad(set_to_none=True)
    actor_loss.backward()
    step_info = actor_kfac.step()

    with torch.no_grad():
        logits_new = actor(obs_batch)
        dist_new = Categorical(logits=logits_new)
        approx_kl = torch.distributions.kl_divergence(dist_old, dist_new).mean().item()

    logs.append(
        {
            'update': update,
            'timesteps': update * ROLLOUT_STEPS,
            'episodes': len(episode_returns),
            'actor_loss': float(actor_loss.item()),
            'critic_loss': float(critic_loss.item()),
            'entropy': float(entropy.item()),
            'approx_kl': float(approx_kl),
            **step_info,
        }
    )

    if update % 25 == 0:
        recent = episode_returns[-20:]
        mean_20 = float(np.mean(recent)) if recent else float('nan')
        elapsed = time.time() - start
        print(
            f'update {update:4d}/{num_updates} | episodes {len(episode_returns):4d} '
            f'| mean_return_20 {mean_20:7.2f} | kl {approx_kl:9.2e} | nu {step_info["nu"]:7.3f} '
            f'| elapsed {elapsed:6.1f}s'
        )

env.close()


## 6) Plotly: learning dynamics

We visualize:

- episodic reward progression (raw + smoothed)
- estimated KL per update
- actor/critic losses
- trust-region scaling factor $\nu$


In [None]:
df_logs = pd.DataFrame(logs)
df_eps = pd.DataFrame({'episode': np.arange(1, len(episode_returns) + 1), 'return': episode_returns})
df_eps['return_smooth'] = df_eps['return'].rolling(window=20, min_periods=1).mean()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return'], mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return_smooth'], mode='lines', name='return (20-ep mean)', line=dict(width=3)))
fig.update_layout(
    title='Episodic reward progression (CartPole-v1)',
    xaxis_title='Episode',
    yaxis_title='Return',
    height=420,
)
fig.show()

fig2 = px.line(df_logs, x='timesteps', y=['approx_kl', 'predicted_kl'], title='KL diagnostics per update')
fig2.update_layout(height=380)
fig2.show()

fig3 = px.line(df_logs, x='timesteps', y=['actor_loss', 'critic_loss'], title='Losses per update')
fig3.update_layout(height=380)
fig3.show()

fig4 = px.line(df_logs, x='timesteps', y=['nu'], title='Trust-region scaling (nu)')
fig4.update_layout(height=320)
fig4.show()


## 7) Stable-Baselines ACKTR (reference)

We’ll include a reference snippet for the (TensorFlow-based) Stable-Baselines implementation of ACKTR, plus an explanation of its key hyperparameters.

This section is **reference only** — the implementation above is the main deliverable.


### Stable-Baselines usage (snippet)

```python
# pip install stable-baselines==2.*  (TensorFlow 1.x based)
from stable_baselines import ACKTR

model = ACKTR(
    policy='MlpPolicy',
    env='CartPole-v1',
    n_steps=20,
    gamma=0.99,
    ent_coef=0.01,
    vf_coef=0.25,
    vf_fisher_coef=1.0,
    learning_rate=0.25,
    max_grad_norm=0.5,
    kfac_clip=0.001,
    lr_schedule='linear',
    kfac_update=1,
    gae_lambda=None,
    verbose=1,
)

model.learn(total_timesteps=200_000)
```


### Hyperparameters (Stable-Baselines) explained

Stable-Baselines ("v2", TensorFlow 1.x) includes an `ACKTR` implementation (see `stable_baselines/acktr/acktr.py`). The constructor signature is:

```python
ACKTR(
  policy,
  env,
  gamma=0.99,
  n_steps=20,
  ent_coef=0.01,
  vf_coef=0.25,
  vf_fisher_coef=1.0,
  learning_rate=0.25,
  max_grad_norm=0.5,
  kfac_clip=0.001,
  lr_schedule='linear',
  async_eigen_decomp=False,
  kfac_update=1,
  gae_lambda=None,
  policy_kwargs=None,
  seed=None,
  n_cpu_tf_sess=1,
  # + logging/boilerplate args
)
```

**Core RL knobs**

- `gamma`: discount factor.
- `n_steps`: rollout length per environment before each update.
- `gae_lambda`: if not `None`, Stable-Baselines computes GAE with parameter $\lambda$; if `None`, it uses the classic advantage (no GAE).
- `ent_coef`: entropy bonus weight (encourages exploration).
- `vf_coef`: value loss weight in the joint loss.

**ACKTR / K-FAC + trust region knobs**

- `kfac_clip`: KL-based clip used inside the K-FAC optimizer (trust-region-like safeguard; called `clip_kl` in the underlying optimizer).
- `vf_fisher_coef`: weight on the **value-function Fisher loss**. In the Stable-Baselines code, the value Fisher is constructed by adding noise to the value output and backpropagating a Gaussian negative log-likelihood; this lets K-FAC build curvature stats for the critic.
- `learning_rate`: the step size used by the K-FAC optimizer (and scheduled by `lr_schedule`).
- `lr_schedule`: learning-rate schedule string (`'linear'`, `'constant'`, `'double_linear_con'`, `'middle_drop'`, `'double_middle_drop'`).
- `kfac_update`: update frequency for K-FAC statistics / eigen decompositions.
- `async_eigen_decomp`: compute eigen decompositions asynchronously (speed/throughput trade-off).
- `max_grad_norm`: global gradient clipping.

**Practical / reproducibility knobs**

- `policy`: policy network type (e.g. `MlpPolicy`, `CnnPolicy`, `CnnLstmPolicy`).
- `env`: Gym env instance or env id string.
- `policy_kwargs`: extra arguments forwarded to the policy.
- `seed`: seeds python/NumPy/TensorFlow RNGs.
- `n_cpu_tf_sess`: TensorFlow thread count (for determinism, set this to `1`).

Note: Stable-Baselines wires `ACKTR` into `kfac.KfacOptimizer(...)` with additional internal defaults (e.g. `momentum=0.9`, `epsilon=0.01`, `stats_decay=0.99`, `cold_iter=10`).
