# Experiment 1.6: Smooth vs. stepped hyperparameter transitions

In previous experiments, we explored curriculum learning ([Ex 1.3](./ex-1.3-color-mlp-curriculum.ipynb)) with abrupt phase changes and later introduced smooth hyperparameter transitions using a dopesheet and minimum jerk interpolation ([Ex 1.5](./ex-1.5-color-mlp-anchoring.ipynb)).

This notebook directly compares these two approaches:

1.  **Stepped transitions:** Mimicking the traditional approach with discrete phases and sharp parameter changes at boundaries. We'll simulate the LR warmup used in Ex 1.3 within the dopesheet.
2.  **Smooth transitions:** Using the minimum jerk trajectories from Ex 1.5 for all hyperparameters.

Both methods will use the same 4D bottleneck model architecture, initialization seeds, loss functions (including anchoring), and target the same final hyperparameter values at equivalent points in the curriculum.

## Hypothesis

While both approaches might reach similar final performance, we hypothesize that the smooth transitions will lead to:

- **More stable training:** Fewer and smaller loss spikes, especially during periods corresponding to phase transitions in the stepped approach.
- **Lower loss variance:** Quantifiably less fluctuation in the loss signal.
- **Smoother latent space evolution:** A more gradual and less chaotic development of the final representation structure.

## Experiment design

We'll train the 4D MLP autoencoder using two different dopesheets representing the stepped and smooth schedules. We will track:

- Training loss curves (total and components).
- Loss variance over time.
- Final latent space structure.
- Evolution of the latent space (via animation, similar to Ex 1.5).

In [None]:
from __future__ import annotations

In [None]:
import logging

from utils.logging import SimpleLoggingConfig

# Configure logging
logging_config = SimpleLoggingConfig().info('notebook', 'utils', 'mini', 'ex_color')
logging_config.apply()
log = logging.getLogger('notebook')

## Model architecture

We'll use the same 4-dimensional bottleneck MLP as in Experiment 1.5.

In [None]:
import torch.nn as nn
from torch import Tensor

E = 4  # Bottleneck dimension


class ColorMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(3, 16),
            nn.GELU(),
            nn.Linear(16, E),
        )
        self.decoder = nn.Sequential(
            nn.Linear(E, 16),
            nn.GELU(),
            nn.Linear(16, 3),
            nn.Sigmoid(),
        )

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        bottleneck = self.encoder(x)
        output = self.decoder(bottleneck)
        return output, bottleneck


# Instantiate once to check params, will re-instantiate for each run
temp_model = ColorMLP()
total_params = sum(p.numel() for p in temp_model.parameters() if p.requires_grad)
log.info(f'Model initialized with {total_params:,} trainable parameters.')
del temp_model

## Data loading

We need the same datasets as in Ex 1.5:
- `primary_dataset`: Just the 6 primary and secondary colors (used for anchoring).
- `full_dataset`: The full grid of H, S, V colors (used for general training).
- `rgb_tensor`: A dense grid covering the RGB cube (used for validation and visualization).

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

from ex_color.losses import Separate, planarity, objective, unitary
from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic

# --- Primary Colors (for Anchor) ---
primary_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 6),
    s=np.ones(1),
    v=np.ones(1),
)
primary_tensor = torch.tensor(primary_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
primary_dataset = TensorDataset(primary_tensor)
# Batch size = all primary colors, as Anchor processes them together
primary_loader = DataLoader(primary_dataset, batch_size=len(primary_tensor))
log.info(f'Primary dataset shape: {primary_tensor.shape}')

# --- Full HSV Grid (for Training) ---
full_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=10 / 360),
    s=np.linspace(0, 1, 10),
    v=np.linspace(0, 1, 10),
)
full_tensor = torch.tensor(full_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
full_dataset = TensorDataset(full_tensor)
# Sampler will be configured per run
log.info(f'Full HSV dataset shape: {full_tensor.shape}')

# --- Full RGB Grid (for Validation/Visualization) ---
rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 10),
    g=np.linspace(0, 1, 10),
    b=np.linspace(0, 1, 10),
)
rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
log.info(f'RGB validation tensor shape: {rgb_tensor.shape}')

