# Train Low Test Models

This notebook is a streamlined notebook for generating minima of low test accuracy through three different means:
- Dataset Poisoning
- Adding Noise to Data
- Decreasing Dataset Sizes

## Imports

In [1]:
# Standard library
import copy
import os
import sys
import time

# Third-party
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Local package imports
from minima_volume.dataset_funcs import (
    prepare_datasets,
    save_dataset,
    save_model,
)
from minima_volume.train_funcs import evaluate, train

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input Parameters

In [2]:

# ==============================
# Base Input Parameters
# ==============================
# --- SEEDS ---
data_seed = 16            
model_seed = 6           

# --- Training configuration ---
epochs = 5000            

# --- Dataset configuration ---
base_data_size = (int)(0.01*9409)
dataset_type = "data"   
dataset_quantities = [0, (int)(0.09*9409), (int)(0.19*9409), (int)(0.39*9409), (int)(0.59*9409)] # modulo arithmetic dataset size

# --- Output configuration ---
base_output_dir = ""     
save_generated_dataset = True   
save_generated_models = True    



In [3]:
print (base_data_size)
print (dataset_quantities)

94
[0, 846, 1787, 3669, 5551]


## Model + Dataset Specific Code

This is for specific code.

In [4]:
# User specifies the model module name
from minima_volume.models import modulo_arithmetic_model_data as model_module
modulus = 97
# Generate dataset
x_base, y_base, x_test, y_test = model_module.get_dataset(
    modulus = modulus,
    device = device,
)

# MNIST specific initialization parameters
hidden_dims = [250]

# Grab model
model_template = model_module.get_model(N = modulus, hidden_dims=hidden_dims, device=device, seed=model_seed)

# Grab loss and metrics
loss_fn = model_module.get_loss_fn()
other_metrics = model_module.get_additional_metrics()

## Training

We generate the various datasets used to train our models here, before training them. We record the losses, and what each model was trained on.

In [5]:
# ==============================
# Prepare datasets
# ==============================
x_base_train, y_base_train, x_additional, y_additional = prepare_datasets(
    x_base=x_base,
    y_base=y_base,
    dataset_type=dataset_type,
    dataset_quantities=dataset_quantities,
    base_data_size=base_data_size,
    data_seed=data_seed,
    seed_1=None,
    seed_2=None,
)

x_base_train = x_base_train.to(device)
y_base_train = y_base_train.to(device)
x_additional = x_additional.to(device)
y_additional = y_additional.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)

# ==============================
# Training loop
# ==============================
all_models = []

for additional_data in dataset_quantities:
    # Assemble training dataset
    x_train = torch.cat([x_base_train, x_additional[:additional_data]], dim=0)
    y_train = torch.cat([y_base_train, y_additional[:additional_data]], dim=0)

    # Initialize model (defined in the model-specific file)
    torch.manual_seed(model_seed)
    model = copy.deepcopy(model_template)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    batch_size = len(x_train)

    # Train model
    train_loss, train_other_metrics, test_loss, test_other_metrics = train(
        model = model,
        x_train = x_train, y_train = y_train,
        x_test = x_test, y_test = y_test,
        loss_fn = loss_fn,
        metrics = other_metrics,
        optimizer = optimizer,
        epochs=epochs,
        batch_size=batch_size,
        verbose_every=100,
    )
    
    # Build dictionary dynamically for additional metrics
    train_metrics_dict = {}
    test_metrics_dict = {}
    if train_other_metrics is not None:
        # train_other_metrics is a list of dicts per epoch
        for metric_name in train_other_metrics[0].keys():  # keys from first epoch
            train_metrics_dict[f"train_{metric_name}"] = [m[metric_name] for m in train_other_metrics]
            test_metrics_dict[f"test_{metric_name}"] = [m[metric_name] for m in test_other_metrics]
    
    # Store results
    trained_model = {
        "model": model,
        "train_loss": train_loss,
        "test_loss": test_loss,
        "additional_data": additional_data,
        "dataset_type": dataset_type,
        **train_metrics_dict,  # dynamically include additional metrics
        **test_metrics_dict,
    }
    
    all_models.append(trained_model)

    print(f"Completed training with {additional_data} additional samples of {dataset_type}")

    # Free memory (important for large GPU datasets)
    del x_train, y_train
    torch.cuda.empty_cache()


