# Experiment 1.7: Sparse labels (per-sample regularization)

In previous experiments, we imposed structure on latent space with curriculum learning: varying the training data and hyperparameters over the course of the training run ([Ex 1.3](./ex-1.3-color-mlp-curriculum.ipynb), [Ex 1.5](./ex-1.5-color-mlp-anchoring.ipynb), [Ex 1.6](./ex-1.6-curriculum-comparison.ipynb)). It worked (in that the latent space looked _ok_), but we are unsure whether it worked because the color wheel was found and anchored _before_ expanding the data to include all colors, or whether it was just that the primary and secondary colors were (in effect) labelled: that is, special regularization was applied to those samples.

In this experiment, we do away with the phased curriculum, and instead apply per-sample regularization. Our hypothesis is that this will in fact outperform the curriculum-based methods, because the model will have access to the data full distribution from the start (limited only by batch size).

## Dataset design

We chose color as a domain because it's easy to reason about and visualize. Since our eventual goal is to apply these techniques to LLM training, we should consider how to constrain the labels in a way that could realistically be replicated for text. We assume that:

1. An LLM would be trained with something like internet text
2. Sentiment analysis could be run over it to generate labels — attributes that we care about, such as "malicious", "benign", "honest", etc.
3. Such labelling would not be entirely accurate.

For this experiment, we will apply the following regigme:

1. An autoencoder trained on a color cube
2. Certain colors are given labels (e.g. $(1,0,0)=red$)
3. The labelling would be noisy, e.g. maybe $(1,0,0)$ wouldn't _always_ be labelled as $red$, and sometimes other colors close to red would be given that label.

Certain regularizers will be activated based on those labels, e.g. colors labelled $red$ could be penalized for not being embedded at $(1,0,0,0)$.

In [1]:
from __future__ import annotations

In [2]:
import logging
from utils.logging import SimpleLoggingConfig

logging_config = SimpleLoggingConfig().info('notebook', 'utils', 'mini', 'ex_color')
logging_config.apply()

# ID for tagging assets
nbid = '1.7'
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

## Model architecture

We use the same simple 2-layer MLP autoencoder with a bottleneck as in previous experiments. The key difference lies not in the architecture, but in the training process governed by the smooth curriculum.


In [3]:
import torch
import torch.nn as nn

E = 4


class ColorMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # RGB input (3D) → hidden layer → bottleneck → hidden layer → RGB output
        self.encoder = nn.Sequential(
            nn.Linear(3, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, E),  # Our critical bottleneck!
        )

        self.decoder = nn.Sequential(
            nn.Linear(E, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, 3),
            nn.Sigmoid(),  # Keep RGB values in [0,1]
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Get our bottleneck representation
        bottleneck = self.encoder(x)

        # Decode back to RGB
        output = self.decoder(bottleneck)
        return output, bottleneck

## Training machinery with timeline and events

The `train_color_model` function orchestrates the training process based on a `Timeline` derived from the dopesheet. It handles:

- Iterating through training steps.
- Fetching the correct data loader for the current phase.
- Updating hyperparameters (like learning rate and loss weights) smoothly based on the timeline state.
- Calculating the combined loss from reconstruction and various regularizers.
- Executing the optimizer step.
- Emitting events at different points (phase start/end, pre-step, actions like 'anchor', step metrics) to trigger callbacks like plotting, recording, or updating loss terms.

In [4]:
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from torch import Tensor
import torch.optim as optim

from mini.temporal.timeline import State


@dataclass
class InferenceResult:
    outputs: Tensor
    latents: Tensor

    def detach(self):
        return InferenceResult(self.outputs.detach(), self.latents.detach())

    def clone(self):
        return InferenceResult(self.outputs.clone(), self.latents.clone())

    def cpu(self):
        return InferenceResult(self.outputs.cpu(), self.latents.cpu())


@runtime_checkable
class LossCriterion(Protocol):
    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor: ...


@dataclass
class RegularizerConfig:
    """Configuration for a regularizer, including label affinities."""
    name: str
    """Matched with hyperparameter for weighting"""
    criterion: LossCriterion
    label_affinities: dict[str, float] | None
    """Maps label names to affinity strengths"""


@dataclass(eq=False, frozen=True)
class Event:
    name: str
    step: int
    model: ColorMLP
    timeline_state: State
    optimizer: optim.Optimizer


@dataclass(eq=False, frozen=True)
class PhaseEndEvent(Event):
    validation_data: Tensor
    inference_result: InferenceResult


@dataclass(eq=False, frozen=True)
class StepMetricsEvent(Event):
    """Event carrying metrics calculated during a training step."""

    total_loss: float
    losses: dict[str, float]


class EventHandler[T](Protocol):
    def __call__(self, event: T) -> None: ...


class EventBinding[T]:
    """A class to bind events to handlers."""

    def __init__(self, event_name: str):
        self.event_name = event_name
        self.handlers: list[tuple[str, EventHandler[T]]] = []

    def add_handler(self, event_name: str, handler: EventHandler[T]) -> None:
        self.handlers.append((event_name, handler))

    def emit(self, event_name: str, event: T) -> None:
        for name, handler in self.handlers:
            if name == event_name:
                handler(event)


class EventHandlers:
    """A simple event system to allow for custom callbacks."""

    phase_start: EventBinding[Event]
    pre_step: EventBinding[Event]
    action: EventBinding[Event]
    phase_end: EventBinding[PhaseEndEvent]
    step_metrics: EventBinding[StepMetricsEvent]

    def __init__(self):
        self.phase_start = EventBinding[Event]('phase-start')
        self.pre_step = EventBinding[Event]('pre-step')
        self.action = EventBinding[Event]('action')
        self.phase_end = EventBinding[PhaseEndEvent]('phase-end')
        self.step_metrics = EventBinding[StepMetricsEvent]('step-metrics')

### Training loop

In [5]:
import random
import numpy as np
from skimage import metrics
import torch
from typing import Iterable, Iterator
from torch.utils.data import DataLoader
import torch.optim as optim

from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from utils.progress import RichProgress


def seed_everything(seed: int):
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    log.info(f'Global random seed set to {seed}')


def set_deterministic_mode(seed: int):
    """Make experiments reproducible."""
    seed_everything(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    log.info('PyTorch set to deterministic mode')


def reiterate[T](it: Iterable[T]) -> Iterator[T]:
    """
    Iterates over an iterable indefinitely.

    When the iterable is exhausted, it starts over from the beginning. Unlike
    `itertools.cycle`, yielded values are not cached — so each iteration may be
    different.
    """
    while True:
        yield from it


def train_color_model(  # noqa: C901
    model: ColorMLP,
    train_loader: DataLoader,
    val_data: Tensor,
    dopesheet: Dopesheet,
    loss_criterion: LossCriterion,
    regularizers: list[RegularizerConfig],
    event_handlers: EventHandlers | None = None,
):
    if event_handlers is None:
        event_handlers = EventHandlers()

    # --- Validate inputs ---
    if 'lr' not in dopesheet.props:
        raise ValueError("Dopesheet must define the 'lr' property column.")
    # --- End Validation ---

    timeline = Timeline(dopesheet)
    optimizer = optim.Adam(model.parameters(), lr=0)
    device = next(model.parameters()).device

    train_data = iter(reiterate(train_loader))

    total_steps = len(timeline)

    with RichProgress(total=total_steps, description='Training Steps') as pbar:
        for step in range(total_steps):
            # Get state *before* advancing timeline for this step's processing
            current_state = timeline.state
            current_phase_name = current_state.phase

            batch_data, batch_labels = next(train_data)
            # Should already be on device
            # batch_data = batch_data.to(device)
            # batch_labels = batch_labels.to(device)

            # --- Event Handling ---
            event_template = {
                'step': step,
                'model': model,
                'timeline_state': current_state,
                'optimizer': optimizer,
            }

            if current_state.is_phase_start:
                event = Event(name=f'phase-start:{current_phase_name}', **event_template)
                event_handlers.phase_start.emit(event.name, event)
                event_handlers.phase_start.emit('phase-start', event)

            for action in current_state.actions:
                event = Event(name=f'action:{action}', **event_template)
                event_handlers.action.emit(event.name, event)
                event_handlers.action.emit('action', event)

            event = Event(name='pre-step', **event_template)
            event_handlers.pre_step.emit('pre-step', event)

            # --- Training Step ---
            # ... (get data, update LR, zero grad, forward pass, calculate loss, backward, step) ...

            current_lr = current_state.props['lr']
            # REF_BATCH_SIZE = 32
            # lr_scale_factor = batch.shape[0] / REF_BATCH_SIZE
            # current_lr = current_lr * lr_scale_factor
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            optimizer.zero_grad()

            outputs, latents = model(batch_data)
            current_results = InferenceResult(outputs, latents)

            primary_loss = loss_criterion(batch_data, current_results).mean()
            losses = {'recon': primary_loss.item()}
            total_loss = primary_loss
            zeros = torch.tensor(0.0, device=batch_data.device)

            for regularizer in regularizers:
                name = regularizer.name
                criterion = regularizer.criterion

                weight = current_state.props.get(name, 1.0)
                if weight == 0:
                    continue

                if regularizer.label_affinities is not None:
                    label_probs = [
                        batch_labels[k] * v  # Soft labels that indicate how much effect this regularizer has, based on its affinity with the label
                        for k, v in regularizer.label_affinities.items()
                        if k in batch_labels
                    ]
                    if not label_probs:
                        continue

                    sample_affinities = torch.stack(label_probs, dim=0).sum(dim=0)
                    sample_affinities = torch.clamp(sample_affinities, 0.0, 1.0)
                    if torch.allclose(sample_affinities, zeros):
                        continue
                else:
                    sample_affinities = torch.ones(batch_data.shape[0], device=batch_data.device)

                per_sample_loss = criterion(batch_data, current_results)
                if len(per_sample_loss.shape) == 0:
                    # If the loss is a scalar, we need to expand it to match the batch size
                    per_sample_loss = per_sample_loss.expand(batch_data.shape[0])
                assert per_sample_loss.shape[0] == batch_data.shape[0], f'Loss should be per-sample OR scalar: {name}'

                # Apply sample affinities
                weighted_loss = per_sample_loss * sample_affinities

                # Apply sample importance weights
                # weighted_loss *= batch_weights

                # Calculate mean only over selected samples. If we used torch.mean, it would average over all samples, including those with 0 weight
                term_loss = weighted_loss.sum() / (sample_affinities.sum() + 1e-8)

                losses[name] = term_loss.item()
                if not torch.isfinite(term_loss):
                    log.warning(f'Loss term {name} at step {step} is not finite: {term_loss}')
                    continue
                total_loss += term_loss * weight

            if total_loss > 0:
                total_loss.backward()
                optimizer.step()
            # --- End Training Step ---

            # Emit step metrics event
            step_metrics_event = StepMetricsEvent(
                name='step-metrics',
                **event_template,
                total_loss=total_loss.item(),
                losses=losses,
            )
            event_handlers.step_metrics.emit('step-metrics', step_metrics_event)

            # --- Post-Step Event Handling ---
            if current_state.is_phase_end:
                # Trigger phase-end for the *current* phase
                # validation_data = batch_data
                with torch.no_grad():
                    val_outputs, val_latents = model(val_data.to(device))
                event = PhaseEndEvent(
                    name=f'phase-end:{current_phase_name}',
                    **event_template,
                    validation_data=val_data,
                    inference_result=InferenceResult(val_outputs, val_latents),
                )
                event_handlers.phase_end.emit(event.name, event)
                event_handlers.phase_end.emit('phase-end', event)
            # --- End Event Handling ---

            # Update progress bar
            pbar.update(
                metrics={
                    'PHASE': current_phase_name,
                    'lr': f'{current_lr:.6f}',
                    'loss': f'{total_loss.item():.4f}',
                    **{name: f'{lt:.4f}' for name, lt in losses.items()},
                },
            )

            # Advance timeline *after* processing the current step
            if step < total_steps:  # Avoid stepping past the end
                timeline.step()

    log.info('Training finished!')

### Visualization

In [None]:
import math

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from torch import Tensor
from IPython.display import HTML

from utils.nb import save_fig


class PhasePlotter:
    """Event handler to plot latent space at the end of each phase."""

    def __init__(self, val_data: Tensor, *, dim_pairs: list[tuple[int, int]], interval: int = 100):
        from utils.nb import displayer

        # Store (phase_name, end_step, data, result) - data comes from event now
        self.val_data = val_data
        self.history: list[tuple[str, int, Tensor, Tensor]] = []
        self.display = displayer()
        self.dim_pairs = dim_pairs
        self.interval = interval

    def __call__(self, event: Event):
        # TODO: Don't assume device = CPU
        # TODO: Split this class so that the event handler is separate from the plotting, and so the plotting can happen locally with @run.hither
        if event.step % self.interval != 0:
            return
        phase_name = event.timeline_state.phase
        step = event.step
        output, latents = event.model(self.val_data)

        log.debug(f'Plotting end of phase: {phase_name} at step {step} using provided results.')

        # Append to history
        self.history.append((phase_name, step, output.detach().cpu(), latents.detach().cpu()))

        # Plotting logic remains the same as it already expected CPU tensors
        fig = self._plot_phase_history()
        self.display(
            HTML(
                save_fig(
                    fig,
                    f'large-assets/ex-{nbid}-color-phase-history.png',
                    alt_text='Visualizations of latent space at the end of each curriculum phase.',
                )
            )
        )

    def _plot_phase_history(self):
        if not self.history:
            fig, ax = plt.subplots()
            fig.set_facecolor('#333')
            ax.set_facecolor('#222')
            ax.text(0.5, 0.5, 'Waiting...', ha='center', va='center')
            return fig

        plt.style.use('dark_background')

        # Number of dimension pairs
        num_dim_pairs = len(self.dim_pairs)

        # Cap the number of thumbnails to a maximum for readability
        max_thumbnails = 10
        # If we have more history than max_thumbnails, only show the most recent ones
        stride = math.ceil(len(self.history) / max_thumbnails)
        history_to_show = self.history[::stride]
        history_to_show = history_to_show[:max_thumbnails]
        # history_to_show = self.history[-max_thumbnails:] if len(self.history) > max_thumbnails else self.history

        # Create figure with gridspec for flexible layout
        fig = plt.figure(figsize=(12, 5), facecolor='#333')

        # Create two separate gridspecs - one for thumbnails, one for latest state
        gs = fig.add_gridspec(2, 1, hspace=0.1, height_ratios=[1, 4])

        # Thumbnail gridspec (top row) - only first dimension pair
        # Remove spacing between thumbnails by setting wspace=0
        thumbnail_gs = gs[0].subgridspec(2, max_thumbnails, wspace=0, hspace=0.1, height_ratios=[0, 1])

        # Latest state gridspec (bottom row) - all dimension pairs
        latest_gs = gs[1].subgridspec(2, num_dim_pairs, wspace=0, hspace=0.1, height_ratios=[1, 0])

        # Get the data
        _colors = self.val_data.numpy()

        # Create thumbnail axes and plot history
        for i, (_, step, _, latents) in enumerate(history_to_show):
            _latents = latents.numpy()

            # Only plot the first dimension pair for thumbnails
            dim1, dim2 = self.dim_pairs[0]

            # Create title for the thumbnail as its own axes, so that it's aligned with the other titles
            axt = fig.add_subplot(thumbnail_gs[0, i])
            axt.text(
                0,
                0,
                f'{step}',
                # transform=axt.transAxes,
                horizontalalignment='center',
                fontsize=7,
            )
            # Remove all decorations
            axt.patch.set_alpha(0)
            axt.set_xticks([])
            axt.set_yticks([])
            axt.spines['top'].set_visible(False)
            axt.spines['right'].set_visible(False)
            axt.spines['bottom'].set_visible(False)
            axt.spines['left'].set_visible(False)

            # Create thumbnail axis
            ax = fig.add_subplot(thumbnail_gs[1, i])
            ax.sharex(axt)

            # Plot the data
            ax.scatter(_latents[:, dim1], _latents[:, dim2], c=_colors, s=50, alpha=0.7)

            # Add reference circle
            ax.add_patch(Circle((0, 0), 1, fill=True, facecolor='black', edgecolor='gray', zorder=-1))

            # Remove all decorations
            ax.patch.set_alpha(0)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)

            # Ensure square aspect ratio
            ax.set_aspect('equal')
            ax.set_adjustable('box')

        # Plot latest state
        # Get the latest data
        phase_name, step, output, latents = self.history[-1]
        _latents = latents.numpy()

        for i, (dim1, dim2) in enumerate(self.dim_pairs):
            # Create title for the thumbnail as its own axes, so that it's aligned with the other titles
            axt = fig.add_subplot(latest_gs[1, i])
            axt.text(
                0,
                0,
                f'[{dim1}, {dim2}]',
                # transform=axt.transAxes,
                horizontalalignment='center',
                fontsize=10,
            )
            # Remove all decorations
            axt.patch.set_alpha(0)
            axt.set_xticks([])
            axt.set_yticks([])
            axt.spines['top'].set_visible(False)
            axt.spines['right'].set_visible(False)
            axt.spines['bottom'].set_visible(False)
            axt.spines['left'].set_visible(False)

            # Plot
            ax = fig.add_subplot(latest_gs[0, i])
            ax.sharex(axt)

            # Plot the data with larger markers
            ax.scatter(_latents[:, dim1], _latents[:, dim2], c=_colors, s=200, alpha=0.7)

            # Add reference circle
            ax.add_patch(Circle((0, 0), 1, fill=True, facecolor='black', edgecolor='gray', zorder=-1))

            # Remove all decorations
            ax.patch.set_alpha(0)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)

            # Ensure square aspect ratio
            ax.set_aspect('equal')
            ax.set_adjustable('box')

        # Add overall title
        fig.suptitle('Latent Space', fontsize=12, color='white')

        # Use subplots_adjust instead of tight_layout to avoid warnings
        fig.subplots_adjust(top=0.85, bottom=0.1, left=0.1, right=0.95)

        return fig

