# ACKTR (Actor–Critic using Kronecker-Factored Trust Region)

ACKTR is an **actor–critic** algorithm that updates the policy using a **trust region** (a KL-divergence constraint) and computes the update direction using **K-FAC** — a Kronecker-factored approximation of the curvature (Fisher / Gauss–Newton).

This notebook focuses on the *math and intuition* behind:

- **Trust regions** (why we constrain policy updates with KL)
- **Natural gradients** and the **Fisher information matrix**
- **K-FAC** (how we approximate the Fisher cheaply, layer-wise)

---

For a full **low-level PyTorch implementation** (including Plotly learning curves), see `01_acktr_from_scratch.ipynb`.


## Learning goals

By the end you should be able to:

- write the **KL-constrained** policy improvement objective used by trust-region methods
- derive the **natural gradient** step: $\Delta \theta \propto F^{-1} g$
- explain **why** $F$ is the right geometry for policy updates (invariance)
- explain **K-FAC** as a Kronecker factorization of per-layer curvature
- connect **ACKTR** to TRPO-style updates but with a cheaper curvature approximation


## 1) Trust regions: KL-constrained policy improvement

In policy gradient methods, we want to improve a policy $\pi_\theta(a\mid s)$ by changing parameters $\theta$.

A naive step $\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)$ can be **destructive**: it may change the policy too much and collapse performance.

Trust-region methods instead solve (conceptually):

$$
\max_{\theta'}\; \mathbb{E}_{s\sim d_{\pi_\theta},\;a\sim\pi_\theta}\left[\frac{\pi_{\theta'}(a\mid s)}{\pi_{\theta}(a\mid s)}\,\hat A_\theta(s,a)\right]
+    \quad\text{s.t.}\quad
+    \mathbb{E}_{s\sim d_{\pi_\theta}}\left[\mathrm{KL}\big(\pi_{\theta}(\cdot\mid s)\;\|\;\pi_{\theta'}(\cdot\mid s)\big)\right] \le \delta.
+$$

- $\hat A_\theta(s,a)$ is an advantage estimate.
- The constraint upper-bounds the **average KL** change (a trust region radius $\delta$).


## 2) Natural gradient and the Fisher information matrix

If we locally approximate the objective and the KL constraint, the constrained problem yields a step in the **natural gradient** direction.

Let:

- $g = \nabla_\theta J(\theta)$ be the (vanilla) policy gradient.
- $F$ be the **Fisher information matrix** of the policy:

$$
F = \mathbb{E}_{s\sim d_{\pi_\theta},\;a\sim\pi_\theta}\left[\nabla_\theta \log \pi_\theta(a\mid s)\;\nabla_\theta \log \pi_\theta(a\mid s)^\top\right].
+$$

The natural gradient direction is:

$$
\Delta\theta_{\mathrm{nat}} = F^{-1} g.
+$$

A trust region step size can be chosen by the quadratic KL approximation:

$$
\mathrm{KL}(\pi_{\theta}\|\pi_{\theta+\alpha\Delta\theta}) \approx \tfrac{1}{2}\,\alpha^2\, g^\top F^{-1} g.
+$$

So, one principled choice is:

$$
\alpha = \sqrt{\frac{2\delta}{g^\top F^{-1} g}}.
+$$


## 3) K-FAC: Kronecker-factored curvature (layer-wise)

Directly forming/inverting $F$ is infeasible for neural networks (millions of parameters).

**K-FAC** exploits structure of feed-forward layers to approximate curvature **per layer**.

For a linear layer:

$$
h = W a + b
+$$

where $a$ is the input activation and $h$ the pre-activation output. Define $g = \nabla_h \mathcal{L}$ (the backprop signal into the layer).

K-FAC approximates the layer Fisher block as a Kronecker product:

$$
F_W \approx \mathbb{E}[a a^\top] \otimes \mathbb{E}[g g^\top] \;=\; A \otimes G.
+$$

This makes the inverse cheap:

$$
(A \otimes G)^{-1} = A^{-1} \otimes G^{-1}.
+$$

The Kronecker-factored natural gradient for the weight matrix becomes:

$$
\Delta W \approx G^{-1}\,\nabla_W \mathcal{L}\,A^{-1}.
+$$

In practice we use:

- **exponential moving averages** for $A$ and $G$
- **damping** for numerical stability, e.g. $A \leftarrow A + \lambda I$, $G \leftarrow G + \lambda I$


## 4) What makes it “ACKTR”

ACKTR combines:

- **actor–critic** losses (policy loss + value loss + entropy bonus)
- a **K-FAC preconditioner** (approximate $F^{-1}$)
- a **trust region / KL clip** that scales steps to avoid large policy shifts

Conceptually the update looks like:

$$
\theta \leftarrow \theta - \alpha\,F^{-1}\,\nabla_\theta \mathcal{L}(\theta)
+$$

with $\alpha$ chosen (or clipped) so that the predicted KL stays below $\delta$.


In [None]:
import platform

import numpy as np
import plotly
import plotly.graph_objects as go
import os
import plotly.io as pio

try:
    import torch
    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e

try:
    import gymnasium as gym
    GYMNASIUM_AVAILABLE = True
except Exception:
    GYMNASIUM_AVAILABLE = False

pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

print('Python', platform.python_version())
print('NumPy', np.__version__)
print('Plotly', plotly.__version__)
print('Torch', torch.__version__ if TORCH_AVAILABLE else _TORCH_IMPORT_ERROR)
print('Gymnasium', gym.__version__ if GYMNASIUM_AVAILABLE else 'not installed')


In [None]:
# A tiny Plotly sketch: trust region + K-FAC (conceptual)
fig = go.Figure()

fig.add_shape(type='rect', x0=0.05, x1=0.45, y0=0.55, y1=0.85, line=dict(width=2))
fig.add_annotation(x=0.25, y=0.70, text='Policy update\n(natural gradient)', showarrow=False, font=dict(size=14))

fig.add_shape(type='rect', x0=0.55, x1=0.95, y0=0.55, y1=0.85, line=dict(width=2))
fig.add_annotation(x=0.75, y=0.70, text='K-FAC\n(Fisher approx.)', showarrow=False, font=dict(size=14))

fig.add_annotation(
    x=0.55,
    y=0.70,
    ax=0.45,
    ay=0.70,
    xref='paper',
    yref='paper',
    axref='paper',
    ayref='paper',
    text='',
    showarrow=True,
    arrowhead=3,
    arrowsize=1.2,
)

fig.add_annotation(
    x=0.50,
    y=0.40,
    text='$\\mathbb{E}[\\mathrm{KL}(\\pi_{old}\\|\\pi_{new})] \\le \\delta$\ncontrols step size',
    showarrow=False,
    font=dict(size=13),
)

fig.update_xaxes(visible=False, range=[0, 1])
fig.update_yaxes(visible=False, range=[0, 1])
fig.update_layout(title='ACKTR: trust region (KL) + K-FAC curvature', height=310)
fig.show()


## Next notebook

- `01_acktr_from_scratch.ipynb`: full ACKTR **low-level PyTorch** implementation on a Gymnasium environment + Plotly learning dynamics, including **episodic reward progression**.