Epoch 1/5000: Train Loss 0.0121 | Test Loss 0.0118 | accs Train 0.0000 Test 0.0095


Epoch 100/5000: Train Loss 0.0002 | Test Loss 0.0151 | accs Train 1.0000 Test 0.0118


Epoch 200/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 300/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 400/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 500/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 600/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 700/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 800/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 900/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1000/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1100/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1200/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1300/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1400/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0115


Epoch 1500/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1600/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1700/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0115


Epoch 1800/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 1900/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2000/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2100/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0115


Epoch 2200/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0115


Epoch 2300/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2400/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2500/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0115


Epoch 2600/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2700/5000: Train Loss 0.0000 | Test Loss 0.0155 | accs Train 1.0000 Test 0.0116


Epoch 2800/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 2900/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 3000/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 3100/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3200/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3300/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 3400/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3500/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3600/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3700/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 3800/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 3900/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 4000/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 4100/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 4200/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 4300/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0116


Epoch 4400/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 4500/5000: Train Loss 0.0000 | Test Loss 0.0154 | accs Train 1.0000 Test 0.0115


Epoch 4600/5000: Train Loss 0.0000 | Test Loss 0.0153 | accs Train 1.0000 Test 0.0116


Epoch 4700/5000: Train Loss 0.0000 | Test Loss 0.0153 | accs Train 1.0000 Test 0.0115


Epoch 4800/5000: Train Loss 0.0000 | Test Loss 0.0153 | accs Train 1.0000 Test 0.0115


Epoch 4900/5000: Train Loss 0.0000 | Test Loss 0.0153 | accs Train 1.0000 Test 0.0115


Epoch 5000/5000: Train Loss 0.0000 | Test Loss 0.0153 | accs Train 1.0000 Test 0.0116
Completed training with 0 additional samples of data
Epoch 1/5000: Train Loss 0.0122 | Test Loss 0.0117 | accs Train 0.0106 Test 0.0095


Epoch 100/5000: Train Loss 0.0063 | Test Loss 0.0125 | accs Train 1.0000 Test 0.0999


Epoch 200/5000: Train Loss 0.0036 | Test Loss 0.0155 | accs Train 1.0000 Test 0.1003


Epoch 300/5000: Train Loss 0.0024 | Test Loss 0.0182 | accs Train 1.0000 Test 0.1011


Epoch 400/5000: Train Loss 0.0019 | Test Loss 0.0205 | accs Train 1.0000 Test 0.1022


Epoch 500/5000: Train Loss 0.0015 | Test Loss 0.0223 | accs Train 1.0000 Test 0.1027


Epoch 600/5000: Train Loss 0.0013 | Test Loss 0.0238 | accs Train 1.0000 Test 0.1027


Epoch 700/5000: Train Loss 0.0012 | Test Loss 0.0251 | accs Train 1.0000 Test 0.1030


Epoch 800/5000: Train Loss 0.0011 | Test Loss 0.0261 | accs Train 1.0000 Test 0.1027


Epoch 900/5000: Train Loss 0.0010 | Test Loss 0.0271 | accs Train 1.0000 Test 0.1023


Epoch 1000/5000: Train Loss 0.0010 | Test Loss 0.0279 | accs Train 1.0000 Test 0.1023


Epoch 1100/5000: Train Loss 0.0009 | Test Loss 0.0286 | accs Train 1.0000 Test 0.1029


Epoch 1200/5000: Train Loss 0.0009 | Test Loss 0.0292 | accs Train 1.0000 Test 0.1032


Epoch 1300/5000: Train Loss 0.0009 | Test Loss 0.0297 | accs Train 1.0000 Test 0.1036


Epoch 1400/5000: Train Loss 0.0008 | Test Loss 0.0303 | accs Train 1.0000 Test 0.1036


Epoch 1500/5000: Train Loss 0.0008 | Test Loss 0.0307 | accs Train 1.0000 Test 0.1036


Epoch 1600/5000: Train Loss 0.0008 | Test Loss 0.0312 | accs Train 1.0000 Test 0.1036


Epoch 1700/5000: Train Loss 0.0008 | Test Loss 0.0316 | accs Train 1.0000 Test 0.1037


Epoch 1800/5000: Train Loss 0.0008 | Test Loss 0.0319 | accs Train 1.0000 Test 0.1038