## Hyperparameter dopesheet

We'll define [a dopesheet](./ex-1.7-dopesheet.csv) (timelines) to allow hyperparameters to vary over time.

In [261]:
import re

from IPython.display import display, HTML
from matplotlib.figure import Figure

from mini.temporal.vis import plot_timeline, realize_timeline, ParamGroup
from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from utils.nb import save_fig


line_styles = [
    (re.compile(r'^data-'), {'linewidth': 5, 'zorder': -1, 'alpha': 0.5}),
    # (re.compile(r'-(anchor|norm)$'), {'linewidth': 2, 'linestyle': (0, (8, 1, 1, 1))}),
]


def load_dopesheet():
    dopesheet = Dopesheet.from_csv(f'ex-{nbid}-dopesheet.csv')
    # display(Markdown(f"""## Parameter schedule ({variant})\n{dopesheet.to_markdown()}"""))

    timeline = Timeline(dopesheet)
    history_df = realize_timeline(timeline)
    keyframes_df = dopesheet.as_df()

    groups = (
        ParamGroup(
            name='',
            params=[p for p in dopesheet.props if p not in {'lr'}],
            height_ratio=2,
        ),
        ParamGroup(
            name='',
            params=[p for p in dopesheet.props if p in {'lr'}],
            height_ratio=1,
        ),
    )
    fig, ax = plot_timeline(history_df, keyframes_df, groups, line_styles=line_styles)
    # Add assertion to satisfy type checker
    assert isinstance(fig, Figure), 'plot_timeline should return a Figure'
    display(
        HTML(
            save_fig(
                fig,
                f'large-assets/ex-{nbid}-color-timeline.png',
                alt_text='Line chart showing the hyperparameter schedule over time.',
            )
        )
    )
    return dopesheet


