In [None]:
import torch

from NeuroVisualizer.neuro_aux.AEmodel import UniformAutoencoder

from helper.neuro_viz import get_dataloader_flat

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Neuro-Visualizer
This notebook creates the loss landscape from the NeuroVisualizer

In [None]:
dataset_name = 'mnist'
run_ids = [
    "run-0011-CNN_mnist_32_0.9776", # No Residual
    "run-0012-CNN_mnist_32_0.9768", # No Residual
]

titles = [
    "SAM, 0.9776",
    "SGD, 0.9768"
]

In [None]:
# CNN x CIFAR 10 
dataset_name = 'cifar10'

run_ids = [
    "run-0017-CNN_cifar10_128_0.8072",  # Seed 42, SAM
    "run-0019-CNN_cifar10_128_0.8487",  # Seed 42
    "run-0021-CNN_cifar10_128_0.8054",  # Seed 11, SAM
    "run-0023-CNN_cifar10_128_0.8509",  # Seed 11
    "run-0025-CNN_cifar10_128_0.8062",
    "run-0027-CNN_cifar10_128_0.8503"
]

titles = [
    "Seed 42, SAM, 0.8072",
    "Seed 42, SGD, 0.8487",
    "Seed 11, SAM, 0.8054",
    "Seed 11, SGD, 0.8509",
    "Seed 6, SAM, 0.8062",
    "Seed 6, SGD, 0.8503",
]

In [None]:
# CNN Residual x CIFAR 10 
dataset_name = 'cifar10'

run_ids = [
    "run-0016-CNN_cifar10_128_0.8093",  # Seed 42, SAM, Residual
    "run-0018-CNN_cifar10_128_0.8499",  # Seed 42, Residual
    "run-0020-CNN_cifar10_128_0.8079",  # Seed 11, SAM, Residual
    "run-0022-CNN_cifar10_128_0.8519",  # Seed 11, Residual
    "run-0024-CNN_cifar10_128_0.8062",
    "run-0026-CNN_cifar10_128_0.8504"
]

titles = [
    "Seed 42, SAM, Residual 0.8093",
    "Seed 42, SGD, Residual 0.8499",
    "Seed 11, SAM, Residual 0.8079",
    "Seed 11, SGD, Residual 0.8519",
    "Seed 6, SAM, Residual 0.0.8062",
    "Seed 6, SGD, Residual 0.8504",
]

### Load Paths

In [None]:
from helper.visualization import Run

runs = []
for run_id in run_ids:
    runs.append(Run(run_id, dataset_name))

In [None]:
pt_files_per_run = [run.get_pt_files() for run in runs]

In [None]:
vis_id = ' x '.join([run.results["ll_flattened_weights_dir"] for run in runs])
model_file = f'ae_models/{vis_id}.pt'
print(model_file)

#model_file = "ae_models/run-0016-CNN x run-0018-CNN x run-0020-CNN x run-0022-CNN x run-0024-CNN x run-0026-CNN.pt"

In [None]:
pt_files_per_run

In [None]:
# Filter for final epochs only

pt_files_per_run = []

for run in runs:
    min_loss = min(run.results["val_losses"])
    max_visualize = min_loss * 1.1
    print(max_visualize)
    # Find in run.results["val_losses"] idx where val losses is first below max_visualize
    idx = next((i for i, v in enumerate(run.results["val_losses"]) if v <= max_visualize), None)
    pt_files_per_run.append(run.get_pt_files()[idx:])
    
pt_files_per_run

## Train AE Model
Run this part to train an AE-Model

In [None]:
batch_size = 3 #4 - 32 Batch Size of AE Training

loader, normalizer = get_dataloader_flat(
    pt_files_per_run,
    batch_size,
    include_lmc=False,
    shuffle=True,
    oversample_later=False, # more samples from later epochs, that diverge more
    power=1.0
)

In [None]:
torch.cuda.empty_cache()

Adjust: Choose the hidden dimension (that the model-GPU combination is still working with)