Epoch 1900/5000: Train Loss 0.0008 | Test Loss 0.0322 | accs Train 1.0000 Test 0.1044


Epoch 2000/5000: Train Loss 0.0007 | Test Loss 0.0325 | accs Train 1.0000 Test 0.1040


Epoch 2100/5000: Train Loss 0.0007 | Test Loss 0.0328 | accs Train 1.0000 Test 0.1044


Epoch 2200/5000: Train Loss 0.0007 | Test Loss 0.0331 | accs Train 1.0000 Test 0.1046


Epoch 2300/5000: Train Loss 0.0007 | Test Loss 0.0333 | accs Train 1.0000 Test 0.1046


Epoch 2400/5000: Train Loss 0.0007 | Test Loss 0.0335 | accs Train 1.0000 Test 0.1045


Epoch 2500/5000: Train Loss 0.0007 | Test Loss 0.0337 | accs Train 1.0000 Test 0.1043


Epoch 2600/5000: Train Loss 0.0007 | Test Loss 0.0339 | accs Train 1.0000 Test 0.1046


Epoch 2700/5000: Train Loss 0.0007 | Test Loss 0.0341 | accs Train 1.0000 Test 0.1043


Epoch 2800/5000: Train Loss 0.0007 | Test Loss 0.0343 | accs Train 1.0000 Test 0.1043


Epoch 2900/5000: Train Loss 0.0007 | Test Loss 0.0345 | accs Train 1.0000 Test 0.1042


Epoch 3000/5000: Train Loss 0.0007 | Test Loss 0.0347 | accs Train 1.0000 Test 0.1043


Epoch 3100/5000: Train Loss 0.0007 | Test Loss 0.0349 | accs Train 1.0000 Test 0.1043


Epoch 3200/5000: Train Loss 0.0007 | Test Loss 0.0351 | accs Train 1.0000 Test 0.1042


Epoch 3300/5000: Train Loss 0.0007 | Test Loss 0.0352 | accs Train 1.0000 Test 0.1043


Epoch 3400/5000: Train Loss 0.0006 | Test Loss 0.0354 | accs Train 1.0000 Test 0.1042


Epoch 3500/5000: Train Loss 0.0006 | Test Loss 0.0355 | accs Train 1.0000 Test 0.1040


Epoch 3600/5000: Train Loss 0.0006 | Test Loss 0.0357 | accs Train 1.0000 Test 0.1043


Epoch 3700/5000: Train Loss 0.0006 | Test Loss 0.0358 | accs Train 1.0000 Test 0.1039


Epoch 3800/5000: Train Loss 0.0006 | Test Loss 0.0359 | accs Train 1.0000 Test 0.1042


Epoch 3900/5000: Train Loss 0.0006 | Test Loss 0.0361 | accs Train 1.0000 Test 0.1040


Epoch 4000/5000: Train Loss 0.0006 | Test Loss 0.0362 | accs Train 1.0000 Test 0.1042


Epoch 4100/5000: Train Loss 0.0006 | Test Loss 0.0363 | accs Train 1.0000 Test 0.1043


Epoch 4200/5000: Train Loss 0.0006 | Test Loss 0.0365 | accs Train 1.0000 Test 0.1046


Epoch 4300/5000: Train Loss 0.0006 | Test Loss 0.0366 | accs Train 1.0000 Test 0.1047


Epoch 4400/5000: Train Loss 0.0006 | Test Loss 0.0367 | accs Train 1.0000 Test 0.1051


Epoch 4500/5000: Train Loss 0.0006 | Test Loss 0.0368 | accs Train 1.0000 Test 0.1051


Epoch 4600/5000: Train Loss 0.0006 | Test Loss 0.0369 | accs Train 1.0000 Test 0.1049


Epoch 4700/5000: Train Loss 0.0006 | Test Loss 0.0371 | accs Train 1.0000 Test 0.1050


Epoch 4800/5000: Train Loss 0.0006 | Test Loss 0.0372 | accs Train 1.0000 Test 0.1050


Epoch 4900/5000: Train Loss 0.0006 | Test Loss 0.0373 | accs Train 1.0000 Test 0.1049