dopesheet = load_dopesheet()

These schedules seem pretty well matched for a fair comparison. The core hyperparameter targets are hit at similar times, with the main difference being, well, the smoothness. This should give us a good basis for seeing what impact the transition style has.

## Loss functions and regularizers

Like Ex 1.5, we use mean squared error for the main reconstruction loss (`loss-recon`), and regularizers that encourage embeddings of unit length, and for primary colors to be on the plane of the first two dimensions.

Unlike Ex 1.5, most of the criteria and regularizers now return per-sample loss, which allows new samples to be given lower weight (see data loaders below).

In [133]:
from torch import linalg as LA

from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic


def objective(fn):
    """Adapt loss function to look like a regularizer"""

    def wrapper(data: Tensor, res: InferenceResult) -> Tensor:
        loss = fn(data, res.outputs)
        # Reduce element-wise loss to per-sample loss by averaging over feature dimensions
        if loss.ndim > 1:
            # Calculate mean over all dimensions except the first (batch) dimension
            reduce_dims = tuple(range(1, loss.ndim))
            loss = torch.mean(loss, dim=reduce_dims)
        return loss

    return wrapper


def unitary(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to have unit norm (vectors of length 1)"""
    norms = LA.vector_norm(res.latents, dim=-1)
    # Return per-sample loss, shape [B]
    return (norms - 1.0) ** 2


def planarity(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to be planar in the first two channels (so zero in other channels)"""
    if res.latents.shape[1] <= 2:
        # No dimensions beyond the first two, return zero loss per sample
        return torch.zeros(res.latents.shape[0], device=res.latents.device)
    # Sum squares across the extra dimensions for each sample, shape [B]
    return torch.sum(res.latents[:, 2:] ** 2, dim=-1)


### Anchor

In [134]:
class Anchor(LossCriterion):
    def __init__(self, anchor_point: Tensor):
        self.anchor_point = anchor_point

    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor:
        """
        Regularize latents to be close to the anchor point.

        Returns:
            loss: Per-sample loss, shape [B].
        """
        # Calculate squared distances to the anchor
        sq_dists = torch.sum((res.latents - self.anchor_point) ** 2, dim=-1)  # [B]
        return sq_dists

### Separate

In [135]:
class Separate(LossCriterion):
    def __init__(
        self,
        channels: tuple[int, ...] | None = None,
        bias: float = 1e-8,
        scale: float = 1.0,
        power: float = 1.0,
    ):
        self.channels = channels
        self.scale = scale
        self.power = power
        self.bias = bias

    # def __call__(self, data: Tensor, res: InferenceResult) -> Tensor:
    #     """
    #     Regularize latents to be separated from each other in first two channels.

    #     Returns:
    #         loss: Per-sample loss, shape [B].
    #     """
    #     # Get pairwise differences in the first two dimensions
    #     points = res.latents[:, self.channels]  # [B, C]
    #     diffs = points.unsqueeze(1) - points.unsqueeze(0)  # [B, B, C]
    #     diffs *= self.scale
    #     # diffs += 1.0

    #     # Calculate squared distances
    #     sq_dists = torch.sum(diffs**2, dim=-1)  # [B, B]
    #     invmask = torch.isclose(sq_dists, torch.zeros_like(sq_dists))  # [B, B]
    #     losses = 1.0 / (sq_dists + self.bias)

    #     # Remove self-loss (and degenerate pairs)
    #     losses[invmask] = 0.0

    #     losses = torch.sum(losses, dim=1)  # [B]
    #     # losses = torch.sigmoid_(losses)  # [B]
    #     # return losses
    #     return losses**self.power

    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor:
        """
        Regularize latents to be separated from each other in first two channels.

        Returns:
            loss: Per-sample loss, shape [B].
        """
        # Get pairwise differences in the first two dimensions
        points = res.latents[:, self.channels]  # [B, C]

        # Normalize to unit hypersphere, so that points are repelled along its surface
        magnitudes = torch.norm(points, dim=-1, keepdim=True)
        points = points / (magnitudes + 1e-8)  # Normalize to unit sphere

        diffs = points.unsqueeze(1) - points.unsqueeze(0)  # [B, B, C]
        diffs *= self.scale
        # diffs += 1.0

        # Calculate squared distances
        sq_dists = torch.sum(diffs**2, dim=-1)  # [B, B]
        invmask = torch.isclose(sq_dists, torch.zeros_like(sq_dists))  # [B, B]
        losses = 1.0 / (sq_dists + self.bias)

        # Remove self-loss (and degenerate pairs)
        losses[invmask] = 0.0

        losses = torch.sum(losses, dim=1)  # [B]
        # losses = torch.sigmoid_(losses)  # [B]
        # return losses
        return losses**self.power

## Data loading, sampling, and event handling

Here we set up:

- **Datasets:** Define the datasets used (primary/secondary colors, full color grid).
- **Sampler:** Use `DynamicWeightedRandomBatchSampler` for the full dataset. Its weights are updated by the `update_sampler_weights` callback, which responds to the `data-fraction` parameter from the dopesheet. This smoothly shifts the sampling focus from highly vibrant colors early on to the full range of colors later.
- **Recorders:** `ModelRecorder` and `MetricsRecorder` are event handlers that save the model state and loss values at each step.
- **Event bindings:** Connect event handlers to specific events (e.g., `plotter` to `phase-end`, `reg_anchor.on_anchor` to `action:anchor`, recorders to `pre-step` and `step-metrics`).
- **Training execution:** Finally, call `train_color_model` with the model, datasets, dopesheet, loss criteria, and configured event handlers.


### Recorders

In [256]:
import numpy as np
import torch.nn as nn


class ModelRecorder(EventHandler):
    """Event handler to record model parameters."""

    history: list[tuple[int, dict[str, Tensor]]]
    """A list of tuples (step, state_dict) where state_dict is a copy of the model's state dict."""

    def __init__(self):
        self.history = []

    def __call__(self, event: Event):
        # Get a *copy* of the state dict and move it to the CPU
        # so we don't hold onto GPU memory or track gradients unnecessarily.
        model_state = {k: v.cpu().clone() for k, v in event.model.state_dict().items()}
        self.history.append((event.step, model_state))
        log.debug(f'Recorded model state at step {event.step}')


class MetricsRecorder(EventHandler):
    """Event handler to record training metrics."""

    history: list[tuple[int, float, dict[str, float]]]
    """A list of tuples (step, total_loss, losses_dict)."""

    def __init__(self):
        self.history = []

    def __call__(self, event: StepMetricsEvent):
        # Ensure we are handling the correct event type
        if not isinstance(event, StepMetricsEvent):
            log.warning(f'MetricsRecorder received unexpected event type: {type(event)}')
            return

        self.history.append((event.step, event.total_loss, event.losses.copy()))
        log.debug(f'Recorded metrics at step {event.step}: loss={event.total_loss:.4f}')

### Labelling

In [257]:
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from torch.utils.data.dataloader import default_collate

# TODO: remove forced reload
if True:
    import importlib
    import ex_color.data.cube_sampler

    importlib.reload(ex_color.data.cube_sampler)


def generate_color_labels(data: Tensor, vibrancies: Tensor) -> dict[str, Tensor]:
    """
    Generate label probabilities based on RGB values.

    Args:
        data: Batch of RGB values [B, 3]

    Returns:
        Dictionary mapping label names to probabilities str -> [B]
    """
    labels: dict[str, Tensor] = {}

    # Proximity to primary colors
    r, g, b = data[:, 0], data[:, 1], data[:, 2]
    labels['red'] = (r * (1 - g / 2 - b / 2)) ** 10
    # labels['green'] = g * (1 - r / 2 - b / 2)
    # labels['blue'] = b * (1 - r / 2 - g / 2)
    # Proximity to any fully-saturated, fully-bright color
    labels['vibrant'] = vibrancies ** 100

    return labels


def collate_with_generated_labels(batch):
    """
    Custom collate function that generates labels for the samples.

    Args:
        batch: A list of ((data_tensor,), index_tensor) tuples from TensorDataset.
               Note: TensorDataset wraps single tensors in a tuple.

    Returns:
        A tuple: (collated_data_tensor, collated_labels_tensor)
    """
    # Separate data and indices
    # TensorDataset yields tuples like ((data_point_tensor,), index_scalar_tensor)
    data_tuple_list = [item[0] for item in batch]  # List of (data_tensor,) tuples
    vibrancies = [item[1] for item in batch]

    # Collate the data points using the default collate function
    # default_collate handles the list of (data_tensor,) tuples correctly
    collated_data = default_collate(data_tuple_list)
    vibrancies = default_collate(vibrancies)
    label_probs = generate_color_labels(collated_data, vibrancies)

    # Either return the probabilities directly
    return collated_data, label_probs

    # TODO: Sample discrete labels stochastically? Perhaps only one label per sample?
    # sampled_labels = sample_labels(label_probs)
    # return collated_data, sampled_labels


### Datasets

In [258]:
from ex_color.data.cube_sampler import vibrancy

hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=10 / 360),
    s=np.linspace(0, 1, 10),
    v=np.linspace(0, 1, 10),
)
hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)

# Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
hsv_loader = DataLoader(
    hsv_dataset,
    batch_size=32,
    sampler=WeightedRandomSampler(
        weights=hsv_cube.bias.flatten().tolist(),
        num_samples=len(hsv_dataset),
        replacement=True,
    ),
    collate_fn=collate_with_generated_labels,
)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 8),
    g=np.linspace(0, 1, 8),
    b=np.linspace(0, 1, 8),
)
rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)

### Train

In [262]:
def train(dopesheet: Dopesheet):
    """Train the model with the given dopesheet and variant."""
    log.info('Training')
    recorder = ModelRecorder()
    metrics_recorder = MetricsRecorder()

    # seed = 0
    # set_deterministic_mode(seed)

    model = ColorMLP()
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(f'Model initialized with {total_params:,} trainable parameters.')

    event_handlers = EventHandlers()
    event_handlers.pre_step.add_handler('pre-step', recorder)
    event_handlers.step_metrics.add_handler('step-metrics', metrics_recorder)

    plotter = PhasePlotter(rgb_tensor, dim_pairs=[(0, 1), (0, 2), (0, 3)], interval=200)
    event_handlers.pre_step.add_handler('pre-step', plotter)

    regularizers = [
        RegularizerConfig(
            name='reg-polar',
            criterion=Anchor(torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=hsv_tensor.device)),
            label_affinities={'red': 1.0},
        ),
        RegularizerConfig(
            name='reg-separate',
            # criterion=Separate(scale=1000),
            criterion=Separate(bias=1, scale=1_000, power=0.5),
            label_affinities=None,
        ),
        RegularizerConfig(
            name='reg-planar',
            criterion=planarity,
            label_affinities={'vibrant': 1.0},
        ),
        RegularizerConfig(
            name='reg-norm-v',
            criterion=unitary,
            label_affinities={'vibrant': 1.0},
        ),
        RegularizerConfig(
            name='reg-norm',
            criterion=unitary,
            label_affinities=None,
        ),
    ]

    train_color_model(
        model,
        hsv_loader,
        rgb_tensor,
        dopesheet,
        # loss_criterion=objective(nn.MSELoss(reduction='none')),  # No reduction; allows per-sample loss weights
        loss_criterion=objective(nn.MSELoss()),
        regularizers=regularizers,
        event_handlers=event_handlers,
    )

    return recorder, metrics_recorder

