# Random Perturbations

Streamlined notebook for evaluating random perturbations on saved models and datasets.

## Imports

In [1]:
# Standard library
import os
import sys
import time

# Third-party
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Local package imports
from minima_volume.perturb_funcs import ( analyze_wiggles_metrics_large )

from minima_volume.dataset_funcs import (
    load_dataset,
    load_model,
    load_models_and_data,
    prepare_datasets,
    tensor_to_list,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input Parameters

In [2]:
# Perturbation Configuration
perturbation_seed = 1
num_directions = 500
N = 100
x = np.linspace(0, 1, N)
coefficients = x**2

# Other Configuration
dataset_quantities = [0, 500-50, 2000 - 50, 5000 - 50, 20000 - 50]
base_output_dir = ""

## Model + Dataset Specific Code

This is for model and dataset specific code.

In [None]:
# User specifies the CIFAR-10 model module name
from minima_volume.models import CIFAR10_CNN_model_data as model_module  # <- your new module for CIFAR-10

# CIFAR-10 CNN initialization parameters
conv_channels = [32, 64, 128]  # adjust as desired
fc_dims = [512, 256]  # adjust as desired

# Grab model - use CNN parameters instead of MLP hidden_dims
model_template = model_module.get_model(
    conv_channels=conv_channels,  # CNN-specific parameter
    fc_dims=fc_dims,              # CNN-specific parameter
    device=device, 
    seed=0
)

# Grab loss and metrics (these should remain the same)
loss_fn = model_module.get_loss_fn()
other_metrics = model_module.get_additional_metrics()

## Loading Model and Datasets

In [4]:
# ====================================
# Load Trained Models and Dataset
# ====================================
target_dir = "models_and_data"  # relative path
loaded_models, loaded_model_data, loaded_dataset = load_models_and_data(
    model_template=model_template,
    target_dir=target_dir,
    device=device,
)

# Dataset Info
dataset_type = loaded_dataset['dataset_type']
print(f"Dataset type: {dataset_type}")
print(f"Dataset quantities: {loaded_dataset['dataset_quantities']}")

print("\nTensor shapes:")
for key in ["x_base_train", "y_base_train", "x_additional", "y_additional", "x_test", "y_test"]:
    shape = getattr(loaded_dataset[key], "shape", None)
    print(f"  {key}: {shape if shape is not None else 'None'}")

# Reconstruct trained_model dicts safely.
# If the loss or accuracy or additional metrics happen to be
# tensors, they get safely converted to lists.
all_models = [
    {
        "model": model,
        **{
            k: tensor_to_list(model_data[k], key_path=k)
            for k in ["train_loss", "train_accs", "test_loss", "test_accs", "additional_data", "dataset_type"]
        },
    }
    for model, model_data in zip(loaded_models, loaded_model_data)
]
print(f"Reconstructed {len(all_models)} trained models")

Looking for models and dataset in: models_and_data
Found 6 model files:
  - model_additional_0.pt
  - model_additional_1950.pt
  - model_additional_19950.pt
  - model_additional_450.pt
  - model_additional_4950.pt
  - model_additional_49950.pt
✅ Model loaded into provided instance from models_and_data\model_additional_0.pt
Successfully loaded: model_additional_0.pt
✅ Model loaded into provided instance from models_and_data\model_additional_1950.pt
Successfully loaded: model_additional_1950.pt
✅ Model loaded into provided instance from models_and_data\model_additional_19950.pt
Successfully loaded: model_additional_19950.pt
✅ Model loaded into provided instance from models_and_data\model_additional_450.pt
Successfully loaded: model_additional_450.pt
✅ Model loaded into provided instance from models_and_data\model_additional_4950.pt
Successfully loaded: model_additional_4950.pt
✅ Model loaded into provided instance from models_and_data\model_additional_49950.pt
Successfully loaded: model_

✅ Dataset loaded from models_and_data\dataset.pt
Dataset type: data
Dataset quantities: [0, 450, 1950, 4950, 19950, 49950]

Tensor shapes:
  x_base_train: torch.Size([50, 3, 32, 32])
  y_base_train: torch.Size([50])
  x_additional: torch.Size([49950, 3, 32, 32])
  y_additional: torch.Size([49950])
  x_test: torch.Size([10000, 3, 32, 32])
  y_test: torch.Size([10000])
Reconstructed 6 trained models


## Perturbations

Using the saved datasets, we perform model perturbations. 

In [5]:
device = 'cuda'

x_base_train = loaded_dataset['x_base_train'].to(device)
y_base_train = loaded_dataset['y_base_train'].to(device)
x_additional = loaded_dataset['x_additional'].to(device)
y_additional = loaded_dataset['y_additional'].to(device)
x_test = loaded_dataset['x_test'].to(device)
y_test = loaded_dataset['y_test'].to(device)

# Loss function and metrics already grabbed from the model module
analyze_wiggles_metrics_large(
    model_list = all_models, 
    x_base_train = x_base_train,
    y_base_train = y_base_train, 
    x_additional = x_additional,
    y_additional = y_additional,
    x_test = x_test,
    y_test = y_test, 
    dataset_quantities = dataset_quantities, 
    dataset_type = dataset_type, 
    metrics = {"loss": loss_fn, **other_metrics}, 
    coefficients = coefficients,
    num_directions = num_directions,
    perturbation_seed = perturbation_seed,
    base_output_dir = base_output_dir,
    device = device,  # can be set to GPU if needed
)


""" Our saved results are structured as follows:
wiggle_results: List of dictionaries containing wiggle test results
Each dictionary is of the form
{
'loss':
'coefficients':
'accs':
'perturbation_seed':
'perturbation_norm':
}
model: PyTorch model used in analysis (state_dict will be saved)
output_dir: Directory to save results (default: "imgs/swiss/random_dirs")
filename: Name of output file (default: "random_directions.npz")
**kwargs: Additional key-value pairs to be saved in the output file
Typically:
'additional_data':
'model_trained_data':
'dataset_type':
'base_dataset_size': 
'test_loss':
'test_accs':
'num_params':
"""

Testing on data with 0 samples - 500 directions
Testing model trained on 0 additional data.
Loss: 23.8123
Accs: 0.2173


Wiggle completed in 97.02 seconds for data model trained with 0 samples


Saved to data_0

Testing model trained on 1950 additional data.
Loss: 3.7354
Accs: 0.5099


Wiggle completed in 97.79 seconds for data model trained with 1950 samples


Saved to data_0

Testing model trained on 19950 additional data.
Loss: 3.0252
Accs: 0.6959


Wiggle completed in 106.56 seconds for data model trained with 19950 samples


Saved to data_0

Testing model trained on 450 additional data.
Loss: 5.6386
Accs: 0.3964


Wiggle completed in 108.03 seconds for data model trained with 450 samples


Saved to data_0

Testing model trained on 4950 additional data.
Loss: 3.4692
Accs: 0.5855


Wiggle completed in 110.38 seconds for data model trained with 4950 samples


Saved to data_0

Testing model trained on 49950 additional data.
Loss: 3.0697
Accs: 0.7590


Wiggle completed in 116.73 seconds for data model trained with 49950 samples


Saved to data_0

Testing on data with 450 samples - 500 directions
Testing model trained on 0 additional data.
Testing model trained on 1950 additional data.
Loss: 3.7354
Accs: 0.5099


Wiggle completed in 169.65 seconds for data model trained with 1950 samples


Saved to data_450

Testing model trained on 19950 additional data.
Loss: 3.0252
Accs: 0.6959


Wiggle completed in 167.42 seconds for data model trained with 19950 samples


Saved to data_450

Testing model trained on 450 additional data.
Loss: 5.6386
Accs: 0.3964


Wiggle completed in 167.85 seconds for data model trained with 450 samples


Saved to data_450

Testing model trained on 4950 additional data.
Loss: 3.4692
Accs: 0.5855


Wiggle completed in 174.45 seconds for data model trained with 4950 samples


Saved to data_450

Testing model trained on 49950 additional data.
Loss: 3.0697
Accs: 0.7590


Wiggle completed in 170.32 seconds for data model trained with 49950 samples


Saved to data_450

Testing on data with 1950 samples - 500 directions
Testing model trained on 0 additional data.
Testing model trained on 1950 additional data.
Loss: 3.7354
Accs: 0.5099


Wiggle completed in 544.31 seconds for data model trained with 1950 samples


Saved to data_1950

Testing model trained on 19950 additional data.
Loss: 3.0252
Accs: 0.6959


Wiggle completed in 547.39 seconds for data model trained with 19950 samples


Saved to data_1950

Testing model trained on 450 additional data.
Testing model trained on 4950 additional data.
Loss: 3.4692
Accs: 0.5855


Wiggle completed in 550.82 seconds for data model trained with 4950 samples


Saved to data_1950

Testing model trained on 49950 additional data.
Loss: 3.0697
Accs: 0.7590


Wiggle completed in 548.19 seconds for data model trained with 49950 samples


Saved to data_1950

Testing on data with 4950 samples - 500 directions
Testing model trained on 0 additional data.
Testing model trained on 1950 additional data.
Testing model trained on 19950 additional data.
Loss: 3.0252
Accs: 0.6959


Wiggle completed in 1258.50 seconds for data model trained with 19950 samples


Saved to data_4950

Testing model trained on 450 additional data.
Testing model trained on 4950 additional data.
Loss: 3.4692
Accs: 0.5855


Wiggle completed in 1263.24 seconds for data model trained with 4950 samples


Saved to data_4950

Testing model trained on 49950 additional data.
Loss: 3.0697
Accs: 0.7590


Wiggle completed in 1271.04 seconds for data model trained with 49950 samples


Saved to data_4950

Testing on data with 19950 samples - 500 directions
Testing model trained on 0 additional data.
Testing model trained on 1950 additional data.
Testing model trained on 19950 additional data.
Loss: 3.0252
Accs: 0.6959


Wiggle completed in 4769.03 seconds for data model trained with 19950 samples


Saved to data_19950

Testing model trained on 450 additional data.
Testing model trained on 4950 additional data.
Testing model trained on 49950 additional data.
Loss: 3.0697
Accs: 0.7590


Wiggle completed in 4722.94 seconds for data model trained with 49950 samples


Saved to data_19950



' Our saved results are structured as follows:\nwiggle_results: List of dictionaries containing wiggle test results\nEach dictionary is of the form\n{\n\'loss\':\n\'coefficients\':\n\'accs\':\n\'perturbation_seed\':\n\'perturbation_norm\':\n}\nmodel: PyTorch model used in analysis (state_dict will be saved)\noutput_dir: Directory to save results (default: "imgs/swiss/random_dirs")\nfilename: Name of output file (default: "random_directions.npz")\n**kwargs: Additional key-value pairs to be saved in the output file\nTypically:\n\'additional_data\':\n\'model_trained_data\':\n\'dataset_type\':\n\'base_dataset_size\': \n\'test_loss\':\n\'test_accs\':\n\'num_params\':\n'