# Twin Delayed DDPG (TD3) â€” from scratch in PyTorch

TD3 (Fujimoto, van Hoof, Meger, 2018) is a deterministic actor-critic algorithm for **continuous control**.
It improves DDPG with three small but crucial modifications:

1. **Twin critics**: learn two Q-functions and use the *minimum* in the bootstrap target.
2. **Target policy smoothing**: add clipped noise to the target action when computing the target Q.
3. **Delayed policy updates**: update the actor (and target networks) less often than the critics.

In this notebook we:
- write the TD3 update equations precisely (LaTeX)
- implement TD3 at a **low level** in PyTorch (no RL libraries)
- train on a Gymnasium environment and **plot episodic returns** (Plotly)

---

## Learning goals

- Understand why DDPG overestimates and how TD3 fixes it
- Implement replay buffer + target networks + twin critics + delayed updates
- Train a working agent and visualize learning curves


## Prerequisites

- Basic PyTorch (modules, optimizers, autograd)
- Q-learning / bootstrapping and the Bellman equation
- Actor-critic idea (policy + value function)
- Continuous action spaces (e.g., Pendulum)


In [None]:
import copy
import time

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import plotly.io as pio

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e

try:
    import gymnasium as gym

    GYM_AVAILABLE = True
except Exception as e:
    GYM_AVAILABLE = False
    _GYM_IMPORT_ERROR = e


pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

SEED = 42
rng = np.random.default_rng(SEED)


In [None]:
# --- Run configuration ---
FAST_RUN = True

ENV_ID = 'Pendulum-v1'

TOTAL_TIMESTEPS = 10_000 if FAST_RUN else 200_000
START_STEPS = 1_000 if FAST_RUN else 10_000
UPDATE_AFTER = 1_000 if FAST_RUN else 10_000

BATCH_SIZE = 256
BUFFER_SIZE = 200_000

# TD3 hyperparameters
GAMMA = 0.99
TAU = 0.005

ACTOR_LR = 1e-3
CRITIC_LR = 1e-3

POLICY_DELAY = 2
TARGET_POLICY_NOISE = 0.2
TARGET_NOISE_CLIP = 0.5

EXPLORATION_NOISE = 0.1

HIDDEN_SIZES = (256, 256)


## 1) TD3: the exact updates (twin critics + delayed actor)

We use:

- deterministic policy (actor)  $a = \pi_\phi(s)$
- **two** critics  $Q_{\theta_1}(s,a)$ and $Q_{\theta_2}(s,a)$
- target networks  $(\phi', \theta_1', \theta_2')$ updated by Polyak averaging

Given a transition $(s,a,r,s',\text{terminal})$ sampled from the replay buffer, TD3 builds the target in three steps.

### 1. Target policy smoothing

TD3 does *not* evaluate the target critics at the raw target action $\pi_{\phi'}(s')$.
Instead it adds clipped Gaussian noise:

$$
\tilde a = \pi_{\phi'}(s') + \epsilon,\qquad
\epsilon \sim \mathrm{clip}(\mathcal N(0, \sigma^2),\,-c,\,+c)
$$

$$
\tilde a \leftarrow \mathrm{clip}(\tilde a, a_{\min}, a_{\max})
$$

Intuition: this makes the target Q-value less sensitive to small action errors and prevents the critic from exploiting sharp, unrealistic peaks in $Q$.

### 2. Twin critics (min target)

Compute both target Q-values and take the **minimum**:

$$
y = r + \gamma (1-\text{terminal})\,\min\Big( Q_{\theta_1'}(s', \tilde a),\; Q_{\theta_2'}(s', \tilde a)\Big)
$$

Each critic minimizes an MSE to this same target:

$$
L(\theta_i) = \mathbb E\big[(Q_{\theta_i}(s,a)-y)^2\big],\qquad i\in\{1,2\}
$$

Taking the minimum is a simple bias-reduction trick: it turns DDPG's optimistic
target into a more conservative estimate, reducing overestimation error.

### 3. Delayed policy updates

The critics are updated **every gradient step**.
The actor is updated only every $d$ critic updates (e.g. $d=2$):

$$
\max_\phi\; J(\phi) = \mathbb E\big[ Q_{\theta_1}(s, \pi_\phi(s)) \big]
$$

In code we minimize the negative:

$$
L_\pi(\phi) = -\mathbb E\big[ Q_{\theta_1}(s, \pi_\phi(s)) \big]
$$

When (and only when) we update the actor, we also update **all** target networks with Polyak averaging:

$$
\theta_i' \leftarrow \tau \theta_i + (1-\tau)\theta_i',\qquad
\phi' \leftarrow \tau \phi + (1-\tau)\phi'
$$

Delaying the actor update lets the critics move closer to their fixed point, so the actor sees a less noisy / less biased gradient.


## 2) Implementation roadmap

We will implement TD3 as a small set of building blocks:

1. Gymnasium environment helpers (reset/step API differences)
2. Replay buffer (NumPy storage, PyTorch sampling)
3. Actor network $\pi_\phi(s)$
4. Twin critic networks $Q_{\theta_1}(s,a), Q_{\theta_2}(s,a)$
5. TD3 update step (critic update every step, actor + target update every `POLICY_DELAY` steps)
6. Training loop + Plotly learning curve


In [None]:
def set_global_seeds(seed: int) -> None:
    np.random.seed(seed)
    if TORCH_AVAILABLE:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def env_reset(env, seed=None):
    out = env.reset(seed=seed) if seed is not None else env.reset()
    if isinstance(out, tuple) and len(out) == 2:
        obs, info = out
        return obs, info
    return out, {}


def env_step(env, action):
    out = env.step(action)
    if isinstance(out, tuple) and len(out) == 5:
        next_obs, reward, terminated, truncated, info = out
        done = bool(terminated or truncated)
        terminal = bool(terminated)  # time-limit truncation is not a terminal state
        return next_obs, float(reward), done, terminal, info
    if isinstance(out, tuple) and len(out) == 4:
        next_obs, reward, done, info = out
        terminal = bool(done)
        return next_obs, float(reward), bool(done), terminal, info
    raise ValueError('Unexpected env.step(...) output format')


In [None]:
if not TORCH_AVAILABLE:
    raise RuntimeError(f'PyTorch import failed: {_TORCH_IMPORT_ERROR}')

if not GYM_AVAILABLE:
    raise RuntimeError(f'Gymnasium import failed: {_GYM_IMPORT_ERROR}')


set_global_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

env = gym.make(ENV_ID)
obs, _ = env_reset(env, seed=SEED)

obs_dim = int(np.prod(env.observation_space.shape))
act_dim = int(np.prod(env.action_space.shape))

action_low = env.action_space.low.astype(np.float32)
action_high = env.action_space.high.astype(np.float32)

print('env:', ENV_ID)
print('obs_dim:', obs_dim)
print('act_dim:', act_dim)
print('action_low:', action_low)
print('action_high:', action_high)
print('device:', device)


In [None]:
class ReplayBuffer:
    def __init__(self, obs_dim: int, act_dim: int, size: int, seed: int, device: torch.device):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros((size, 1), dtype=np.float32)
        self.done_buf = np.zeros((size, 1), dtype=np.float32)

        self.max_size = int(size)
        self.ptr = 0
        self.size = 0

        self.rng = np.random.default_rng(seed)
        self.device = device

    def add(self, obs, act, rew: float, next_obs, terminal: bool) -> None:
        self.obs_buf[self.ptr] = np.asarray(obs, dtype=np.float32).reshape(-1)
        self.next_obs_buf[self.ptr] = np.asarray(next_obs, dtype=np.float32).reshape(-1)
        self.act_buf[self.ptr] = np.asarray(act, dtype=np.float32).reshape(-1)
        self.rew_buf[self.ptr] = float(rew)
        self.done_buf[self.ptr] = 1.0 if terminal else 0.0

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size: int):
        if self.size < batch_size:
            raise ValueError(f'Not enough samples: size={self.size}, batch_size={batch_size}')
        idxs = self.rng.integers(0, self.size, size=batch_size)

        obs = torch.as_tensor(self.obs_buf[idxs], device=self.device)
        act = torch.as_tensor(self.act_buf[idxs], device=self.device)
        rew = torch.as_tensor(self.rew_buf[idxs], device=self.device)
        next_obs = torch.as_tensor(self.next_obs_buf[idxs], device=self.device)
        done = torch.as_tensor(self.done_buf[idxs], device=self.device)

        return obs, act, rew, next_obs, done


In [None]:
def mlp(layer_sizes, activation=nn.ReLU, output_activation=nn.Identity):
    layers = []
    for i in range(len(layer_sizes) - 1):
        act = activation if i < len(layer_sizes) - 2 else output_activation
        layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
        layers.append(act())
    return nn.Sequential(*layers)


class Actor(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes, action_low, action_high):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.ReLU, output_activation=nn.Identity)

        action_low_t = torch.as_tensor(action_low, dtype=torch.float32)
        action_high_t = torch.as_tensor(action_high, dtype=torch.float32)

        self.register_buffer('action_scale', (action_high_t - action_low_t) / 2.0)
        self.register_buffer('action_bias', (action_high_t + action_low_t) / 2.0)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        a = torch.tanh(self.net(obs))
        return a * self.action_scale + self.action_bias