In [263]:
recorder, metrics = train(dopesheet)

I 6480.4 no.1.7:Training
I 6480.4 no.1.7:Model initialized with 263 trainable parameters.


I 6530.3 no.1.7:Training finished!


Both models trained fairly well! There are some differences, but they look like they have similar characteristics. Surprisingly, the smooth variant seemed to have a _noisier_ (i.e. worse) latent space at the end of the _All hues_ phase.

## Validation

### Latent space evolution analysis

Let's visualize how the latent spaces evolved over time. Like Ex 1.5, we'll use the `ModelRecorder`'s history to load the model state at each recorded step and evaluate the latent positions for a fixed set of input colors (the full RGB grid). This gives us a sequence of latent space snapshots.


In [16]:
import numpy as np


def eval_latent_history(
    recorder: ModelRecorder,
    rgb_tensor: Tensor,
):
    """Evaluate the latent space for each step in the recorder's history."""
    # Create a new model instance
    from utils.progress import RichProgress

    model = ColorMLP()

    latent_history: list[tuple[int, np.ndarray]] = []
    # Iterate over the recorded history
    for step, state_dict in RichProgress(recorder.history, description='Evaluating latents'):
        # Load the model state dict
        model.load_state_dict(state_dict)
        model.eval()
        with torch.no_grad():
            # Get the latents for the RGB tensor
            _, latents = model(rgb_tensor.to(next(model.parameters()).device))
            latents = latents.cpu().numpy()
            latent_history.append((step, latents))
    return latent_history


latent_history = eval_latent_history(recorder, rgb_tensor)

### Animation of latent space

This final visualization combines multiple views into a single animation:

- **Latent space:** Shows the 2D projection (Dims 0 vs 1) of the latent embeddings for the full RGB color grid, colored by their true RGB values. We can see the color wheel forming.
- **Hyperparameters:** Replots the parameter schedule from the dopesheet, with a vertical line indicating the current step in the animation.
- **Training metrics:** Plots the total loss and the contribution of each individual loss/regularization term (on a log scale), again with a vertical line for the current step.

_(Note: A variable stride is used for sampling frames to focus on periods of rapid change.)_

The smooth training run is shown on the left, and the stepped run on the right.

In [17]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio_ffmpeg
from matplotlib import rcParams
import pandas as pd
from matplotlib.gridspec import GridSpec

from mini.temporal.dopesheet import RESERVED_COLS
from utils.progress import RichProgress

# TODO: remove forced reload
import importlib
import mini.temporal.vis
importlib.reload(mini.temporal.vis)
from mini.temporal.vis import group_properties_by_scale, plot_timeline

rcParams['animation.ffmpeg_path'] = imageio_ffmpeg.get_ffmpeg_exe()


