# Soft Actor-Critic (SAC) for Continuous Action Spaces (low-level PyTorch)

Soft Actor-Critic (SAC) is an **off-policy** actor-critic algorithm that learns a **stochastic** policy by maximizing expected return *and* **entropy**. The entropy term makes exploration a first-class objective, which tends to improve stability and robustness.

In this notebook you will:
- Derive the **maximum-entropy** objective and the SAC losses
- Implement SAC **from scratch in PyTorch** (replay buffer, twin critics, target networks, squashed Gaussian policy, temperature tuning)
- Train on a small **continuous-control** environment (no downloads required)
- Visualize **episodic rewards**, **policy entropy**, and **Q-values** with **Plotly**

---

## Learning goals
- Understand why SAC maximizes entropy and how that changes the Bellman backup
- Implement the core SAC update rules directly in PyTorch
- Diagnose learning using reward/entropy/Q plots


## Notebook roadmap
1. Maximum-entropy RL: objective + key equations (LaTeX)
2. Environment (continuous actions, offline-friendly)
3. From scratch (PyTorch): replay buffer + networks
4. From scratch (PyTorch): SAC update (critic, actor, temperature)
5. Training loop + Plotly diagnostics (entropy, Q-values, rewards)
6. Stable-Baselines3 SAC: reference code + hyperparameters (final section)


In [None]:
import math
import sys
import time
from dataclasses import dataclass

import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

np.set_printoptions(precision=4, suppress=True)

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.distributions import Normal

    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    print("PyTorch import failed:", e)


In [None]:
SEED = 42
rng = np.random.default_rng(SEED)

if TORCH_AVAILABLE:
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = (
    torch.device("cuda")
    if TORCH_AVAILABLE and torch.cuda.is_available()
    else torch.device("cpu")
)

print("Python:", sys.version.split()[0])
print("NumPy:", np.__version__)
import plotly

print("Plotly:", plotly.__version__)
if TORCH_AVAILABLE:
    print("PyTorch:", torch.__version__)
    print("Device:", device)


## 1) Maximum-entropy RL (entropy maximization)

Classic RL maximizes expected return:

$$
\max_\pi\; \mathbb{E}_{\tau\sim\pi}\Big[\sum_{t=0}^{T-1} \gamma^t\, r(s_t,a_t)\Big].
$$

SAC instead uses a **maximum-entropy** objective:

$$
\max_\pi\; \mathbb{E}_{\tau\sim\pi}\Big[\sum_{t=0}^{T-1} \gamma^t\,\big(r(s_t,a_t) + \alpha\,\mathcal{H}(\pi(\cdot\mid s_t))\big)\Big],
$$

where the (differential) entropy of a continuous policy is

$$
\mathcal{H}(\pi(\cdot\mid s)) = -\mathbb{E}_{a\sim\pi(\cdot\mid s)}\big[\log \pi(a\mid s)\big].
$$

- Larger entropy means a **broader** action distribution (more exploration).
- $\alpha>0$ is the **temperature**: it trades off reward vs. entropy.


### Soft value functions

SAC defines a **soft** value function:

$$
V_\pi(s) = \mathbb{E}_{a\sim\pi}\big[Q_\pi(s,a) - \alpha\,\log\pi(a\mid s)\big].
$$

This leads to a soft Bellman backup for $Q$:

$$
Q_\pi(s,a) = r(s,a) + \gamma\,\mathbb{E}_{s'\sim P}\big[V_\pi(s')\big].
$$

Intuition: the next-state value is high if we can both (1) get high Q and (2) keep the policy **stochastic** (high entropy).


### Losses used in SAC (twin critics + temperature tuning)

SAC typically uses **two** Q-networks $Q_{\theta_1}, Q_{\theta_2}$ to reduce overestimation bias.

1) **Critic target** (with target networks and a sampled next action $a'\sim\pi_\phi(\cdot\mid s')$):

$$
y = r + \gamma(1-d)\Big(\min_i Q_{\bar\theta_i}(s', a') - \alpha\,\log \pi_\phi(a'\mid s')\Big).
$$

2) **Critic loss**:

$$
\mathcal{L}_Q = \mathbb{E}\big[(Q_{\theta_1}(s,a)-y)^2 + (Q_{\theta_2}(s,a)-y)^2\big].
$$

3) **Actor loss** (reparameterization trick):

$$
\mathcal{L}_\pi = \mathbb{E}_{s\sim\mathcal{D},\,a\sim\pi_\phi}\big[\alpha\,\log\pi_\phi(a\mid s) - \min_i Q_{\theta_i}(s,a)\big].
$$