class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
        super().__init__()
        self.net = mlp([obs_dim + act_dim, *hidden_sizes, 1], activation=nn.ReLU, output_activation=nn.Identity)

    def forward(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        x = torch.cat([obs, act], dim=-1)
        return self.net(x)


class TwinCritic(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
        super().__init__()
        self.q1 = QNetwork(obs_dim, act_dim, hidden_sizes)
        self.q2 = QNetwork(obs_dim, act_dim, hidden_sizes)

    def forward(self, obs: torch.Tensor, act: torch.Tensor):
        return self.q1(obs, act), self.q2(obs, act)

    def q1_only(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        return self.q1(obs, act)


In [None]:
class TD3Agent:
    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        hidden_sizes,
        action_low,
        action_high,
        device: torch.device,
        gamma: float = 0.99,
        tau: float = 0.005,
        actor_lr: float = 1e-3,
        critic_lr: float = 1e-3,
        policy_delay: int = 2,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
    ):
        self.device = device

        self.actor = Actor(obs_dim, act_dim, hidden_sizes, action_low, action_high).to(device)
        self.actor_target = copy.deepcopy(self.actor).to(device)

        self.critic = TwinCritic(obs_dim, act_dim, hidden_sizes).to(device)
        self.critic_target = copy.deepcopy(self.critic).to(device)

        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = float(gamma)
        self.tau = float(tau)

        self.policy_delay = int(policy_delay)
        self.target_policy_noise = float(target_policy_noise)
        self.target_noise_clip = float(target_noise_clip)

        self.action_low_t = torch.as_tensor(action_low, dtype=torch.float32, device=device)
        self.action_high_t = torch.as_tensor(action_high, dtype=torch.float32, device=device)

        self.total_it = 0

        # Targets start identical to online nets
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

    @torch.no_grad()
    def select_action(self, obs, noise_scale: float = 0.0):
        obs_t = torch.as_tensor(np.asarray(obs, dtype=np.float32).reshape(1, -1), device=self.device)
        action = self.actor(obs_t).cpu().numpy().reshape(-1)
        if noise_scale and noise_scale > 0:
            action = action + np.random.normal(0.0, noise_scale, size=action.shape).astype(np.float32)
        action = np.clip(action, self.action_low_t.cpu().numpy(), self.action_high_t.cpu().numpy())
        return action

    def _soft_update_(self, source: nn.Module, target: nn.Module) -> None:
        with torch.no_grad():
            for p, p_targ in zip(source.parameters(), target.parameters()):
                p_targ.data.mul_(1.0 - self.tau)
                p_targ.data.add_(self.tau * p.data)

    def train_step(self, replay_buffer: ReplayBuffer, batch_size: int):
        self.total_it += 1

        obs, act, rew, next_obs, done = replay_buffer.sample(batch_size)

        # --- Critic update (every step) ---
        with torch.no_grad():
            noise = torch.randn_like(act) * self.target_policy_noise
            noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)

            next_action = self.actor_target(next_obs) + noise
            next_action = torch.max(torch.min(next_action, self.action_high_t), self.action_low_t)

            target_q1, target_q2 = self.critic_target(next_obs, next_action)
            target_q = torch.min(target_q1, target_q2)

            y = rew + (1.0 - done) * self.gamma * target_q

        current_q1, current_q2 = self.critic(obs, act)
        critic_loss = F.mse_loss(current_q1, y) + F.mse_loss(current_q2, y)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        info = {'critic_loss': float(critic_loss.item())}

        # --- Delayed actor + target updates ---
        if self.total_it % self.policy_delay == 0:
            actor_loss = -self.critic.q1_only(obs, self.actor(obs)).mean()
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            info['actor_loss'] = float(actor_loss.item())

            self._soft_update_(self.critic, self.critic_target)
            self._soft_update_(self.actor, self.actor_target)

        return info


In [None]:
replay = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=BUFFER_SIZE, seed=SEED, device=device)
agent = TD3Agent(
    obs_dim=obs_dim,
    act_dim=act_dim,
    hidden_sizes=HIDDEN_SIZES,
    action_low=action_low,
    action_high=action_high,
    device=device,
    gamma=GAMMA,
    tau=TAU,
    actor_lr=ACTOR_LR,
    critic_lr=CRITIC_LR,
    policy_delay=POLICY_DELAY,
    target_policy_noise=TARGET_POLICY_NOISE,
    target_noise_clip=TARGET_NOISE_CLIP,
)