def animate_latent_evolution_with_metrics(
    # Smooth variant data
    smooth_latent_history: list[tuple[int, np.ndarray]],
    smooth_metrics_history: list[tuple[int, float, dict[str, float]]],
    smooth_param_history_df: pd.DataFrame,
    smooth_param_keyframes_df: pd.DataFrame,
    # Stepped variant data
    stepped_latent_history: list[tuple[int, np.ndarray]],
    stepped_metrics_history: list[tuple[int, float, dict[str, float]]],
    stepped_param_history_df: pd.DataFrame,
    stepped_param_keyframes_df: pd.DataFrame,
    # Common data and settings
    colors: np.ndarray,
    dim_pair: tuple[int, int] = (0, 1),
    interval=1 / 30,
):
    """Create a side-by-side animation of latent space evolution, hyperparameters, and metrics."""
    plt.style.use('dark_background')
    # Aim for 16:9 aspect ratio, give latent plots more height
    fig = plt.figure(figsize=(16, 9))
    # Use the height ratios from your latest version
    gs = GridSpec(3, 2, height_ratios=[5, 1, 1], width_ratios=[1, 1], hspace=0, wspace=0.02)

    # --- Create Axes ---
    # Latent plots (Top row) - No sharing needed initially
    ax_latent_s = fig.add_subplot(gs[0, 0])
    ax_latent_t = fig.add_subplot(gs[0, 1])

    # Parameter plots (Middle row) - Share x-axis with metrics plot BELOW
    ax_params_s = fig.add_subplot(gs[1, 0])
    ax_params_t = fig.add_subplot(gs[1, 1])

    # Metrics plots (Bottom row) - Share x-axis with parameter plot ABOVE
    ax_metrics_s = fig.add_subplot(gs[2, 0], sharex=ax_params_s)
    ax_metrics_t = fig.add_subplot(gs[2, 1], sharex=ax_params_t)


    fig.patch.set_facecolor('#333')
    all_axes = [ax_latent_s, ax_params_s, ax_metrics_s, ax_latent_t, ax_params_t, ax_metrics_t]
    for ax in all_axes:
        ax.patch.set_facecolor('#222')

    latent_lim = 1.1

    # --- Setup Smooth Plots (Left Column) ---
    step_s, current_latents_s = smooth_latent_history[0]
    ax_latent_s.set_xlim(-latent_lim, latent_lim)
    ax_latent_s.set_ylim(-latent_lim, latent_lim)
    ax_latent_s.set_aspect('equal', adjustable='datalim')
    # ax_latent_s.set_xlabel(f'Dim {dim_pair[0]}') # Set X label for latent plot
    ax_latent_s.tick_params(axis='x', labelleft=False) # Hide x labels
    plt.setp(ax_latent_s.get_xticklabels(), visible=False)
    # ax_latent_s.set_ylabel(f'Dim {dim_pair[1]}')
    ax_latent_s.set_ylabel('Latent space')
    ax_latent_s.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax_latent_s.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    ax_latent_s.add_patch(Circle((0, 0), 1, fill=False, linestyle='--', color='gray', alpha=0.3))
    scatter_s = ax_latent_s.scatter(
        current_latents_s[:, dim_pair[0]], current_latents_s[:, dim_pair[1]], c=colors, s=150, alpha=0.7
    )
    title_latent_s = ax_latent_s.set_title('placeholder') # Title set in update()
    # No need to hide x-ticks here anymore

    param_props_s = smooth_param_keyframes_df.columns.difference(list(RESERVED_COLS)).tolist()
    param_groups_s = group_properties_by_scale(smooth_param_keyframes_df[param_props_s])
    # Pass show_legend=False, show_phase_labels=False as you did
    plot_timeline(smooth_param_history_df, smooth_param_keyframes_df, [param_groups_s[0]], ax=ax_params_s, show_legend=False, show_phase_labels=False, line_styles=line_styles)
    param_vline_s = ax_params_s.axvline(step_s, color='white', linestyle='--', lw=1)
    ax_params_s.set_ylabel('Param value', fontsize='x-small')
    ax_params_s.set_xlabel('') # Remove xlabel, it will be on the plot below
    # Hide x-tick labels because they are shared with the plot below
    plt.setp(ax_params_s.get_xticklabels(), visible=False)

    metrics_steps_s = [h[0] for h in smooth_metrics_history]
    total_losses_s = [h[1] for h in smooth_metrics_history]
    loss_components_s = {k: [h[2].get(k, np.nan) for h in smooth_metrics_history] for k in smooth_metrics_history[0][2].keys()}
    ax_metrics_s.plot(metrics_steps_s, total_losses_s, label='Total Loss', lw=latent_lim)
    for name, values in loss_components_s.items():
        ax_metrics_s.plot(metrics_steps_s, values, label=name, lw=1, alpha=0.8)
    ax_metrics_s.set_xlabel('Step') # Set X label for the bottom plot
    ax_metrics_s.set_ylabel('Loss (log)', fontsize='x-small')
    ax_metrics_s.set_yscale('log')
    ax_metrics_s.set_ylim(bottom=1e-6)
    metrics_vline_s = ax_metrics_s.axvline(step_s, color='white', linestyle='--', lw=1)

    # --- Setup Stepped Plots (Right Column) ---
    step_t, current_latents_t = stepped_latent_history[0]
    ax_latent_t.set_xlim(-latent_lim, latent_lim)
    ax_latent_t.set_ylim(-latent_lim, latent_lim)
    ax_latent_t.set_aspect('equal', adjustable='datalim')
    # ax_latent_t.set_xlabel(f'Dim {dim_pair[0]}') # Set X label for latent plot
    ax_latent_t.tick_params(axis='x', labelleft=False) # Hide x labels
    plt.setp(ax_latent_t.get_xticklabels(), visible=False)
    ax_latent_t.tick_params(axis='y', labelleft=False) # Hide y labels
    ax_latent_t.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax_latent_t.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    ax_latent_t.add_patch(Circle((0, 0), 1, fill=False, linestyle='--', color='gray', alpha=0.3))
    scatter_t = ax_latent_t.scatter(
        current_latents_t[:, dim_pair[0]], current_latents_t[:, dim_pair[1]], c=colors, s=150, alpha=0.7
    )
    title_latent_t = ax_latent_t.set_title('placeholder') # Title set in update()
    # No need to hide x-ticks here anymore

    param_props_t = stepped_param_keyframes_df.columns.difference(list(RESERVED_COLS)).tolist()
    param_groups_t = group_properties_by_scale(stepped_param_keyframes_df[param_props_t])
    # Pass show_legend=False, show_phase_labels=False as you did
    plot_timeline(stepped_param_history_df, stepped_param_keyframes_df, [param_groups_t[0]], ax=ax_params_t, show_legend=False, show_phase_labels=False, line_styles=line_styles)
    param_vline_t = ax_params_t.axvline(step_t, color='white', linestyle='--', lw=1)
    ax_params_t.set_ylabel('')  # Y label only on left
    ax_params_t.set_xlabel('') # Remove xlabel, it will be on the plot below
    # Hide x-tick labels because they are shared with the plot below
    plt.setp(ax_params_t.get_xticklabels(), visible=False)
    ax_params_t.tick_params(axis='y', labelleft=False) # Hide y labels

    metrics_steps_t = [h[0] for h in stepped_metrics_history]
    total_losses_t = [h[1] for h in stepped_metrics_history]
    loss_components_t = {k: [h[2].get(k, np.nan) for h in stepped_metrics_history] for k in stepped_metrics_history[0][2].keys()}
    ax_metrics_t.plot(metrics_steps_t, total_losses_t, label='Total Loss', lw=1.5)
    for name, values in loss_components_t.items():
        ax_metrics_t.plot(metrics_steps_t, values, label=name, lw=1, alpha=0.8)
    ax_metrics_t.set_xlabel('Step') # Set X label for the bottom plot
    ax_metrics_t.set_yscale('log')
    ax_metrics_t.set_ylim(bottom=1e-6)
    ax_metrics_t.tick_params(axis='y', labelleft=False) # Hide y labels
    metrics_vline_t = ax_metrics_t.axvline(step_t, color='white', linestyle='--', lw=1)

    # --- Set common X limits ---
    # Only set xlim for the timeline plots (params and metrics)
    max_step = max(smooth_param_history_df['STEP'].max(), stepped_param_history_df['STEP'].max())
    for ax in [ax_params_s, ax_metrics_s, ax_params_t, ax_metrics_t]:
         ax.set_xlim(left=0, right=max_step)

    # fig.tight_layout(h_pad=0, w_pad=0.5)  # Adjust padding
    fig.subplots_adjust(
        left=0.05,    # Smaller left margin
        right=0.95,   # Smaller right margin
        bottom=0.08,  # Smaller bottom margin (leave room for x-label)
        top=0.95,     # Smaller top margin (leave room for titles)
        wspace=0.1,   # Adjust space between columns (tweak as needed)
        hspace=0.0    # Keep vertical space at 0 (set in GridSpec)
    )

    def update(frame: int):
        # ... (update logic remains the same) ...
        # Assume smooth and stepped histories have the same length and aligned steps after sampling
        smooth_step, current_latents_s = smooth_latent_history[frame]
        stepped_step, current_latents_t = stepped_latent_history[frame]
        # Use the smooth step for titles and lines, assuming they are aligned
        current_step = smooth_step

        # Update smooth plots
        scatter_s.set_offsets(current_latents_s[:, dim_pair])
        title_latent_s.set_text(f'Smooth curriculum (step {current_step})') # Use current_step
        param_vline_s.set_xdata([current_step])
        metrics_vline_s.set_xdata([current_step])

        # Update stepped plots
        scatter_t.set_offsets(current_latents_t[:, dim_pair])
        title_latent_t.set_text(f'Stepped curriculum (step {current_step})') # Use current_step
        param_vline_t.set_xdata([current_step])
        metrics_vline_t.set_xdata([current_step])

        return (
            scatter_s, title_latent_s, param_vline_s, metrics_vline_s,
            scatter_t, title_latent_t, param_vline_t, metrics_vline_t,
        )

    # Use the length of the (potentially strided) latent_history for frames
    # Assuming both histories have the same length after sampling
    num_frames = len(smooth_latent_history)
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=interval * 1000, blit=True)
    return fig, ani