4) **Automatic temperature** (optional): learn $\alpha$ to match a target entropy $\mathcal{H}_{\text{target}}$.

A common form (optimize $\log\alpha$) is:

$$
\mathcal{L}_\alpha = \mathbb{E}_{a\sim\pi_\phi}\big[-\alpha\,(\log\pi_\phi(a\mid s) + \mathcal{H}_{\text{target}})\big].
$$


In [None]:
class ContinuousPointMass1DEnv:
    """A tiny continuous-control environment (no external deps).

    State:  [position, velocity]
    Action: acceleration in [-action_limit, +action_limit]
    Reward: negative quadratic cost (stabilize at 0)
    """

    def __init__(
        self,
        dt: float = 0.05,
        max_steps: int = 200,
        action_limit: float = 2.0,
        seed: int = 42,
    ):
        self.dt = float(dt)
        self.max_steps = int(max_steps)
        self._action_limit = float(action_limit)
        self.action_low = np.array([-self._action_limit], dtype=np.float32)
        self.action_high = np.array([+self._action_limit], dtype=np.float32)
        self.obs_dim = 2
        self.act_dim = 1
        self.rng = np.random.default_rng(seed)
        self.t = 0
        self.state = np.zeros(self.obs_dim, dtype=np.float32)

    def reset(self, seed: int | None = None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.t = 0
        pos = self.rng.uniform(-1.0, 1.0)
        vel = self.rng.uniform(-1.0, 1.0)
        self.state = np.array([pos, vel], dtype=np.float32)
        return self.state.copy(), {}

    def step(self, action):
        a = float(np.asarray(action, dtype=np.float32).reshape(-1)[0])
        a = float(np.clip(a, self.action_low[0], self.action_high[0]))

        pos, vel = float(self.state[0]), float(self.state[1])
        vel = vel + self.dt * a
        pos = pos + self.dt * vel
        self.state = np.array([pos, vel], dtype=np.float32)

        reward = -(pos * pos + 0.1 * vel * vel + 0.001 * a * a)

        self.t += 1
        terminated = False
        truncated = self.t >= self.max_steps
        info = {}
        return self.state.copy(), float(reward), terminated, truncated, info


def reset_env(env, seed: int | None = None) -> np.ndarray:
    out = env.reset(seed=seed) if seed is not None else env.reset()
    if isinstance(out, tuple) and len(out) == 2:
        obs, _info = out
    else:
        obs = out
    return np.asarray(obs, dtype=np.float32)


def step_env(env, action: np.ndarray):
    out = env.step(action)
    if isinstance(out, tuple) and len(out) == 5:
        obs, reward, terminated, truncated, info = out
        done = bool(terminated) or bool(truncated)
    else:
        obs, reward, done, info = out
    return np.asarray(obs, dtype=np.float32), float(reward), bool(done), info


def make_env(seed: int = 42, prefer_gym_pendulum: bool = False):
    if prefer_gym_pendulum:
        for pkg in ("gymnasium", "gym"):
            try:
                gym = __import__(pkg)
                env = gym.make("Pendulum-v1")
                reset_env(env, seed)
                name = f"{pkg}:Pendulum-v1"
                action_low = env.action_space.low.astype(np.float32)
                action_high = env.action_space.high.astype(np.float32)
                obs_dim = int(env.observation_space.shape[0])
                act_dim = int(env.action_space.shape[0])
                return env, name, obs_dim, act_dim, action_low, action_high
            except Exception:
                pass

    env = ContinuousPointMass1DEnv(seed=seed)
    reset_env(env, seed)
    name = "custom:ContinuousPointMass1DEnv"
    return env, name, env.obs_dim, env.act_dim, env.action_low, env.action_high


In [None]:
env, env_name, obs_dim, act_dim, act_low, act_high = make_env(seed=SEED, prefer_gym_pendulum=False)
print("Env:", env_name)
print("obs_dim:", obs_dim, "act_dim:", act_dim)
print("action bounds:", act_low, act_high)

# Quick sanity rollout (random actions)
obs = reset_env(env, seed=SEED)
positions, velocities, rewards = [], [], []
for t in range(80):
    a = rng.uniform(act_low, act_high)
    obs, r, done, _ = step_env(env, a)
    positions.append(float(obs[0]))
    velocities.append(float(obs[1]))
    rewards.append(float(r))
    if done:
        break

fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(y=positions, mode="lines", name="position"), row=1, col=1)
fig.add_trace(go.Scatter(y=velocities, mode="lines", name="velocity"), row=2, col=1)
fig.add_trace(go.Scatter(y=rewards, mode="lines", name="reward"), row=3, col=1)
fig.update_layout(height=650, title="Sanity check rollout (random actions)")
fig.update_yaxes(title_text="pos", row=1, col=1)
fig.update_yaxes(title_text="vel", row=2, col=1)
fig.update_yaxes(title_text="r", row=3, col=1)
fig.update_xaxes(title_text="step", row=3, col=1)
fig