episode_returns = []
episode_lengths = []

critic_losses = []
actor_losses = []

obs, _ = env_reset(env, seed=SEED)
ep_return = 0.0
ep_len = 0

t0 = time.time()

for t in range(TOTAL_TIMESTEPS):
    if t < START_STEPS:
        action = env.action_space.sample()
    else:
        action = agent.select_action(obs, noise_scale=EXPLORATION_NOISE)

    next_obs, reward, done, terminal, _info = env_step(env, action)

    replay.add(obs, action, reward, next_obs, terminal)

    obs = next_obs
    ep_return += reward
    ep_len += 1

    if t >= UPDATE_AFTER:
        train_info = agent.train_step(replay, batch_size=BATCH_SIZE)
        critic_losses.append(train_info['critic_loss'])
        if 'actor_loss' in train_info:
            actor_losses.append(train_info['actor_loss'])

    if done:
        episode_returns.append(ep_return)
        episode_lengths.append(ep_len)

        if len(episode_returns) % 5 == 0 or not FAST_RUN:
            elapsed = time.time() - t0
            print(
                f"Episode {len(episode_returns):4d} | return {ep_return:9.1f} | len {ep_len:3d} | "
                f"t {t + 1:6d}/{TOTAL_TIMESTEPS} | elapsed {elapsed:6.1f}s"
            )

        obs, _ = env_reset(env)
        ep_return = 0.0
        ep_len = 0

env.close()

print('episodes:', len(episode_returns))
print('last return:', episode_returns[-1] if episode_returns else None)


In [None]:
# Plot episodic returns
df = pd.DataFrame(
    {
        'episode': np.arange(1, len(episode_returns) + 1),
        'return': episode_returns,
        'length': episode_lengths,
    }
)

window = min(10, max(1, len(df)))
df['return_ma'] = df['return'].rolling(window=window, min_periods=1).mean()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df['episode'], y=df['return'], mode='lines+markers', name='Return'))
fig.add_trace(go.Scatter(x=df['episode'], y=df['return_ma'], mode='lines', name=f'{window}-episode MA'))
fig.update_layout(
    title=f'TD3 training on {ENV_ID}: episodic return',
    xaxis_title='Episode',
    yaxis_title='Return',
)
fig.show()


