# Model Analysis 

This program takes in a folder name with datasets and models, and in that folder generates a series of images and videos describing how the loss landscape changes over time.

## Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import time
import glob

# Importing our existing funcs
import os
import sys
from pathlib import Path

from minima_volume.dataset_funcs import ( load_models_and_data )

from minima_volume.landscape_videos import (
    plot_2d_loss_landscape_3model,
    plot_3d_loss_landscape,
    load_computed_loss_landscapes,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model

In [2]:
# User specifies the model module name
import swiss_model_data as model_module

model_seed = 0 # filler for the template

# Swiss specific initialization parameters
input_dim = 2
output_dim = 1
hidden_dims = [32]*5

# Grab model
model_template = model_module.get_model(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim, device=device, seed=model_seed)

# Grab loss and metrics
loss_fn_per_sample = model_module.get_loss_fn_per_sample()

## Loading Model and Datasets

In [3]:
# Get the relative path
target_dir = Path("models_and_data") #current directory

# Lists to store loaded models and additional data
loaded_models, loaded_additional_data, loaded_dataset = load_models_and_data(model_template=model_template, target_dir=target_dir, device="cpu")

# Print dataset information
print(f"Dataset type: {loaded_dataset['dataset_type']}")
print(f"Dataset quantities: {loaded_dataset['dataset_quantities']}")

# Print tensor shapes
print("\nTensor shapes:")
for key in ["x_base_train", "y_base_train", "x_additional", "y_additional", "x_test", "y_test"]:
    if loaded_dataset[key] is not None:
        print(f"  {key}: {loaded_dataset[key].shape}")
    else:
        print(f"  {key}: None")

Looking for models and dataset in: models_and_data
Found 6 model files:
  - model_additional_0.pt
  - model_additional_180.pt
  - model_additional_30.pt
  - model_additional_380.pt
  - model_additional_780.pt
  - model_additional_80.pt
✅ Model loaded into provided instance from models_and_data\model_additional_0.pt
Successfully loaded: model_additional_0.pt
✅ Model loaded into provided instance from models_and_data\model_additional_180.pt
Successfully loaded: model_additional_180.pt
✅ Model loaded into provided instance from models_and_data\model_additional_30.pt
Successfully loaded: model_additional_30.pt
✅ Model loaded into provided instance from models_and_data\model_additional_380.pt
Successfully loaded: model_additional_380.pt
✅ Model loaded into provided instance from models_and_data\model_additional_780.pt
Successfully loaded: model_additional_780.pt
✅ Model loaded into provided instance from models_and_data\model_additional_80.pt
Successfully loaded: model_additional_80.pt

Mod

## Plotting Parameters

Determines what models to use.

In [4]:
index1 = 0
index2 = 1
index3 = 4
index_list = [index1, index2, index3]

print(
    f"Loading models trained with: "
    f"{loaded_additional_data[index1]['additional_data']}, "
    f"{loaded_additional_data[index2]['additional_data']}, "
    f"{loaded_additional_data[index3]['additional_data']}"
)

model_0 = loaded_models[index1]
model_1 = loaded_models[index2]
model_2 = loaded_models[index3]

print("=== Selected Models ===")
dataset_type = loaded_dataset['dataset_type']
model_labels = [
    f"Model Trained with {loaded_additional_data[index1]['additional_data']} {dataset_type}",
    f"Model Trained with {loaded_additional_data[index2]['additional_data']} {dataset_type}",
    f"Model Trained with {loaded_additional_data[index3]['additional_data']} {dataset_type}"
]

# Prepare dataset with additional_data = 0 (base dataset only)
x_base_train = loaded_dataset['x_base_train']
y_base_train = loaded_dataset['y_base_train']
x_additional = loaded_dataset['x_additional']
y_additional = loaded_dataset['y_additional']

Loading models trained with: 0, 180, 780
=== Selected Models ===


## Loss Landscape Parameters

Loss landscapes have been computed, here's parameters to help plot.

In [5]:
# Create storage directory
varying_dataset_dir = Path("varying_dataset")
varying_dataset_dir.mkdir(exist_ok=True)

# Define the additional data amounts to test
# This can be varied depending on the use case!!
additional_data_amounts = range(len(x_additional)+1)

# Create grid
N = 80
span_low, span_high = -0.5, 1.5
x = np.linspace(span_low, span_high, N)
y = np.linspace(span_low, span_high, N)
C1_grid, C2_grid = np.meshgrid(x, y)
grid_coefficients = (C1_grid, C2_grid)

print(f"Grid size: {N}x{N}, Range: [{span_low}, {span_high}]")

Grid size: 80x80, Range: [-0.5, 1.5]


## Visualization

Load from folder and then visualize.

In [6]:
# Load specific indexes
loss_grids_list, C1_grid, C2_grid = load_computed_loss_landscapes(
    output_dir=varying_dataset_dir,
    indexes_to_load=range(len(additional_data_amounts))
)

✓ Loaded index 0: landscape_index_0_amount_0.npz
✓ Loaded index 1: landscape_index_1_amount_1.npz
✓ Loaded index 2: landscape_index_2_amount_2.npz
✓ Loaded index 3: landscape_index_3_amount_3.npz
✓ Loaded index 4: landscape_index_4_amount_4.npz
✓ Loaded index 5: landscape_index_5_amount_5.npz
✓ Loaded index 6: landscape_index_6_amount_6.npz
✓ Loaded index 7: landscape_index_7_amount_7.npz
✓ Loaded index 8: landscape_index_8_amount_8.npz
✓ Loaded index 9: landscape_index_9_amount_9.npz
✓ Loaded index 10: landscape_index_10_amount_10.npz
✓ Loaded index 11: landscape_index_11_amount_11.npz
✓ Loaded index 12: landscape_index_12_amount_12.npz
✓ Loaded index 13: landscape_index_13_amount_13.npz
✓ Loaded index 14: landscape_index_14_amount_14.npz
✓ Loaded index 15: landscape_index_15_amount_15.npz
✓ Loaded index 16: landscape_index_16_amount_16.npz
✓ Loaded index 17: landscape_index_17_amount_17.npz
✓ Loaded index 18: landscape_index_18_amount_18.npz
✓ Loaded index 19: landscape_index_19_amou

In [7]:
# Create directories for different plot types
plot_dirs = {
    'dynamic_vmin': 'plots_dynamic_vmin',
    'fixed_vmin_large': 'plots_fixed_vmin_large', 
    'fixed_vmin_small': 'plots_fixed_vmin_small', 
    '3d': 'plots_3d',
}

for dir_name in plot_dirs.values():
    Path(dir_name).mkdir(parents=True, exist_ok=True)

# Find global min and max across ALL landscapes for consistent fixed scaling
global_vmin_smallest = float('inf')
global_vmin_largest = float('-inf')
all_vmax = float('-inf')

for idx in range(len(loss_grids_list)):
    if 0 <= idx < len(loss_grids_list):
        loss_grid = loss_grids_list[idx]
        current_min = loss_grid[loss_grid > 0].min()
        current_max = loss_grid.max()
        global_vmin_smallest = min(global_vmin_smallest, current_min)
        global_vmin_largest = max(global_vmin_largest, current_min)
        all_vmax = max(all_vmax, current_max)

print(f"Global min across selected landscapes (smallest value, features are indistinguishable with low variability): {global_vmin_smallest:.6f}")
print(f"Global min across selected landscapes (largest value that captures the variability in the most constant version): {global_vmin_largest:.6f}")
print(f"Global max across ALL landscapes: {all_vmax:.6f}")

# Generate all three types of plots for each index

start_time = time.time()

for idx in range(len(loss_grids_list)):
    if idx < 0 or idx >= len(loss_grids_list):
        continue
    
    additional_data = additional_data_amounts[idx]
    loss_grid = loss_grids_list[idx]
    
    print(f"\n=== Generating plots for index {idx}, additional data: {additional_data} ===")
    print(f"Local min: {loss_grid[loss_grid > 0].min():.6f}, Local max: {loss_grid.max():.6f}")
    
    # 1. 2D plot with dynamic vmin (local scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=None,  # Dynamic - uses local min
        vmax=20, 
        output_path=f"{plot_dirs['dynamic_vmin']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )
    
    # 2. 2D plot with fixed vmin (global scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=global_vmin_largest,  # Fixed global min
        vmax=20, 
        output_path=f"{plot_dirs['fixed_vmin_large']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )

    # 3. 2D plot with fixed vmin (global scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=global_vmin_smallest,  # Fixed global min
        vmax=20, 
        output_path=f"{plot_dirs['fixed_vmin_small']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )
    
    # 4. 3D plot without vmin/vmax constraints
    plot_3d_loss_landscape(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        output_path=f"{plot_dirs['3d']}/landscape_3d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )

end_time = time.time()
elapsed = end_time - start_time

print(f"\n=== Total time for plotting all loss landscapes: {elapsed:.2f} seconds ({elapsed/60:.2f} minutes) ===")
print("\n✓ All plots generated successfully!")

Global min across selected landscapes (smallest value, features are indistinguishable with low variability): 0.000055
Global min across selected landscapes (largest value that captures the variability in the most constant version): 0.000300
Global max across ALL landscapes: 96.429459

=== Generating plots for index 0, additional data: 0 ===
Local min: 0.000055, Local max: 94.768593

=== Generating plots for index 1, additional data: 1 ===
Local min: 0.000057, Local max: 93.152565

=== Generating plots for index 2, additional data: 2 ===
Local min: 0.000058, Local max: 90.697433

=== Generating plots for index 3, additional data: 3 ===
Local min: 0.000064, Local max: 86.754066

=== Generating plots for index 4, additional data: 4 ===
Local min: 0.000063, Local max: 84.745445

=== Generating plots for index 5, additional data: 5 ===
Local min: 0.000068, Local max: 81.355621

=== Generating plots for index 6, additional data: 6 ===
Local min: 0.000071, Local max: 78.533691

=== Generating