## 2) Replay buffer (from scratch)

SAC is **off-policy**, so it stores transitions in a replay buffer and samples mini-batches for gradient updates.


In [None]:
class ReplayBuffer:
    def __init__(self, obs_dim: int, act_dim: int, size: int, seed: int = 0):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros((size, 1), dtype=np.float32)
        self.done_buf = np.zeros((size, 1), dtype=np.float32)

        self.max_size = int(size)
        self.ptr = 0
        self.size = 0
        self.rng = np.random.default_rng(seed)

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.next_obs_buf[self.ptr] = next_obs
        self.done_buf[self.ptr] = float(done)

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size: int, device: torch.device):
        idx = self.rng.integers(0, self.size, size=batch_size)
        batch = dict(
            obs=torch.as_tensor(self.obs_buf[idx], device=device),
            act=torch.as_tensor(self.act_buf[idx], device=device),
            rew=torch.as_tensor(self.rew_buf[idx], device=device),
            next_obs=torch.as_tensor(self.next_obs_buf[idx], device=device),
            done=torch.as_tensor(self.done_buf[idx], device=device),
        )
        return batch


## 3) Networks (from scratch)

We implement:
- A **squashed Gaussian** actor $\pi_\phi(a\mid s)$ using `tanh` to respect action bounds
- Two critics $Q_{\theta_1}$ and $Q_{\theta_2}$

The key low-level detail is the **log-prob correction** for the `tanh` squash (change of variables).


In [None]:
def mlp(sizes, activation=nn.ReLU, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


class SquashedGaussianActor(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        act_low: np.ndarray,
        act_high: np.ndarray,
        hidden_sizes=(256, 256),
        log_std_min: float = -20.0,
        log_std_max: float = 2.0,
    ):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes], activation=nn.ReLU, output_activation=nn.ReLU)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)

        self.log_std_min = float(log_std_min)
        self.log_std_max = float(log_std_max)

        act_low = np.asarray(act_low, dtype=np.float32)
        act_high = np.asarray(act_high, dtype=np.float32)
        action_scale = (act_high - act_low) / 2.0
        action_bias = (act_high + act_low) / 2.0
        self.register_buffer("action_scale", torch.as_tensor(action_scale))
        self.register_buffer("action_bias", torch.as_tensor(action_bias))

    def forward(self, obs: torch.Tensor):
        h = self.net(obs)
        mu = self.mu_layer(h)
        log_std = self.log_std_layer(h)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        return mu, std

    def sample(self, obs: torch.Tensor):
        mu, std = self(obs)
        dist = Normal(mu, std)
        u = dist.rsample()  # reparameterization
        a = torch.tanh(u)

        action = a * self.action_scale + self.action_bias

        # Log prob with tanh + scaling correction (change of variables)
        log_prob_u = dist.log_prob(u).sum(dim=-1, keepdim=True)
        log_det = (
            torch.log(self.action_scale + 1e-8)
            + torch.log(1.0 - a.pow(2) + 1e-6)
        ).sum(dim=-1, keepdim=True)
        log_prob = log_prob_u - log_det

        mu_action = torch.tanh(mu) * self.action_scale + self.action_bias
        return action, log_prob, mu_action

    def act(self, obs: np.ndarray, deterministic: bool = False):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.action_scale.device).unsqueeze(0)
        with torch.no_grad():
            if deterministic:
                mu, _std = self(obs_t)
                a = torch.tanh(mu) * self.action_scale + self.action_bias
            else:
                a, _logp, _mu_a = self.sample(obs_t)
        return a.squeeze(0).cpu().numpy()


class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(256, 256)):
        super().__init__()
        self.q = mlp([obs_dim + act_dim, *hidden_sizes, 1], activation=nn.ReLU)

    def forward(self, obs: torch.Tensor, act: torch.Tensor):
        x = torch.cat([obs, act], dim=-1)
        return self.q(x)