# --- Datasets Dictionary (used by training loop) ---
# We'll create the specific loaders/samplers within each run setup
datasets_config: dict[str, tuple[TensorDataset, Tensor]] = {
    'Primary & secondary': (primary_dataset, primary_tensor),
    'All hues': (full_dataset, rgb_tensor),
    'Full color space': (full_dataset, rgb_tensor),
}

# --- Loss Criteria ---
loss_criteria = {
    'loss-recon': objective(nn.MSELoss()),
    'reg-separate': Separate((0, 1)),  # Separate based on first two dims
    'reg-planar': planarity,  # Penalize dims 2 and 3
    'reg-norm': unitary,
    # Anchor needs to be instantiated per run with fresh ref_data
    # 'reg-anchor': Anchor(ref_data=primary_tensor),
}

## Event handlers and callbacks

We'll reuse the `ModelRecorder` and `MetricsRecorder` from the engine. We also need the callback to update the `DynamicWeightedRandomBatchSampler` based on the `data-fraction` property from the dopesheet.

In [None]:
from ex_color.data.cube_sampler import DynamicWeightedRandomBatchSampler
from ex_color.data.filters import levels
from ex_color.engine.events import Event


# Callback to update the sampler weights based on dopesheet
def update_sampler_weights(event: Event, sampler: DynamicWeightedRandomBatchSampler, weights_source: np.ndarray):
    frac = event.timeline_state.props['data-fraction']
    # Interpolation logic from Ex 1.5 to shift focus from vibrant to all colors
    in_low = np.interp(frac, [0, 0.5], [0.99, 0])
    out_low = np.interp(frac, [0.5, 1], [0, 1])
    sampler.weights = levels(weights_source, in_low=in_low, out_low=out_low)

## Dopesheets for comparison

We define two dopesheets, stored in separate CSV files:
1. **Stepped:** [`ex-1.6-stepped-dopesheet.csv`](./ex-1.6-stepped-dopesheet.csv) - Mimics the abrupt phase changes and LR schedule from Ex 1.3.
2. **Smooth:** [`ex-1.6-smooth-dopesheet.csv`](./ex-1.6-smooth-dopesheet.csv) - Uses minimum jerk interpolation between keyframes, based on Ex 1.5.

In [None]:
from pathlib import Path

from mini.temporal.dopesheet import Dopesheet

dopesheet_dir = Path('.')  # Assumes notebook is run from docs/

# --- Load Dopesheets ---
stepped_dopesheet_path = dopesheet_dir / 'ex-1.6-stepped-dopesheet.csv'
smooth_dopesheet_path = dopesheet_dir / 'ex-1.6-smooth-dopesheet.csv'

stepped_dopesheet = Dopesheet.from_csv(stepped_dopesheet_path)
smooth_dopesheet = Dopesheet.from_csv(smooth_dopesheet_path)

log.info(f'Loaded stepped dopesheet from {stepped_dopesheet_path}')
log.info(f'Loaded smooth dopesheet from {smooth_dopesheet_path}')

# Display the resolved dopesheets (optional)
# display(stepped_dopesheet.as_df(styled=True))
# display(smooth_dopesheet.as_df(styled=True))

# --- Verify Initial Values ---
initial_stepped = stepped_dopesheet.get_initial_values()
initial_smooth = smooth_dopesheet.get_initial_values()
log.info(f'Initial values (Stepped): {initial_stepped}')
log.info(f'Initial values (Smooth): {initial_smooth}')

# Check that they are identical (they should be)
assert initial_stepped == initial_smooth, 'Initial values differ between dopesheets!'