Epoch 5000/5000: Train Loss 0.0006 | Test Loss 0.0374 | accs Train 1.0000 Test 0.1051
Completed training with 846 additional samples of data
Epoch 1/5000: Train Loss 0.0122 | Test Loss 0.0117 | accs Train 0.0096 Test 0.0097


Epoch 100/5000: Train Loss 0.0080 | Test Loss 0.0112 | accs Train 0.9681 Test 0.1940


Epoch 200/5000: Train Loss 0.0062 | Test Loss 0.0126 | accs Train 0.9995 Test 0.2010


Epoch 300/5000: Train Loss 0.0054 | Test Loss 0.0139 | accs Train 1.0000 Test 0.2020


Epoch 400/5000: Train Loss 0.0049 | Test Loss 0.0148 | accs Train 1.0000 Test 0.2030


Epoch 500/5000: Train Loss 0.0046 | Test Loss 0.0155 | accs Train 1.0000 Test 0.2028


Epoch 600/5000: Train Loss 0.0045 | Test Loss 0.0160 | accs Train 1.0000 Test 0.2025


Epoch 700/5000: Train Loss 0.0044 | Test Loss 0.0164 | accs Train 1.0000 Test 0.2031


Epoch 800/5000: Train Loss 0.0043 | Test Loss 0.0167 | accs Train 1.0000 Test 0.2034


Epoch 900/5000: Train Loss 0.0042 | Test Loss 0.0170 | accs Train 1.0000 Test 0.2026


Epoch 1000/5000: Train Loss 0.0042 | Test Loss 0.0172 | accs Train 1.0000 Test 0.2033


Epoch 1100/5000: Train Loss 0.0041 | Test Loss 0.0173 | accs Train 1.0000 Test 0.2037


Epoch 1200/5000: Train Loss 0.0041 | Test Loss 0.0175 | accs Train 1.0000 Test 0.2040


Epoch 1300/5000: Train Loss 0.0041 | Test Loss 0.0176 | accs Train 1.0000 Test 0.2040


Epoch 1400/5000: Train Loss 0.0040 | Test Loss 0.0177 | accs Train 1.0000 Test 0.2041


Epoch 1500/5000: Train Loss 0.0040 | Test Loss 0.0178 | accs Train 1.0000 Test 0.2040


Epoch 1600/5000: Train Loss 0.0040 | Test Loss 0.0179 | accs Train 1.0000 Test 0.2036


Epoch 1700/5000: Train Loss 0.0040 | Test Loss 0.0180 | accs Train 1.0000 Test 0.2032


Epoch 1800/5000: Train Loss 0.0040 | Test Loss 0.0181 | accs Train 1.0000 Test 0.2030


Epoch 1900/5000: Train Loss 0.0040 | Test Loss 0.0181 | accs Train 1.0000 Test 0.2030


Epoch 2000/5000: Train Loss 0.0039 | Test Loss 0.0182 | accs Train 1.0000 Test 0.2031


Epoch 2100/5000: Train Loss 0.0039 | Test Loss 0.0182 | accs Train 1.0000 Test 0.2028


Epoch 2200/5000: Train Loss 0.0039 | Test Loss 0.0183 | accs Train 1.0000 Test 0.2030


Epoch 2300/5000: Train Loss 0.0039 | Test Loss 0.0183 | accs Train 1.0000 Test 0.2035


Epoch 2400/5000: Train Loss 0.0039 | Test Loss 0.0184 | accs Train 1.0000 Test 0.2035


Epoch 2500/5000: Train Loss 0.0039 | Test Loss 0.0184 | accs Train 1.0000 Test 0.2036


Epoch 2600/5000: Train Loss 0.0039 | Test Loss 0.0184 | accs Train 1.0000 Test 0.2035


Epoch 2700/5000: Train Loss 0.0039 | Test Loss 0.0184 | accs Train 1.0000 Test 0.2035


Epoch 2800/5000: Train Loss 0.0039 | Test Loss 0.0185 | accs Train 1.0000 Test 0.2040


Epoch 2900/5000: Train Loss 0.0039 | Test Loss 0.0185 | accs Train 1.0000 Test 0.2037


Epoch 3000/5000: Train Loss 0.0039 | Test Loss 0.0185 | accs Train 1.0000 Test 0.2036


Epoch 3100/5000: Train Loss 0.0039 | Test Loss 0.0185 | accs Train 1.0000 Test 0.2037