## Notes, diagnostics, and common pitfalls

- **Terminal masking**: for time-limit truncation, you typically still bootstrap, so we mask only `terminated` (Gymnasium) rather than `truncated`.
- **Twin critics**: the key is using the *minimum* only in the bootstrap target $y$ (not necessarily everywhere).
- **Delayed updates**: do not update the actor every step; it should be updated every `POLICY_DELAY` critic updates.
- **Target policy smoothing**: the noise added to *target actions* is separate from exploration noise.
- **Exploration**: TD3 is deterministic; you must add noise to actions during data collection.


## Stable-Baselines TD3 (reference implementation)

Stable-Baselines3 (SB3) includes a PyTorch TD3 implementation:
https://stable-baselines3.readthedocs.io/en/master/modules/td3.html

This is useful as a reference and a quick way to validate your intuition against a well-tested baseline.

If you want to run it locally:

```bash
pip install stable-baselines3
```

If you have SB3 installed, a minimal training script looks like:

```python
import gymnasium as gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise

env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[-1]

action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = TD3(
    policy='MlpPolicy',
    env=env,
    action_noise=action_noise,
    verbose=1,
)
model.learn(total_timesteps=100_000)
```

At the end of this notebook we summarize SB3's TD3 hyperparameters.


## Stable-Baselines3 TD3 hyperparameters (glossary + defaults)

Web research source: https://stable-baselines3.readthedocs.io/en/master/modules/td3.html

Constructor signature (defaults):

`TD3(policy, env, learning_rate=0.001, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, policy_delay=2, target_policy_noise=0.2, target_noise_clip=0.5, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)`

Glossary:

- `policy`: policy class/name (e.g., `MlpPolicy`, `CnnPolicy`).
- `env`: environment instance or env ID string.
- `learning_rate` (default `1e-3`): Adam learning rate (SB3 uses the same LR for actor and critics).
- `buffer_size` (default `1_000_000`): replay buffer capacity.
- `learning_starts` (default `100`): number of environment steps collected before training begins.
- `batch_size` (default `256`): mini-batch size sampled from replay.
- `tau` (default `0.005`): Polyak coefficient $\tau$ for target network updates.
- `gamma` (default `0.99`): discount factor $\gamma$.
- `train_freq` (default `1`): how often to train (steps), or a tuple like `(n, 'step')` / `(n, 'episode')`.
- `gradient_steps` (default `1`): gradient updates per training iteration.
- `action_noise` (default `None`): exploration noise used when collecting data (e.g. Gaussian or OU noise).
- `policy_delay` (default `2`): actor/target update period $d$ (critics update every step).
- `target_policy_noise` (default `0.2`): $\sigma$ in target policy smoothing.
- `target_noise_clip` (default `0.5`): $c$ in target policy smoothing (clip range).
- `replay_buffer_class` (default `None`): custom replay buffer class.
- `replay_buffer_kwargs` (default `None`): kwargs passed to the replay buffer.
- `optimize_memory_usage` (default `False`): memory-efficient replay buffer variant.
- `n_steps` (default `1`): n-step returns (when >1 uses an n-step replay buffer).
- `stats_window_size` (default `100`): logging window size (episodes averaged).
- `tensorboard_log` (default `None`): TensorBoard log directory.
- `policy_kwargs` (default `None`): policy/network architecture options.
- `verbose` (default `0`): verbosity (0/1/2).
- `seed` (default `None`): RNG seed.
- `device` (default `'auto'`): device selection (CPU/GPU).
- `_init_setup_model` (default `True`): whether to build networks at init.


## References

- Fujimoto, van Hoof, Meger (2018): *Addressing Function Approximation Error in Actor-Critic Methods* (TD3)
- Stable-Baselines3 docs / source code (TD3)
