In [None]:
import sys
import os
import time
import torch
import pickle
from datetime import datetime
from algorithms.nn.utils import seed_everything, create_resnet18, load_cifar10
from algorithms.nn.trainers import train_sgd, train_sgda

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def save_run(metrics, algorithm_name, hyperparameters, base_dir='./results/runs'):
    """
    Save a single training run to a pickle file.
    """
    os.makedirs(base_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create a descriptive filename
    parts = []
    for k, v in hyperparameters.items():
        if isinstance(v, (int, float)):
            parts.append(f"{k}{v}")
        elif isinstance(v, list):
            # Handle lists (like lr_milestones) by joining with hyphens
            val_str = "-".join(map(str, v))
            parts.append(f"{k}{val_str}")
            
    params_str = "+".join(parts)
    
    if params_str:
        filename = f"{algorithm_name.lower()}_{params_str}_{timestamp}.pkl"
    else:
        filename = f"{algorithm_name.lower()}_{timestamp}.pkl"
    
    run_data = {
        'name': algorithm_name,
        'timestamp': timestamp,
        'metrics': {
            'train_losses': metrics.train_losses,
            'train_accs': metrics.train_accs,
            'test_losses': metrics.test_losses,
            'test_accs': metrics.test_accs,
            'learning_rates': metrics.learning_rates,
            'epoch_times': metrics.epoch_times,
            'batch_losses': metrics.batch_losses,
        },
        'hyperparameters': hyperparameters
    }
    
    filepath = os.path.join(base_dir, filename)
    with open(filepath, 'wb') as f:
        pickle.dump(run_data, f)
    
    print(f"Run saved to {filepath}")
    return filepath

def run_sgd_experiment(
    train_loader, test_loader,
    lr=0.1, momentum=0.9, weight_decay=5e-4, 
    lr_milestones=[50, 125, 175], lr_gamma=0.1,
    n_epochs=200, batch_size=128, train_pct=1.0
):
    seed_everything(42)
    model = create_resnet18().to(device)
    
    
    params = {
        'lr': lr,
        'momentum': momentum,
        'weight_decay': weight_decay,
        'lr_milestones': lr_milestones,
        'lr_gamma': lr_gamma,
        'n_epochs': n_epochs,
        'batch_size': batch_size,
        'train_pct': train_pct
    }

    print("=" * 80)
    print(f"TRAINING WITH SGD {params}")
    print("=" * 80)
    
    metrics = train_sgd(
        model, train_loader, test_loader, device,
        n_epochs=n_epochs, lr=lr, momentum=momentum, weight_decay=weight_decay,
        lr_milestones=lr_milestones, lr_gamma=lr_gamma
    )
    
    return save_run(metrics, "SGD", params)

def run_sgda_experiment(
    train_loader, test_loader,
    lr=0.1, sigma=0.1, kappa=0.75, momentum=0.9, weight_decay=5e-4,
    n_epochs=200, batch_size=128, train_pct=1.0, warmup_epochs=0
):
    seed_everything(42)
    model = create_resnet18().to(device)
    
    params = {
        'lr': lr,
        'sigma': sigma,
        'kappa': kappa,
        'momentum': momentum,
        'weight_decay': weight_decay,
        'n_epochs': n_epochs,
        'batch_size': batch_size,
        'train_pct': train_pct
    }

    print("=" * 80)
    print(f"TRAINING WITH SGDA {params}")
    print("=" * 80)
    
    metrics = train_sgda(
        model, train_loader, test_loader, device,
        n_epochs=n_epochs, lr=lr, sigma=sigma, kappa=kappa, 
        momentum=momentum, weight_decay=weight_decay, warmup_epochs=warmup_epochs
    )
    
    return save_run(metrics, "SGDA", params)

## Hyperparameters

In [None]:
# Training Hyperparameters
N_EPOCHS = 100
BATCH_SIZE = 128
TRAIN_PCT = 1.0  # Use 100% of training data
TEST_PCT = 1.0   # Use 100% of test data

print("Common hyperparameters defined.")

## Load Data

In [None]:
train_loader, test_loader = load_cifar10(batch_size=BATCH_SIZE, train_pct=TRAIN_PCT, test_pct=TEST_PCT)
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

# Experiments

## SGD

### Fixed LR = 0.1

In [None]:
# SGD Hyperparameters
SGD_LR = 0.1
SGD_MOMENTUM = 0.9
SGD_WEIGHT_DECAY = 5e-4
SGD_LR_MILESTONES = [201]
SGD_LR_GAMMA = 0.1

run_sgd_experiment(
    train_loader, test_loader,
    lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY,
    lr_milestones=SGD_LR_MILESTONES, lr_gamma=SGD_LR_GAMMA,
    n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, train_pct=TRAIN_PCT
)

### Fixed LR = 0.01

In [None]:
# SGD Hyperparameters
SGD_LR = 0.01
SGD_MOMENTUM = 0.9
SGD_WEIGHT_DECAY = 5e-4
SGD_LR_MILESTONES = [201]
SGD_LR_GAMMA = 0.1

run_sgd_experiment(
    train_loader, test_loader,
    lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY,
    lr_milestones=SGD_LR_MILESTONES, lr_gamma=SGD_LR_GAMMA,
    n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, train_pct=TRAIN_PCT
)

### Multistep LR = 0.1, Milestones = [40, 75], LR gamma = 0.1

In [None]:
# SGD Hyperparameters
SGD_LR = 0.1
SGD_MOMENTUM = 0.9
SGD_WEIGHT_DECAY = 5e-4
SGD_LR_MILESTONES = [40, 75]
SGD_LR_GAMMA = 0.1

run_sgd_experiment(
    train_loader, test_loader,
    lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY,
    lr_milestones=SGD_LR_MILESTONES, lr_gamma=SGD_LR_GAMMA,
    n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, train_pct=TRAIN_PCT
)

### Multistep LR = 0.1, Milestones = [40, 75], LR gamma = 0.1, No momentum

In [None]:
# SGD Hyperparameters
SGD_LR = 0.1
SGD_MOMENTUM = 0
SGD_WEIGHT_DECAY = 5e-4
SGD_LR_MILESTONES = [40, 75]
SGD_LR_GAMMA = 0.1

run_sgd_experiment(
    train_loader, test_loader,
    lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY,
    lr_milestones=SGD_LR_MILESTONES, lr_gamma=SGD_LR_GAMMA,
    n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, train_pct=TRAIN_PCT
)

## SGDA

### LR = 0.2, SIGMA = 0.1, KAPPA = 0.75, 2 WARMUP EPOCHS

In [None]:
# SGDA Hyperparameters
SGDA_LR = 0.5        # Initial learning rate
SGDA_SIGMA = 0.1       # Armijo condition parameter
SGDA_KAPPA = 0.75      # Learning rate reduction factor
SGDA_MOMENTUM = 0    # Momentum > 0 will cause instability
SGDA_WEIGHT_DECAY = 5e-4  # Weight decay
SGDA_WARMUP_EPOCHS = 2     # Number of warm-up epochs

run_sgda_experiment(
    train_loader, test_loader,
    lr=SGDA_LR, sigma=SGDA_SIGMA, kappa=SGDA_KAPPA,
    momentum=SGDA_MOMENTUM, weight_decay=SGDA_WEIGHT_DECAY,
    n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, train_pct=TRAIN_PCT, 
    warmup_epochs=SGDA_WARMUP_EPOCHS
)