Epoch 3200/5000: Train Loss 0.0039 | Test Loss 0.0185 | accs Train 1.0000 Test 0.2038


Epoch 3300/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 3400/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 3500/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 3600/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2040


Epoch 3700/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2040


Epoch 3800/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 3900/5000: Train Loss 0.0039 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 4000/5000: Train Loss 0.0038 | Test Loss 0.0186 | accs Train 1.0000 Test 0.2038


Epoch 4100/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2040


Epoch 4200/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2041


Epoch 4300/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2037


Epoch 4400/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2042


Epoch 4500/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2043


Epoch 4600/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2044


Epoch 4700/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2042


Epoch 4800/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2042


Epoch 4900/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2044


Epoch 5000/5000: Train Loss 0.0038 | Test Loss 0.0187 | accs Train 1.0000 Test 0.2045
Completed training with 1787 additional samples of data
Epoch 1/5000: Train Loss 0.0122 | Test Loss 0.0117 | accs Train 0.0093 Test 0.0097


Epoch 100/5000: Train Loss 0.0092 | Test Loss 0.0104 | accs Train 0.7598 Test 0.3056


Epoch 200/5000: Train Loss 0.0081 | Test Loss 0.0105 | accs Train 0.9346 Test 0.3769


Epoch 300/5000: Train Loss 0.0073 | Test Loss 0.0103 | accs Train 0.9875 Test 0.4165


Epoch 400/5000: Train Loss 0.0066 | Test Loss 0.0097 | accs Train 0.9989 Test 0.5086


Epoch 500/5000: Train Loss 0.0059 | Test Loss 0.0088 | accs Train 1.0000 Test 0.7491


Epoch 600/5000: Train Loss 0.0052 | Test Loss 0.0078 | accs Train 1.0000 Test 0.9528


Epoch 700/5000: Train Loss 0.0044 | Test Loss 0.0065 | accs Train 1.0000 Test 0.9983


Epoch 800/5000: Train Loss 0.0035 | Test Loss 0.0052 | accs Train 1.0000 Test 1.0000


Epoch 900/5000: Train Loss 0.0028 | Test Loss 0.0040 | accs Train 1.0000 Test 1.0000


Epoch 1000/5000: Train Loss 0.0022 | Test Loss 0.0032 | accs Train 1.0000 Test 1.0000


Epoch 1100/5000: Train Loss 0.0018 | Test Loss 0.0026 | accs Train 1.0000 Test 1.0000


Epoch 1200/5000: Train Loss 0.0016 | Test Loss 0.0022 | accs Train 1.0000 Test 1.0000


Epoch 1300/5000: Train Loss 0.0014 | Test Loss 0.0020 | accs Train 1.0000 Test 1.0000


Epoch 1400/5000: Train Loss 0.0013 | Test Loss 0.0019 | accs Train 1.0000 Test 1.0000


Epoch 1500/5000: Train Loss 0.0012 | Test Loss 0.0017 | accs Train 1.0000 Test 1.0000


Epoch 1600/5000: Train Loss 0.0011 | Test Loss 0.0016 | accs Train 1.0000 Test 1.0000


Epoch 1700/5000: Train Loss 0.0011 | Test Loss 0.0015 | accs Train 1.0000 Test 1.0000


Epoch 1800/5000: Train Loss 0.0010 | Test Loss 0.0015 | accs Train 1.0000 Test 1.0000


Epoch 1900/5000: Train Loss 0.0010 | Test Loss 0.0014 | accs Train 1.0000 Test 1.0000


Epoch 2000/5000: Train Loss 0.0009 | Test Loss 0.0014 | accs Train 1.0000 Test 1.0000


Epoch 2100/5000: Train Loss 0.0009 | Test Loss 0.0014 | accs Train 1.0000 Test 1.0000


Epoch 2200/5000: Train Loss 0.0009 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 2300/5000: Train Loss 0.0009 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 2400/5000: Train Loss 0.0009 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 2500/5000: Train Loss 0.0008 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 2600/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 2700/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 2800/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 2900/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 3000/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 3100/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 3200/5000: Train Loss 0.0008 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 3300/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3400/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3500/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3600/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3700/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3800/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 3900/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4000/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4100/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4200/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4300/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4400/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4500/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4600/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4700/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4800/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 4900/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 5000/5000: Train Loss 0.0007 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000
Completed training with 3669 additional samples of data
Epoch 1/5000: Train Loss 0.0122 | Test Loss 0.0117 | accs Train 0.0094 Test 0.0096