In [None]:
input_dim = loader.dataset[0].shape[0]
print(f"Input dimension: {input_dim}")

latent_dim = 2
num_layers = 4

# Aggressive compression (scales with first hidden dim)
#h = [input_dim, 64, 32, 8]
#h = [input_dim, 126, 64, 32]
#h = [input_dim, 200, 100, 50]
#ae = UniformAutoencoder(input_dim, num_layers, latent_dim, h=h).to(device)

ae = UniformAutoencoder(input_dim, num_layers, latent_dim).to(device)

In [None]:
total_params = sum(p.numel() for p in ae.parameters())
trainable_params = sum(p.numel() for p in ae.parameters() if p.requires_grad)

size_mb = total_params * 4 / (1024**2)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Approx. size: {size_mb:.2f} MB")

# print(f"Approx. size: {size_mb:.2f} MB")

In [None]:
# Load from previous train process (if available, eg. after a crash)
ae.load_state_dict(torch.load(model_file, weights_only=True))

In [None]:
from helper.neuro_viz import train_autoencoder

trained_model = train_autoencoder(
    model=ae,
    train_loader=loader,
    device=device,
    save_path=model_file,
    num_epochs=100, #1000 would be great
    lr=0.0001, # Start with 0.01
    patience=15,
    avoid_overheat=False, # Avoids chrashes on Nembus Computer
    last_saved_loss=0.40143, # Minimum Loss to save
    save_delta_pct=0.02,
    verbose=True
)
# ~ 0.0173 possible (CIFAR10 CNN)

In [None]:
torch.save(ae.state_dict(), model_file)

## Visualize Trajectory
Begin here when trained Autoencoder (AE) can be loaded

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from NeuroVisualizer.neuro_aux.AEmodel import UniformAutoencoder
from NeuroVisualizer.neuro_aux.utils import get_files, repopulate_model
from NeuroVisualizer.neuro_aux.trajectories_data import get_trajectory_dataloader

In [None]:
batch_size = 4
loss_name = 'test_loss'
whichloss = 'mse' # this is CrossEntropyLoss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Get file list
# pt_files = get_files(model_folder, prefix="model-")

# Load AE
example_tensor = torch.load(pt_files_per_run[0][0], weights_only=True)
input_dim = example_tensor.shape[0]
latent_dim = 2
num_layers = 4
#h = [input_dim, 64, 32, 8]
#h = [input_dim, 128, 64, 16]
#h = [input_dim, 200, 100, 50]

#ae_model = UniformAutoencoder(input_dim, num_layers, latent_dim, h=h).to(device)
ae_model = UniformAutoencoder(input_dim, num_layers, latent_dim).to(device)
ae_model.load_state_dict(torch.load(model_file, weights_only=True))
_ = ae_model.eval()

In [None]:
# ---- Load data ----
from helper.neuro_viz import get_dataloader_flat

trajectory_loader, transform = get_dataloader_flat(pt_files_per_run, batch_size, shuffle=False) #[:5] for Subset

### Repopulate original Model Architecture
**IMPORTANT: needs correct model**

In [None]:
for run in runs:
    print(run.results["model_info"])

In [None]:
from helper.vision_classification import init_mlp_for_dataset, init_cnn_for_dataset
from helper.neuro_viz import Loss

#TODO Check the model:
model = init_cnn_for_dataset(dataset_name, conv_dims=[8, 16], kernel_sizes=[3, 3], hidden_dims=[32], dropout=0.25, residual=False).to(device)
#model = init_cnn_for_dataset(dataset_name, conv_dims=[8, 16], kernel_sizes=[3, 3], hidden_dims=[32], dropout=0.25, residual=True).to(device)