# --- Variable Stride Logic ---
def get_stride(step: int):
    import math

    a = 7.9236
    b = 0.0005
    # Ensure stride is at least 1
    return max(1.0, a * math.log(b * step + 1) + 1)


# Apply stride logic based on smooth history (assuming stepped is similar)
sampled_indices = [0]
last_sampled_index = 0
# Use smooth_latents for stride calculation
while True:
    current_step = smooth_latents[round(last_sampled_index)][0]
    stride = get_stride(current_step)
    next_index = last_sampled_index + stride
    # Ensure indices stay within bounds for *both* histories
    if round(next_index) >= len(smooth_latents) or round(next_index) >= len(stepped_latents):
        break
    sampled_indices.append(round(next_index))
    last_sampled_index = next_index

# Ensure the last frame is included if missed
if sampled_indices[-1] < len(smooth_latents) - 1:
     sampled_indices.append(len(smooth_latents) - 1)

# sampled_indices = sampled_indices[:200]  # Limit the number of samples during development

# Sample both latent histories using the same indices
sampled_smooth_latents = [smooth_latents[i] for i in sampled_indices]
sampled_stepped_latents = [stepped_latents[i] for i in sampled_indices]

# --- End Variable Stride Logic ---

# Filter metrics history to align with the *new* sampled latent history steps
# Use steps from the sampled smooth history (assuming alignment)
sampled_steps_set = {step for step, _ in sampled_smooth_latents}
filtered_smooth_metrics = [h for h in smooth_metrics.history if h[0] in sampled_steps_set]
filtered_stepped_metrics = [h for h in stepped_metrics.history if h[0] in sampled_steps_set]

# Realize timelines for both dopesheets
smooth_timeline = Timeline(smooth_dopesheet)
smooth_history_df = realize_timeline(smooth_timeline)
smooth_keyframes_df = smooth_dopesheet.as_df()

stepped_timeline = Timeline(stepped_dopesheet)
stepped_history_df = realize_timeline(stepped_timeline)
stepped_keyframes_df = stepped_dopesheet.as_df()


# --- Call the updated animation function ---
fig, ani = animate_latent_evolution_with_metrics(
    # Smooth
    smooth_latent_history=sampled_smooth_latents,
    smooth_metrics_history=filtered_smooth_metrics,
    smooth_param_history_df=smooth_history_df,
    smooth_param_keyframes_df=smooth_keyframes_df,
    # Stepped
    stepped_latent_history=sampled_stepped_latents,
    stepped_metrics_history=filtered_stepped_metrics,
    stepped_param_history_df=stepped_history_df,
    stepped_param_keyframes_df=stepped_keyframes_df,
    # Common
    colors=rgb_tensor.cpu().numpy(),
    dim_pair=(0, 1),
)

# --- Save the video ---
video_file = f'large-assets/ex-{nbid}-latent-evolution-comparison.mp4' # Updated filename
num_frames_to_render = len(sampled_smooth_latents) # Base on sampled length
with RichProgress(total=num_frames_to_render, description='Rendering comparison video') as pbar:
    ani.save(
        video_file,
        fps=30,
        extra_args=['-vcodec', 'libx264'],
        progress_callback=lambda i, n: pbar.update(1),
    )
plt.close(fig)

# --- Display the video ---
import secrets
from IPython.display import display, HTML

cache_buster = secrets.token_urlsafe()

display(
    HTML(
        f"""
        <video width="960" height="540" controls loop>
            <source src="{video_file}?v={cache_buster}" type="video/mp4">
            Your browser does not support the video tag.
        </video>
        """
    )
)

NameError: name 'smooth_latents' is not defined

## Observations

Qualitatively, we observe that:
- The Smooth variant seems noisier overall: it's more jittery in general, and becomes more misshapen during the _All hues_ phase. This might be due to the specific values of the hyperparameters, e.g. maybe the normalization loss was too high.
- The Stepped variant does indeed show loss spikes at the start of each phase, while the Smooth varint does not — as predicted! However, the spikes don't seem to cause any problem; perhaps they were fully mitigated by the LR warmup.
- Even though the data are introduced to each variant differently (in chunks to the Stepped variant, and gradually to the Smooth variant), the effect is almost identical. This is particularly apparent at the start of the _Full color space_ phase: the Stepped variant bulges suddenly at the start of the phase, while the Smooth variant bulges a little later and somewhat less violently — but both end up in almost the exact same shape.

Perhaps the dynamics and final latent space could be improved for the Smooth curriculum by reducing the learning rate at times when the parameters are changing a lot — but since per-phase LR schedules are already common in curriculum learning, using them _in addition_ to smooth parameter changes may not have much benefit. On the other hand, we note that the smooth curriculum was easier to specify than the stepped one, purely because it had fewer phases and fewer keyframes.

## Conclusion

Our hypothesis seems to have been wrong: smooth parameter changes _don't_ appear to improve training dynamics compared to a traditional curriculum.