## 4) SAC update (from scratch)

This is the core of the algorithm: compute targets, optimize critics, optimize actor, optionally tune $\alpha$, then Polyak-update target critics.


In [None]:
def polyak_update(source: nn.Module, target: nn.Module, tau: float):
    with torch.no_grad():
        for p, p_targ in zip(source.parameters(), target.parameters(), strict=True):
            p_targ.data.mul_(1.0 - tau)
            p_targ.data.add_(tau * p.data)


@dataclass
class SACConfig:
    gamma: float = 0.99
    tau: float = 0.005
    actor_lr: float = 3e-4
    critic_lr: float = 3e-4
    alpha_lr: float = 3e-4
    batch_size: int = 256
    replay_size: int = 200_000
    updates_per_step: int = 1
    start_steps: int = 1_000
    auto_alpha: bool = True
    init_alpha: float = 0.2


class SACAgent:
    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        act_low: np.ndarray,
        act_high: np.ndarray,
        config: SACConfig,
        device: torch.device,
    ):
        self.cfg = config
        self.device = device

        self.actor = SquashedGaussianActor(obs_dim, act_dim, act_low, act_high).to(device)
        self.q1 = QNetwork(obs_dim, act_dim).to(device)
        self.q2 = QNetwork(obs_dim, act_dim).to(device)

        self.q1_targ = QNetwork(obs_dim, act_dim).to(device)
        self.q2_targ = QNetwork(obs_dim, act_dim).to(device)
        self.q1_targ.load_state_dict(self.q1.state_dict())
        self.q2_targ.load_state_dict(self.q2.state_dict())

        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=config.actor_lr)
        self.q_opt = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()), lr=config.critic_lr
        )

        self.auto_alpha = bool(config.auto_alpha)
        if self.auto_alpha:
            # A common heuristic for target entropy: -|A|
            self.target_entropy = -float(act_dim)
            self.log_alpha = torch.tensor(
                math.log(config.init_alpha), dtype=torch.float32, device=device, requires_grad=True
            )
            self.alpha_opt = torch.optim.Adam([self.log_alpha], lr=config.alpha_lr)
        else:
            self.target_entropy = None
            self.log_alpha = torch.tensor(math.log(config.init_alpha), dtype=torch.float32, device=device)
            self.alpha_opt = None

    @property
    def alpha(self) -> torch.Tensor:
        return self.log_alpha.exp()

    def act(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        return self.actor.act(obs, deterministic=deterministic)

    def update(self, batch: dict[str, torch.Tensor]):
        obs = batch["obs"]
        act = batch["act"]
        rew = batch["rew"]
        next_obs = batch["next_obs"]
        done = batch["done"]

        # --- Critic update ---
        with torch.no_grad():
            next_a, next_logp, _next_mu_a = self.actor.sample(next_obs)
            q1_next = self.q1_targ(next_obs, next_a)
            q2_next = self.q2_targ(next_obs, next_a)
            q_next = torch.min(q1_next, q2_next) - self.alpha * next_logp
            backup = rew + self.cfg.gamma * (1.0 - done) * q_next

        q1 = self.q1(obs, act)
        q2 = self.q2(obs, act)
        critic_loss = F.mse_loss(q1, backup) + F.mse_loss(q2, backup)

        self.q_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.q_opt.step()

        # --- Actor update ---
        a_pi, logp_pi, _mu_a = self.actor.sample(obs)
        q1_pi = self.q1(obs, a_pi)
        q2_pi = self.q2(obs, a_pi)
        q_pi = torch.min(q1_pi, q2_pi)
        actor_loss = (self.alpha * logp_pi - q_pi).mean()

        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        # --- Temperature (alpha) update ---
        alpha_loss = torch.tensor(0.0, device=self.device)
        if self.auto_alpha:
            alpha_loss = -(self.log_alpha * (logp_pi + self.target_entropy).detach()).mean()
            self.alpha_opt.zero_grad(set_to_none=True)
            alpha_loss.backward()
            self.alpha_opt.step()

        # --- Target networks ---
        polyak_update(self.q1, self.q1_targ, self.cfg.tau)
        polyak_update(self.q2, self.q2_targ, self.cfg.tau)

        metrics = {
            "critic_loss": float(critic_loss.item()),
            "actor_loss": float(actor_loss.item()),
            "alpha": float(self.alpha.item()),
            "alpha_loss": float(alpha_loss.item()),
            "mean_logp": float(logp_pi.mean().item()),
            "mean_q": float(q_pi.mean().item()),
        }
        return metrics


## 5) Training loop + Plotly diagnostics

We log per-episode:
- Episodic return (sum of rewards)
- Average entropy estimate $-\log\pi(a\mid s)$
- Average Q-value estimate $\min(Q_1,Q_2)$ for actions taken


In [None]:
FAST_RUN = True  # set False for longer training

cfg = SACConfig(
    replay_size=50_000 if FAST_RUN else 200_000,
    batch_size=128 if FAST_RUN else 256,
    start_steps=500 if FAST_RUN else 1_000,
    updates_per_step=1,
)

TOTAL_EPISODES = 60 if FAST_RUN else 200
MAX_STEPS_PER_EP = 200

if not TORCH_AVAILABLE:
    raise RuntimeError("PyTorch is required for the from-scratch SAC implementation.")


In [None]:
buffer = ReplayBuffer(obs_dim, act_dim, size=cfg.replay_size, seed=SEED)
agent = SACAgent(obs_dim, act_dim, act_low, act_high, config=cfg, device=device)


def moving_average(x, window: int = 10):
    x = np.asarray(x, dtype=float)
    if len(x) < window:
        return x
    w = np.ones(window) / window
    y = np.convolve(x, w, mode="valid")
    # pad left to match length
    return np.concatenate([np.full(window - 1, np.nan), y])


ep_returns = []
ep_entropies = []
ep_q_values = []
ep_alphas = []

total_steps = 0
t0 = time.time()
for ep in range(TOTAL_EPISODES):
    obs = reset_env(env, seed=SEED + ep)
    ep_ret = 0.0
    ent_sum = 0.0
    q_sum = 0.0
    steps = 0

    for _ in range(MAX_STEPS_PER_EP):
        if total_steps < cfg.start_steps:
            act = rng.uniform(act_low, act_high).astype(np.float32)
            # For plotting, estimate entropy from current policy anyway
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                _a_pi, logp_pi, _mu_a = agent.actor.sample(obs_t)
                entropy_est = float((-logp_pi).item())
                q_est = float(
                    torch.min(agent.q1(obs_t, _a_pi), agent.q2(obs_t, _a_pi)).item()
                )
        else:
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                a_pi, logp_pi, _mu_a = agent.actor.sample(obs_t)
                act = a_pi.squeeze(0).cpu().numpy().astype(np.float32)
                entropy_est = float((-logp_pi).item())
                q_est = float(
                    torch.min(agent.q1(obs_t, a_pi), agent.q2(obs_t, a_pi)).item()
                )

        next_obs, r, done, _info = step_env(env, act)

        buffer.store(obs, act, r, next_obs, done)

        obs = next_obs
        ep_ret += r
        ent_sum += entropy_est
        q_sum += q_est
        steps += 1
        total_steps += 1

        # Gradient updates
        if buffer.size >= cfg.batch_size:
            for _ in range(cfg.updates_per_step):
                batch = buffer.sample_batch(cfg.batch_size, device=device)
                _metrics = agent.update(batch)

        if done:
            break

    ep_returns.append(ep_ret)
    ep_entropies.append(ent_sum / max(1, steps))
    ep_q_values.append(q_sum / max(1, steps))
    ep_alphas.append(float(agent.alpha.item()))

dt = time.time() - t0
print(f"Finished {TOTAL_EPISODES} episodes, {total_steps} steps in {dt:.1f}s")


In [None]:
episodes = np.arange(1, len(ep_returns) + 1)
ma_window = 10

fig = make_subplots(
    rows=4,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.06,
    subplot_titles=(
        "Episodic rewards (return)",
        "Policy entropy estimate (mean -log π(a|s))",
        "Q-value estimate (mean min(Q1,Q2))",
        "Temperature α",
    ),
)

fig.add_trace(go.Scatter(x=episodes, y=ep_returns, mode="lines", name="return"), row=1, col=1)
fig.add_trace(
    go.Scatter(
        x=episodes,
        y=moving_average(ep_returns, ma_window),
        mode="lines",
        name=f"return (MA{ma_window})",
    ),
    row=1,
    col=1,
)

fig.add_trace(go.Scatter(x=episodes, y=ep_entropies, mode="lines", name="entropy"), row=2, col=1)
fig.add_trace(go.Scatter(x=episodes, y=ep_q_values, mode="lines", name="Q"), row=3, col=1)
fig.add_trace(go.Scatter(x=episodes, y=ep_alphas, mode="lines", name="alpha"), row=4, col=1)

fig.update_layout(height=900, title="SAC training diagnostics")
fig.update_yaxes(title_text="return", row=1, col=1)
fig.update_yaxes(title_text="-logπ", row=2, col=1)
fig.update_yaxes(title_text="Q", row=3, col=1)
fig.update_yaxes(title_text="α", row=4, col=1)
fig.update_xaxes(title_text="episode", row=4, col=1)
fig

### A minimal evaluation run (deterministic actions)

Evaluation uses the **mean** action (deterministic), which usually performs better than sampling.


In [None]:
def eval_policy(agent: SACAgent, env, episodes: int = 5):
    returns = []
    for k in range(episodes):
        obs = reset_env(env, seed=10_000 + k)
        ep_ret = 0.0
        for _ in range(MAX_STEPS_PER_EP):
            act = agent.act(obs, deterministic=True)
            obs, r, done, _ = step_env(env, act)
            ep_ret += r
            if done:
                break
        returns.append(ep_ret)
    return returns


eval_returns = eval_policy(agent, env, episodes=8)
fig = go.Figure()
fig.add_trace(go.Bar(x=list(range(1, len(eval_returns) + 1)), y=eval_returns, name="eval return"))
fig.update_layout(
    title="Deterministic evaluation (per-episode return)",
    xaxis_title="eval episode",
    yaxis_title="return",
)
fig.show()

print("Eval mean return:", float(np.mean(eval_returns)))


## 6) Stable-Baselines3 SAC (reference implementation + hyperparameters)

A widely used reference implementation is **Stable-Baselines3 (SB3)**, which includes `stable_baselines3.SAC`.

Useful links:
- SB3 source: https://github.com/DLR-RM/stable-baselines3
- SAC paper (Haarnoja et al.): https://arxiv.org/abs/1801.01290
- Spinning Up SAC overview: https://spinningup.openai.com/en/latest/algorithms/sac.html

### SB3 usage example

```python
# pip install stable-baselines3 gymnasium

import gymnasium as gym
from stable_baselines3 import SAC

env = gym.make("Pendulum-v1")
model = SAC(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    buffer_size=1_000_000,
    learning_starts=100,
    batch_size=256,
    tau=0.005,
    gamma=0.99,
    train_freq=1,
    gradient_steps=1,
    ent_coef="auto",
    target_entropy="auto",
)
model.learn(total_timesteps=100_000)
```

### SB3 SAC hyperparameters (what they mean)

Below is a practical summary of the main `stable_baselines3.SAC(...)` constructor arguments (based on the SB3 implementation/docstring):

- `policy`: Policy architecture wrapper (`"MlpPolicy"`, `"CnnPolicy"`, ...).
- `env`: Gym/Gymnasium env instance (or env ID string).
- `learning_rate`: Adam learning rate for actor/critics (can be a schedule).
- `buffer_size`: Replay buffer capacity.
- `learning_starts`: Number of environment steps to collect before training starts.
- `batch_size`: Mini-batch size per gradient update.
- `tau`: Target-network Polyak coefficient.
- `gamma`: Discount factor.
- `train_freq`: How often to train (e.g., every N steps, or `(N, "episode")`).
- `gradient_steps`: Gradient updates after each rollout; `-1` means "as many as steps collected".
- `action_noise`: Optional action noise process (often unused for SAC because the policy is stochastic).
- `replay_buffer_class` / `replay_buffer_kwargs`: Swap replay buffer type/params (e.g., HER).
- `optimize_memory_usage`: Memory-efficient replay buffer variant.
- `n_steps`: If >1, use n-step returns via an n-step replay buffer.
- `ent_coef`: Entropy coefficient $\alpha$; set to `"auto"` to learn it (`"auto_0.1"` uses 0.1 init).
- `target_update_interval`: Target network update interval (in gradient steps).
- `target_entropy`: Target entropy when `ent_coef="auto"` (often `"auto"` to use a heuristic).
- `use_sde`: Use generalized State Dependent Exploration (gSDE) instead of action noise.
- `sde_sample_freq`: Resample gSDE noise every N steps (`-1` means only at rollout start).
- `use_sde_at_warmup`: Use gSDE during warmup instead of uniform random actions.
- `stats_window_size`: Window size for logged rollout stats.
- `tensorboard_log`: TensorBoard logging directory.
- `policy_kwargs`: Extra kwargs for the policy/networks (e.g., `net_arch`, activation function).
- `verbose`, `seed`, `device`: Usual runtime controls.