#model = init_cnn_for_dataset(dataset_name, conv_dims=[32, 64], kernel_sizes=[3, 3], hidden_dims=[128], dropout=0.25, residual=False).to(device)
#model = init_cnn_for_dataset(dataset, conv_dims=[32, 64], kernel_sizes=[3, 3], hidden_dims=[128], dropout=0.25, residual=True).to(device)
#model = init_cnn_for_dataset(dataset_name, conv_dims=[64, 128, 256], kernel_sizes=[5, 3, 3], hidden_dims=[256, 128], dropout=0.2, residual=True).to(device)
#model = init_mlp_for_dataset(dataset_name, hidden_dims=[254, 64], dropout=0.1).to(device)
loss_obj = Loss(dataset_name, device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable parameters: {count_parameters(model):,}")
print(f"AE input/output dim:  {ae_model.encoder.fcs[0].in_features:,}")

#### Compute trajectory (Coordinates and Loss)

In [None]:
from helper.neuro_viz import compute_trajectory

trajectory_coordinates, trajectory_models, trajectory_losses = compute_trajectory(
    trajectory_loader,
    ae_model,
    transform,
    loss_obj,
    model,
    loss_name,
    whichloss,
    device,
    recalibrate_bn=True, # Optimizes Loss precision
)

In [None]:
import torch
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector
from tqdm import tqdm
from helper.neuro_viz import repopulate_model_fixed

def _flatten_model_vec(m: torch.nn.Module) -> torch.Tensor:
    # If you have a project-specific flattener, call it here instead.
    return parameters_to_vector(m.parameters()).detach()

def compute_trajectory(
    trajectory_loader,
    ae_model,
    transform,
    loss_obj,
    model,
    loss_name,
    whichloss,
    device,
    recalibrate_bn: bool = True,
    bn_recal_batches: int = 100,
    bn_loader=None,
):
    """
    Returns:
      trajectory_coordinates: [N, 2]
      trajectory_models:     [N, D] (decoded, de-normalized)
      trajectory_losses:     [N]    (task loss per repopulated model)
      ae_losses_decode:      [N]    (AE loss on normalized input -> output during decode)
      ae_losses_finetuned:   [N]    (AE loss after repopulate/BN on flattened model)
    """
    if not recalibrate_bn:
        print("Tip: set recalibrate_bn=True to refresh BN stats for more accurate losses.")

    ae_model.eval()
    model = model.to(device)
    model_device = next(model.parameters()).device
    total_params = sum(p.numel() for p in model.parameters())

    mean = transform.mean.to(device)
    std  = transform.std.to(device)

    # ---- Decode trajectory + AE loss (no fine-tuning) ----
    trajectory_models, trajectory_coordinates = [], []
    ae_losses_decode = []

    with torch.no_grad():
        for batch in tqdm(trajectory_loader, desc="Decoding trajectory"):
            # batch: normalized flattened weights, shape [B, D]
            batch = batch.to(device)

            x_recon_norm, z = ae_model(batch)                 # AE output in normalized space
            # per-sample MSE over feature dim
            ae_mse = F.mse_loss(x_recon_norm, batch, reduction='none').mean(dim=1)  # [B]
            ae_losses_decode.append(ae_mse.cpu())

            # store coords + de-normalized decoded weights for repopulation
            trajectory_coordinates.append(z.cpu())
            x_recon = x_recon_norm * std + mean               # de-normalize
            trajectory_models.append(x_recon.cpu())

    trajectory_coordinates = torch.cat(trajectory_coordinates, dim=0)
    trajectory_models = torch.cat(trajectory_models, dim=0)
    ae_losses_decode = torch.cat(ae_losses_decode, dim=0).float()  # [N]

    print(f"✅ Decoded trajectory shapes: coords {trajectory_coordinates.shape}, models {trajectory_models.shape}")

    # BN recal source
    bn_src = bn_loader or getattr(loss_obj, "train_loader", None) or trajectory_loader

    # ---- Compute task losses + AE loss after repopulation ----
    trajectory_losses = []
    ae_losses_finetuned = []

    for i in tqdm(range(trajectory_models.shape[0]), desc="Computing trajectory & AE(finetuned) losses"):
        flat_cpu = trajectory_models[i, :]
        assert flat_cpu.numel() == total_params, "Mismatch in parameter size."

        with torch.no_grad():
            # repopulate model from decoded weights
            flat = flat_cpu.to(model_device, non_blocking=True)
            model = repopulate_model_fixed(flat, model)  # in-place or returns model

        # (optional) BN recalibration
        if recalibrate_bn and bn_src is not None:
            model.train()
            with torch.no_grad():
                for b_idx, batch in enumerate(bn_src):
                    x = batch[0] if (isinstance(batch, (list, tuple)) and len(batch) >= 1) else batch
                    model(x.to(model_device, non_blocking=True))
                    if b_idx + 1 >= bn_recal_batches:
                        break
            model.eval()

        # Task loss
        with torch.no_grad():
            loss_val = loss_obj.get_loss(model, loss_name, whichloss).item()
        trajectory_losses.append(loss_val)

        # AE loss on the (re)flattened, finetuned model
        with torch.no_grad():
            flat_ft = _flatten_model_vec(model).to(device).view(1, -1)   # [1, D] on AE device
            flat_ft_norm = (flat_ft - mean) / std                        # normalize to AE space
            recon_ft_norm, _ = ae_model(flat_ft_norm)                    # AE forward
            ae_mse_ft = F.mse_loss(recon_ft_norm, flat_ft_norm, reduction='none').mean(dim=1)  # [1]
            ae_losses_finetuned.append(ae_mse_ft.squeeze(0).cpu().item())

    trajectory_losses = torch.tensor(trajectory_losses, dtype=torch.float32)         # [N]
    ae_losses_finetuned = torch.tensor(ae_losses_finetuned, dtype=torch.float32)     # [N]

    print(f"✅ Computed {trajectory_losses.shape[0]} task losses\n"
          f"AE(decode) losses: {ae_losses_decode.mean()}, "
          f"AE(finetuned) losses: {ae_losses_finetuned.mean()}")

    return trajectory_coordinates, trajectory_models, trajectory_losses

trajectory_coordinates, trajectory_models, trajectory_losses = compute_trajectory(
    trajectory_loader,
    ae_model,
    transform,
    loss_obj,
    model,
    loss_name,
    whichloss,
    device,
    recalibrate_bn=True, # Optimizes Loss precision
)

In [None]:
print(ae_losses_decode.mean())
print(ae_losses_finetuned.mean())

In [None]:
# Get lengths for each run
chunk_sizes = [len(run) for run in pt_files_per_run]
num_chunks = len(chunk_sizes)

# Split trajectory arrays according to these lengths
tr_losses = np.split(trajectory_losses.cpu().numpy(), np.cumsum(chunk_sizes)[:-1])
tr_coordinates = np.split(trajectory_coordinates.cpu().numpy(), np.cumsum(chunk_sizes)[:-1])

In [None]:
real_losses = [run.results["val_losses"] for run in runs]

In [None]:
# Fix, that there is one epoch 0 for the pt files
for i in range(num_chunks):
    #first_loss = tr_losses[i][0]
    real_losses[i] = np.concatenate(([np.NaN], real_losses[i]))

In [None]:
titles

In [None]:
import matplotlib.pyplot as plt

cols = 2
rows = int(np.ceil(num_chunks / cols))

fig, axes = plt.subplots(rows, cols, figsize=(14, rows * 4), squeeze=False)

for i in range(num_chunks):
    r, c = divmod(i, cols)
    ax = axes[r, c]

    ax.plot(real_losses[i], label='Logged Validation Loss', marker='o')
    ax.plot(tr_losses[i], label='AE-Projected Validation Loss', marker='x')

    ax.set_title(titles[i])
    ax.set_xlabel('Checkpoint Index')
    ax.set_ylabel('Loss (Cross Entropy)')
    ax.grid(True)
    ax.legend()

# Hide unused subplots (if odd number of runs)
for j in range(num_chunks, rows * cols):
    r, c = divmod(j, cols)
    fig.delaxes(axes[r, c])

plt.tight_layout()
plt.show()

In [None]:
# Generate grid in latent space
from helper.neuro_viz import generate_latent_grid, compute_grid_losses, compute_grid_losses_batched
xx, yy, grid_coords = generate_latent_grid(
    min_map=-1.1, max_map=1.1,
    xnum=10, # 3 - 25
    device=device
)

grid_losses = compute_grid_losses_batched(
    grid_coords,
    transform,
    ae_model,
    model,
    loss_obj,
    loss_name,
    whichloss,
    device,
    bn_recal_batches=30,
)

# Reshape to grid
grid_losses = grid_losses.view(xx.shape)

In [None]:
print(grid_losses.min().item(), grid_losses.max().item())

In [None]:
rec_grid_models = ae_model.decoder(grid_coords)
rec_grid_models = rec_grid_models*transform.std.to(device) + transform.mean.to(device)

If CUDA out of memory

In [None]:
def decode_grid_in_batches(ae_model, grid_coords, transform, device, batch_size=32):
    ae_model.eval()
    std = transform.std.to(device)
    mean = transform.mean.to(device)
    chunks = []

    with torch.no_grad():
        for i in range(0, grid_coords.size(0), batch_size):
            coords = grid_coords[i : i + batch_size].to(device)      # [B,2]
            rec = ae_model.decoder(coords)                            # [B, D]
            rec = rec * std + mean                                    # [B, D]
            chunks.append(rec.cpu())     # move back to CPU immediately
            del coords, rec
            torch.cuda.empty_cache()     # free any cached GPU memory

    return torch.cat(chunks, dim=0)      # [N, D]

In [None]:
rec_grid_models = decode_grid_in_batches(
    ae_model, grid_coords, transform, device, batch_size=16
)

In [None]:
from helper.neuro_viz import plot_loss_landscape

fig = plot_loss_landscape(
    xx, yy,
    grid_losses,
    real_losses, # real_losses or tr_losses
    tr_coordinates,
    rec_grid_models=rec_grid_models,
    draw_density=False,
    filled_contours=False
)

In [None]:
# Save to PDF
os.makedirs('plots', exist_ok=True)
fig.savefig(f'plots/loss_landscape_{vis_id}.pdf', dpi=300, bbox_inches='tight', format='pdf')
print(f"Saved PDF to plots/loss_landscape_{vis_id}.pdf")

In [None]:
fig = plot_loss_landscape(
    xx, yy,
    grid_losses,
    real_losses, # real_losses or tr_losses
    tr_coordinates,
    rec_grid_models=rec_grid_models,
    draw_density=False,
    filled_contours=False,
    trajectory_labels=titles,            # NEW: list of strs, one per trajectory - ['Test 1', 'Test 2', 'Test 3', 'Test 4'],
    label_positions=[('left', 'top'), ('left', 'bottom'), ('right', 'bottom'), ('right', 'top'), ('left', 'center'), ('center', 'bottom')], # ('left'|'center'|'right', 'top'|'center'|'bottom')
)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.ticker as ticker
import numpy as np

def plot_loss_landscape(
    xx, yy,
    grid_losses, trajectory_losses_list, trajectory_coords_list,
    rec_grid_models=None,
    draw_density=True,
    filled_contours=True,
    cmap='viridis',
    loss_label='Cross Entropy Loss',
    trajectory_labels=None,
    label_positions=None,
):
    # === PREPARE LOSSES ===
    grid_losses_pos = grid_losses.detach().cpu().numpy()

    # === SHARED COLOR SCALE ===
    traj_losses_all = np.concatenate([t for t in trajectory_losses_list])
    all_losses = np.concatenate([grid_losses_pos.flatten(), traj_losses_all])
    vmin = np.clip(all_losses.min() / 1.2, 1e-5, None)
    vmax = all_losses.max() * 1.2

    if vmin >= vmax or np.isclose(vmin, vmax):
        vmax = vmin * 10
        print(f"Adjusted nearly-constant losses: vmin={vmin}, vmax={vmax}")

    levels = np.logspace(np.log10(vmin), np.log10(vmax), 30)
    norm = LogNorm(vmin=vmin, vmax=vmax)

    # === BEGIN PLOTTING ===
    fig, ax = plt.subplots(figsize=(8, 6))

    # -- 1 Loss Landscape --
    X = xx.cpu().numpy()
    Y = yy.cpu().numpy()

    if filled_contours:
        contour = ax.contourf(X, Y, grid_losses_pos, levels=levels, norm=norm, cmap=cmap)
    else:
        contour = ax.contour(X, Y, grid_losses_pos, levels=levels, norm=norm, cmap=cmap)
        ax.clabel(contour, fmt="%.2e", fontsize=8)

    cbar = plt.colorbar(contour, ax=ax, shrink=0.8)
    ticks = np.logspace(np.log10(vmin), np.log10(vmax), 5)  # customize number here
    cbar.set_ticks(ticks)
    cbar.ax.set_ylabel(loss_label, fontsize=12)

    # -- 2 & 3: Plot Multiple Trajectories --
    for z_tensor, losses_tensor in zip(trajectory_coords_list, trajectory_losses_list):
        z = z_tensor
        losses = losses_tensor
        # Lines
        for i in range(len(z) - 1):
            ax.plot([z[i, 0], z[i + 1, 0]], [z[i, 1], z[i + 1, 1]], color='k', linewidth=1)
        # Points
        ax.scatter(
            z[:, 0], z[:, 1],
            c=losses,
            cmap=cmap,
            norm=norm,
            s=40,
            edgecolors='k'
        )

    # ===== 3b: Annotate each trajectory at its last point =====
    offset_pts = 5  # how far, in points, to shift the label

    # defaults
    n_traj = len(trajectory_coords_list)
    if trajectory_labels is None:
        trajectory_labels = [f"traj {i}" for i in range(n_traj)]
    if label_positions is None:
        label_positions = ['auto'] * n_traj

    for idx, (z, losses, lab) in enumerate(zip(
            trajectory_coords_list,
            trajectory_losses_list,
            trajectory_labels)):
        x_end, y_end = float(z[-1, 0]), float(z[-1, 1])

        # decide alignment
        pos = label_positions[idx]
        if pos != 'auto':
            ha, va = pos
        else:
            dx = z[-1, 0] - z[-2, 0]
            dy = z[-1, 1] - z[-2, 1]
            ha = 'left'   if dx >= 0 else 'right'
            va = 'bottom' if dy >= 0 else 'top'

        # convert alignment into point‐offset direction
        ox =  offset_pts if ha == 'left'   else (-offset_pts if ha == 'right' else 0)
        oy =  offset_pts if va == 'bottom' else (-offset_pts if va == 'top'   else 0)

        # annotate with offset
        ax.annotate(
            lab,
            xy=(x_end, y_end),
            xytext=(ox, oy),
            textcoords='offset points',
            ha=ha, va=va,
            fontsize=7,
            bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6),
            arrowprops=dict(arrowstyle='-', lw=0)
        )

    # -- 4 OPTIONAL: Density Contours --
    if draw_density and rec_grid_models is not None:
        try:
            from NeuroVisualizer.neuro_aux.utils import get_density
            density = get_density(rec_grid_models.detach().cpu().numpy(), type='inverse', p=2)
            density = density.reshape(xx.shape)
            density_levels = np.logspace(
                np.log10(max(density.min(), 1e-3)),
                np.log10(density.max()),
                15
            )
            CS_density = ax.contour(
                X, Y, density,
                levels=density_levels,
                colors='white',
                linewidths=0.8
            )
            ax.clabel(CS_density, fmt=ticker.FormatStrFormatter('%.1f'), fontsize=7)
        except Exception as e:
            print("Density contour skipped:", e)

    # -- 5 Labels, Grid, Style --
    ax.set_title('Loss Landscape with Training Trajectory', fontsize=14)
    ax.set_xlabel('Latent Dimension 1', fontsize=12)
    ax.set_ylabel('Latent Dimension 2', fontsize=12)
    ax.grid(True, linestyle='--', alpha=0.3)

    # -- 6 Show --
    #plt.show()

    return fig