Epoch 100/5000: Train Loss 0.0096 | Test Loss 0.0101 | accs Train 0.5355 Test 0.3249


Epoch 200/5000: Train Loss 0.0084 | Test Loss 0.0093 | accs Train 0.8810 Test 0.5656


Epoch 300/5000: Train Loss 0.0063 | Test Loss 0.0071 | accs Train 1.0000 Test 0.9858


Epoch 400/5000: Train Loss 0.0041 | Test Loss 0.0046 | accs Train 1.0000 Test 1.0000


Epoch 500/5000: Train Loss 0.0025 | Test Loss 0.0028 | accs Train 1.0000 Test 1.0000


Epoch 600/5000: Train Loss 0.0018 | Test Loss 0.0020 | accs Train 1.0000 Test 1.0000


Epoch 700/5000: Train Loss 0.0015 | Test Loss 0.0017 | accs Train 1.0000 Test 1.0000


Epoch 800/5000: Train Loss 0.0014 | Test Loss 0.0015 | accs Train 1.0000 Test 1.0000


Epoch 900/5000: Train Loss 0.0013 | Test Loss 0.0014 | accs Train 1.0000 Test 1.0000


Epoch 1000/5000: Train Loss 0.0012 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 1100/5000: Train Loss 0.0012 | Test Loss 0.0013 | accs Train 1.0000 Test 1.0000


Epoch 1200/5000: Train Loss 0.0011 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 1300/5000: Train Loss 0.0011 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 1400/5000: Train Loss 0.0010 | Test Loss 0.0012 | accs Train 1.0000 Test 1.0000


Epoch 1500/5000: Train Loss 0.0010 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 1600/5000: Train Loss 0.0010 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 1700/5000: Train Loss 0.0010 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 1800/5000: Train Loss 0.0010 | Test Loss 0.0011 | accs Train 1.0000 Test 1.0000


Epoch 1900/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2000/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2100/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2200/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2300/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2400/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2500/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2600/5000: Train Loss 0.0009 | Test Loss 0.0010 | accs Train 1.0000 Test 1.0000


Epoch 2700/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 2800/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 2900/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3000/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3100/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3200/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3300/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3400/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3500/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3600/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3700/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3800/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 3900/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4000/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4100/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4200/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4300/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4400/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4500/5000: Train Loss 0.0008 | Test Loss 0.0009 | accs Train 1.0000 Test 1.0000


Epoch 4600/5000: Train Loss 0.0008 | Test Loss 0.0008 | accs Train 1.0000 Test 1.0000


Epoch 4700/5000: Train Loss 0.0008 | Test Loss 0.0008 | accs Train 1.0000 Test 1.0000


Epoch 4800/5000: Train Loss 0.0007 | Test Loss 0.0008 | accs Train 1.0000 Test 1.0000


Epoch 4900/5000: Train Loss 0.0007 | Test Loss 0.0008 | accs Train 1.0000 Test 1.0000


Epoch 5000/5000: Train Loss 0.0007 | Test Loss 0.0008 | accs Train 1.0000 Test 1.0000
Completed training with 5551 additional samples of data


## Training Summary

In [6]:
# ====================================
# Summary of Training Results
# ====================================
print("=== True Generalization ===")
for model_data in all_models:
    model = model_data["model"]
    additional_data = model_data["additional_data"]

    test_loss, test_metrics = evaluate(
        model=model,
        x_test=x_test,
        y_test=y_test,
        loss_fn=loss_fn,
        metrics=other_metrics
    )

    metrics_str = " | ".join([f"{name}: {val:.4f}" for name, val in test_metrics.items()])
    print(
        f"{additional_data:>4} samples | "
        f"Test Loss: {test_loss:.4f}" + (f" | {metrics_str}" if metrics_str else "")
    )

