# Template Notebook for Edge of Stochastic Stability Experiments

This notebook serves as a template for running measurements on neural networks (both trained and untrained) in the context of Edge of Stochastic Stability research. 

## Purpose
- Creates and initalizes neural networks (with the given initialization seed for reproducability)
- Loads corresponding datasets (with the given dataset seed for irreducability)
- Loads pre-trained network checkpoints
- Evaluates training loss and other metrics
- Provides a foundation for EOSS-related measurements and analysis

Use this as a starting point for your own experiments and measurements.


In [1]:
import torch as T

import torch

from pathlib import Path

# import plotly.express as px
import numpy as np

import torch as T
import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch import linalg as LA
import numpy as np
from copy import deepcopy
from pathlib import Path

from torch.autograd import grad

In [2]:
import sys
import os

# Add the parent directory to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

from utils.data import prepare_dataset, get_dataset_presets
from utils.nets import SquaredLoss, prepare_net_dataset_specific, MLP, CNN, prepare_net, initialize_net, prepare_optimizer, get_model_presets, get_path_of_last_net
from utils.measure import *


In [3]:
DATASET_FOLDER = Path('/scratch/gpfs/andreyev/datasets/') 
RESULTS_FOLDER = Path('/scratch/gpfs/andreyev/eoss/results')
device = (T.device('cuda') if T.cuda.is_available() else 'cpu')

## Load everything

In [4]:
dataset = 'cifar10'
model_type = 'mlp'
num_data = 8192
no_init = True
init_scale = 0.3
classes = None
dataset_seed = 345

### Dataset

In [5]:
dataset_presets = get_dataset_presets()

data = prepare_dataset(dataset, DATASET_FOLDER, num_data, classes, dataset_seed)


In [6]:
X_train, Y_train, X_test, Y_test = data

In [7]:
X = X_train.float().to(device)
Y = Y_train.float().to(device)

### Model

In [8]:
net = prepare_net_dataset_specific(model_type, dataset)

net = net.to(device)

In [9]:
loss_fn = SquaredLoss()

### (Optional) Initialize or load

In [18]:
if not no_init:
    initialize_net(net, init_scale)

In [12]:
batch_size = 16
step_to_load_around = 65_000

steps_in_epoch = len(X_train) // batch_size

epoch_to_load = step_to_load_around // steps_in_epoch
print(f'It should be epoch {epoch_to_load} around step {step_to_load_around}')


It should be epoch 126 around step 65000


In [None]:
run_path = "cifar10_mlp/20250328_0221_53_lr0.01000_b16"

checkpoint_path = RESULTS_FOLDER / run_path / 'checkpoints'

net_to_load = get_path_of_last_net(checkpoint_path)
# net_to_load = checkpoint_path / f'net_{epoch_to_load}.pt'

net.load_state_dict(torch.load(net_to_load, weights_only=True, map_location=device))


<All keys matched successfully>

## Do a test run

In [20]:
preds = net(X_train[:4])
loss = loss_fn(preds, Y_train[:4])
loss.item()

0.00012003588199149817

In [21]:
preds = net(X_test[:4])
loss = loss_fn(preds, Y_test[:4])
loss.item()

0.8840587139129639

# Total loss

We are calculating the total loss to check that the dataset was loaded correctly -- as otherwise the training loss would be higher 

In [23]:
preds = net(X_train)
loss = loss_fn(preds, Y_train)
loss.item()

0.00030041515128687024

In [24]:
def calculate_accuracy(predictions, targets):
    """
    Calculate the accuracy given the model predictions and target labels.
    
    Args:
        predictions: tensor of shape (num_samples, num_classes) with model predictions
        targets: tensor of shape (num_samples, num_classes) with one-hot encoded labels
                or tensor of shape (num_samples,) with class indices
    
    Returns:
        accuracy: float representing the accuracy (0.0 to 1.0)
    """
    # Get the predicted class (highest value in each row)
    pred_classes = torch.argmax(predictions, dim=1)
    
    # Check if targets are one-hot encoded or class indices
    if len(targets.shape) > 1 and targets.shape[1] > 1:
        # One-hot encoded targets
        true_classes = torch.argmax(targets, dim=1)
    else:
        # Class indices (1D tensor)
        true_classes = targets
    
    # Compare and compute accuracy
    correct = (pred_classes == true_classes).sum().item()
    total = targets.size(0)
    
    return correct / total

# Calculate accuracy on training data
train_preds = net(X_train)
train_accuracy = calculate_accuracy(train_preds, Y_train)
print(f"Training accuracy: {train_accuracy:.4f} ({train_accuracy*100:.2f}%)")

# Calculate accuracy on test data
test_preds = net(X_test)
test_accuracy = calculate_accuracy(test_preds, Y_test)
print(f"Test accuracy: {test_accuracy:.5f} ({test_accuracy*100:.2f}%)")

Training accuracy: 1.0000 (100.00%)
Test accuracy: 0.44220 (44.22%)


In [25]:
Y_train

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])