# Model Analysis 

This program takes in a folder name with datasets and models, and in that folder generates a series of images and videos describing how the loss landscape changes over time.

## Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import time
import glob

# Importing our existing funcs
import os
import sys
from pathlib import Path

from minima_volume.dataset_funcs import ( load_models_and_data )

from minima_volume.landscape_videos import (
    plot_2d_loss_landscape_3model,
    plot_3d_loss_landscape,
    load_computed_loss_landscapes,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model

In [2]:
# User specifies the model module name
import MNIST_model_data as model_module

# Generate dataset
#x_base, y_base, x_test, y_test = model_module.get_dataset(
#    device = device
#)

# MNIST specific initialization parameters
hidden_dims = [256, 128]

# Grab model
model_template = model_module.get_model(hidden_dims=hidden_dims, device=device, seed=0)

# Grab loss and metrics
loss_fn_per_sample = model_module.get_loss_fn_per_sample()

## Loading Model and Datasets

In [3]:
# Get the relative path
target_dir = Path("models_and_data") #current directory

# Lists to store loaded models and additional data
loaded_models, loaded_additional_data, loaded_dataset = load_models_and_data(model_template=model_template, target_dir=target_dir, device="cpu")

# Print dataset information
print(f"Dataset type: {loaded_dataset['dataset_type']}")
print(f"Dataset quantities: {loaded_dataset['dataset_quantities']}")

# Print tensor shapes
print("\nTensor shapes:")
for key in ["x_base_train", "y_base_train", "x_additional", "y_additional", "x_test", "y_test"]:
    if loaded_dataset[key] is not None:
        print(f"  {key}: {loaded_dataset[key].shape}")
    else:
        print(f"  {key}: None")

Looking for models and dataset in: models_and_data
Found 6 model files:
  - model_additional_0.pt
  - model_additional_1940.pt
  - model_additional_19940.pt
  - model_additional_540.pt
  - model_additional_5940.pt
  - model_additional_59940.pt
✅ Model loaded into provided instance from models_and_data\model_additional_0.pt
Successfully loaded: model_additional_0.pt
✅ Model loaded into provided instance from models_and_data\model_additional_1940.pt
Successfully loaded: model_additional_1940.pt
✅ Model loaded into provided instance from models_and_data\model_additional_19940.pt
Successfully loaded: model_additional_19940.pt
✅ Model loaded into provided instance from models_and_data\model_additional_540.pt
Successfully loaded: model_additional_540.pt
✅ Model loaded into provided instance from models_and_data\model_additional_5940.pt
Successfully loaded: model_additional_5940.pt
✅ Model loaded into provided instance from models_and_data\model_additional_59940.pt
Successfully loaded: model_

## Plotting Parameters

Determines what models to use.

In [4]:
index1 = 3
index2 = 4
index3 = 5
index_list = [index1, index2, index3]

print(
    f"Loading models trained with: "
    f"{loaded_additional_data[index1]['additional_data']}, "
    f"{loaded_additional_data[index2]['additional_data']}, "
    f"{loaded_additional_data[index3]['additional_data']}"
)

model_0 = loaded_models[index1]
model_1 = loaded_models[index2]
model_2 = loaded_models[index3]

print("=== Selected Models ===")
dataset_type = loaded_dataset['dataset_type']
model_labels = [
    f"Model Trained with {loaded_additional_data[index1]['additional_data']} {dataset_type}",
    f"Model Trained with {loaded_additional_data[index2]['additional_data']} {dataset_type}",
    f"Model Trained with {loaded_additional_data[index3]['additional_data']} {dataset_type}"
]

# Prepare dataset with additional_data = 0 (base dataset only)
x_base_train = loaded_dataset['x_base_train']
y_base_train = loaded_dataset['y_base_train']
x_additional = loaded_dataset['x_additional']
y_additional = loaded_dataset['y_additional']

Loading models trained with: 540, 5940, 59940
=== Selected Models ===


## Loss Landscape Parameters

Loss landscapes have been computed, here's parameters to help plot.

In [5]:
# Create storage directory
varying_dataset_dir = Path("varying_dataset")
varying_dataset_dir.mkdir(exist_ok=True)

# Define the additional data amounts to test
# This can be varied depending on the use case!!
anchor_points = [600-60, 6000-60, 60000-60]   # specify 2, 3, or more points
steps_per_segment = 200               # how many intervals between each pair

additional_data_amounts = []

for i in range(len(anchor_points) - 1):
    start, end = anchor_points[i], anchor_points[i+1]
    # linspace includes both start and end, so skip first element except first segment
    segment = np.linspace(start, end, steps_per_segment + 1, dtype=int)
    if i > 0:
        segment = segment[1:]  # avoid repeating the start of each segment
    additional_data_amounts.extend(segment.tolist())
# Create grid
N = 80
span_low, span_high = -0.5, 1.5
x = np.linspace(span_low, span_high, N)
y = np.linspace(span_low, span_high, N)
C1_grid, C2_grid = np.meshgrid(x, y)
grid_coefficients = (C1_grid, C2_grid)

print(f"Grid size: {N}x{N}, Range: [{span_low}, {span_high}]")

Grid size: 80x80, Range: [-0.5, 1.5]


## Visualization

Load from folder and then visualize.

In [6]:
# Load specific indexes
loss_grids_list, C1_grid, C2_grid = load_computed_loss_landscapes(
    output_dir=varying_dataset_dir,
    indexes_to_load=range(len(additional_data_amounts))
)

✓ Loaded index 0: landscape_index_0_amount_540.npz
✓ Loaded index 1: landscape_index_1_amount_567.npz
✓ Loaded index 2: landscape_index_2_amount_594.npz
✓ Loaded index 3: landscape_index_3_amount_621.npz
✓ Loaded index 4: landscape_index_4_amount_648.npz
✓ Loaded index 5: landscape_index_5_amount_675.npz
✓ Loaded index 6: landscape_index_6_amount_702.npz
✓ Loaded index 7: landscape_index_7_amount_729.npz
✓ Loaded index 8: landscape_index_8_amount_756.npz
✓ Loaded index 9: landscape_index_9_amount_783.npz
✓ Loaded index 10: landscape_index_10_amount_810.npz
✓ Loaded index 11: landscape_index_11_amount_837.npz
✓ Loaded index 12: landscape_index_12_amount_864.npz
✓ Loaded index 13: landscape_index_13_amount_891.npz
✓ Loaded index 14: landscape_index_14_amount_918.npz
✓ Loaded index 15: landscape_index_15_amount_945.npz
✓ Loaded index 16: landscape_index_16_amount_972.npz
✓ Loaded index 17: landscape_index_17_amount_999.npz
✓ Loaded index 18: landscape_index_18_amount_1026.npz
✓ Loaded ind

In [7]:
# Create directories for different plot types
plot_dirs = {
    'dynamic_vmin': 'plots_dynamic_vmin',
    'fixed_vmin_large': 'plots_fixed_vmin_large', 
    'fixed_vmin_small': 'plots_fixed_vmin_small', 
}

for dir_name in plot_dirs.values():
    Path(dir_name).mkdir(parents=True, exist_ok=True)

# Find global min and max across ALL landscapes for consistent fixed scaling
global_vmin_smallest = float('inf')
global_vmin_largest = float('-inf')
all_vmax = float('-inf')

for idx in range(len(loss_grids_list)):
    if 0 <= idx < len(loss_grids_list):
        loss_grid = loss_grids_list[idx]
        current_min = loss_grid[loss_grid > 0].min()
        current_max = loss_grid.max()
        global_vmin_smallest = min(global_vmin_smallest, current_min)
        global_vmin_largest = max(global_vmin_largest, current_min)
        all_vmax = max(all_vmax, current_max)

print(f"Global min across selected landscapes (smallest value, features are indistinguishable with low variability): {global_vmin_smallest:.6f}")
print(f"Global min across selected landscapes (largest value that captures the variability in the most constant version): {global_vmin_largest:.6f}")
print(f"Global max across ALL landscapes: {all_vmax:.6f}")

# Generate all three types of plots for each index

start_time = time.time()

for idx in range(len(loss_grids_list)):
    if idx < 0 or idx >= len(loss_grids_list):
        continue
    
    additional_data = additional_data_amounts[idx]
    loss_grid = loss_grids_list[idx]
    
    print(f"\n=== Generating plots for index {idx}, additional data: {additional_data} ===")
    print(f"Local min: {loss_grid[loss_grid > 0].min():.6f}, Local max: {loss_grid.max():.6f}")
    
    # 1. 2D plot with dynamic vmin (local scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=None,  # Dynamic - uses local min
        vmax=20, 
        output_path=f"{plot_dirs['dynamic_vmin']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )
    
    # 2. 2D plot with fixed vmin (global scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=global_vmin_largest,  # Fixed global min
        vmax=20, 
        output_path=f"{plot_dirs['fixed_vmin_large']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )

    # 3. 2D plot with fixed vmin (global scaling)
    plot_2d_loss_landscape_3model(
        C1_grid=C1_grid,
        C2_grid=C2_grid,
        loss_grid=loss_grid,
        span_low=span_low,
        span_high=span_high,
        additional_data=additional_data,
        model_labels=model_labels,
        vmin=global_vmin_smallest,  # Fixed global min
        vmax=20, 
        output_path=f"{plot_dirs['fixed_vmin_small']}/landscape_2d_index_{idx:03d}_amount_{additional_data}.png",
        show=False
    )

end_time = time.time()
elapsed = end_time - start_time

print(f"\n=== Total time for plotting all loss landscapes: {elapsed:.2f} seconds ({elapsed/60:.2f} minutes) ===")
print("\n✓ All plots generated successfully!")

Global min across selected landscapes (smallest value, features are indistinguishable with low variability): 0.000113
Global min across selected landscapes (largest value that captures the variability in the most constant version): 0.002054
Global max across ALL landscapes: 15.162881

=== Generating plots for index 0, additional data: 540 ===
Local min: 0.000113, Local max: 15.162881

=== Generating plots for index 1, additional data: 567 ===
Local min: 0.000415, Local max: 14.884861

=== Generating plots for index 2, additional data: 594 ===
Local min: 0.000407, Local max: 14.915216

=== Generating plots for index 3, additional data: 621 ===
Local min: 0.000402, Local max: 14.833491

=== Generating plots for index 4, additional data: 648 ===
Local min: 0.000401, Local max: 14.650093

=== Generating plots for index 5, additional data: 675 ===
Local min: 0.000410, Local max: 14.708021

=== Generating plots for index 6, additional data: 702 ===
Local min: 0.000405, Local max: 14.424446