print("\n=== Model Diagnostics by Training Data ===")
for additional_data in dataset_quantities:
    # Build dataset with this many additional samples
    x_train = torch.cat([x_base_train, x_additional[:additional_data]], dim=0)
    y_train = torch.cat([y_base_train, y_additional[:additional_data]], dim=0)

    print(f"\nDataset type: {dataset_type}, additional samples: {additional_data}")

    for model_data in all_models:
        model = model_data["model"]
        model_additional_data = model_data["additional_data"]

        train_loss, train_metrics = evaluate(
            model=model,
            x_test=x_train,
            y_test=y_train,
            loss_fn=loss_fn,
            metrics=other_metrics
        )

        metrics_str = " | ".join([f"{name}: {val:.4f}" for name, val in train_metrics.items()])
        print(
            f" Model {model_additional_data:>4} | "
            f"Train Loss: {train_loss:.4f}" + (f" | {metrics_str}" if metrics_str else "")
        )

    # Free memory if large
    del x_train, y_train
    torch.cuda.empty_cache()

=== True Generalization ===
   0 samples | Test Loss: 0.0153 | accs: 0.0116
 846 samples | Test Loss: 0.0374 | accs: 0.1051
1787 samples | Test Loss: 0.0187 | accs: 0.2045
3669 samples | Test Loss: 0.0011 | accs: 1.0000
5551 samples | Test Loss: 0.0008 | accs: 1.0000

=== Model Diagnostics by Training Data ===

Dataset type: data, additional samples: 0
 Model    0 | Train Loss: 0.0000 | accs: 1.0000
 Model  846 | Train Loss: 0.0006 | accs: 1.0000
 Model 1787 | Train Loss: 0.0037 | accs: 1.0000
 Model 3669 | Train Loss: 0.0007 | accs: 1.0000
 Model 5551 | Train Loss: 0.0008 | accs: 1.0000

Dataset type: data, additional samples: 846
 Model    0 | Train Loss: 0.0139 | accs: 0.1011
 Model  846 | Train Loss: 0.0006 | accs: 1.0000
 Model 1787 | Train Loss: 0.0038 | accs: 1.0000
 Model 3669 | Train Loss: 0.0007 | accs: 1.0000
 Model 5551 | Train Loss: 0.0007 | accs: 1.0000

Dataset type: data, additional samples: 1787
 Model    0 | Train Loss: 0.0147 | accs: 0.0510
 Model  846 | Train Loss: 

### Model + Data Specific Verification

In [7]:
model_module.verify_model_results(
    all_models=all_models,
    x_base_train=x_base_train,
    y_base_train=y_base_train,
    x_additional=x_additional,
    y_additional=y_additional,
    x_test=x_test,
    y_test=y_test,
    dataset_quantities=dataset_quantities,
    dataset_type=dataset_type,
)

## Model Saving

In [8]:
# ====================================
# Save Datasets and Models
# ====================================
output_folder = "models_and_data"
# Save dataset (Possible to skip)
if save_generated_dataset:
    save_dataset(
        folder=output_folder,
        filename="dataset.pt",
        x_base_train=x_base_train,
        y_base_train=y_base_train,
        x_additional=x_additional,
        y_additional=y_additional,
        x_test=x_test,
        y_test=y_test,
        dataset_quantities=dataset_quantities,
        dataset_type=dataset_type,
    )
    print(f"Saved dataset to {output_folder}/dataset.pt")

# Save trained models
if save_generated_models:
    for model_data in all_models:
        filename = f"model_additional_{model_data['additional_data']}.pt"
        save_model(
            folder=output_folder,
            filename=filename,
            model=model_data["model"],
            train_loss=model_data["train_loss"],
            train_accs=model_data["train_accs"],
            test_loss=model_data["test_loss"],
            test_accs=model_data["test_accs"],
            additional_data=model_data["additional_data"],
            dataset_type=model_data["dataset_type"],
        )
        print(f"Saved model: {output_folder}/{filename}")

✅ Dataset saved to models_and_data\dataset.pt
Saved dataset to models_and_data/dataset.pt
✅ Model saved to models_and_data\model_additional_0.pt
Saved model: models_and_data/model_additional_0.pt
✅ Model saved to models_and_data\model_additional_846.pt
Saved model: models_and_data/model_additional_846.pt
✅ Model saved to models_and_data\model_additional_1787.pt
Saved model: models_and_data/model_additional_1787.pt
✅ Model saved to models_and_data\model_additional_3669.pt
Saved model: models_and_data/model_additional_3669.pt
✅ Model saved to models_and_data\model_additional_5551.pt
Saved model: models_and_data/model_additional_5551.pt
