#### Training TODOs
- [x] MLP 
- [x] CIPHAR 10 / 100
- [ ] ImageNet
- [ ] NLP Models & Datasets

**Environment:** Please use the standard (`train-viz`) environment here

# Training
### with Logging and Live-Visualization

In this notebook, we train chosen Models with chosen Datasets.
The training comprises many features:
- **Setup & Reproducibility**: seeds for `torch`, `numpy`, `random`; deterministic CuDNN; auto–device (CPU/GPU).  
- **Optimizer & SAM**: configures SGD/AdamW, weight-decay, optional Sharpness-Aware Minimizer (`use_sam`, `rho`).  
- **Metrics Tracking**: epoch-wise train/val losses & accuracies; scheduler LR history.  
- **Early Stopping**: halts after `patience` epochs without val-loss improvement.  
- **Weight Snapshots**: save flattened weights each epoch (`save_model_weights_each_epoch`) for loss-landscape viz.  
- **Embedding Snapshots**: capture embeddings from a fixed subset at configurable intervals (`embedding_records_per_epoch`).  
- **Embedding Drift**: compute latent-space drift between successive snapshots.  
- **Gradient Stats**: log gradient norms, max‐gradient, and grad/param ratios at batch intervals.  
- **Live Plots**: dynamic subplots for loss/accuracy, gradients, LR schedule, drift, PCA, (cosine similarity placeholder).  
- **Logging**: INFO-level history printed each epoch.  
- **Return Payload**: dict with all histories, embeddings, drifts, gradient stats, LR schedule, weight-snapshot dir, model repr, and config.  


In [None]:
%matplotlib widget
# %load_ext autoreload
# %autoreload 2

import torch
import numpy as np
import random

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # Bei Multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torch.version.cuda)  # Should print something like '11.8'
print("Using device:", device)

### Choose Dataset

In [None]:
from helper.vision_classification import init_dataset

#dataset = "mnist"
dataset = "cifar10"
#dataset = "cifar100"

train_loader, test_loader, subset_loader = init_dataset(dataset, samples_per_class=100 if dataset != "cifar100" else 10)

### Choose model

In [None]:
from helper.vision_classification import init_mlp_for_dataset, init_cnn_for_dataset, init_vit_for_dataset, init_resnet_for_dataset, init_densenet_for_dataset

#model = init_mlp_for_dataset(dataset, hidden_dims=[512, 254, 128], dropout=0.1).to(device)
#model = init_mlp_for_dataset(dataset, hidden_dims=[254, 64], dropout=0.1).to(device)

model = init_cnn_for_dataset(dataset, conv_dims=[64, 128, 256], kernel_sizes=[5, 3, 3], hidden_dims=[256, 128], dropout=0.2, residual=True).to(device)
#model = init_cnn_for_dataset(dataset, conv_dims=[64, 128, 256, 512], kernel_sizes=[5, 3, 3, 3], hidden_dims=[512, 256], dropout=0.2).to(device)
#model = init_cnn_for_dataset(dataset, conv_dims=[128, 256, 512, 1024], kernel_sizes=[5, 3, 3, 3], hidden_dims=[1024, 256], dropout=0.2).to(device)

#model = init_vit_for_dataset(dataset, emb_dim=32, depth=4, num_heads=4, mlp_dim=128, dropout=0.15, patch_size=4).to(device)
#model = init_vit_for_dataset(dataset, emb_dim=64, depth=6, num_heads=8, mlp_dim=128, dropout=0.1, patch_size=7).to(device)

# model = init_vit_for_dataset(dataset, emb_dim=192, depth=6, num_heads=6, mlp_dim=256, dropout=0.1, patch_size=4).to(device) #CIFAR100
#model = init_vit_for_dataset(dataset, emb_dim=256, depth=12, num_heads=8, mlp_dim=512, dropout=0.15).to(device)

#model = init_resnet_for_dataset(dataset, fc_hidden_dims=[128], dropout=0.2).to(device)

#model = init_densenet_for_dataset(dataset, fc_hidden_dims=[128], dropout=0.2).to(device)

In [None]:
repr(model)

### Choose Optimizer

In [None]:
# Optional: Overwrites standard Optimizers
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.0005,
    momentum=0.9,
    weight_decay=0.05,
    nesterov=False
)

In [None]:
# Sharpness-Aware Minimizer (Flag to choose SAMSGD)
use_sam = False

# Training Loop

In [None]:
%matplotlib inline

from helper.train import train_model_with_embedding_tracking

#optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) #0.001
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

results = train_model_with_embedding_tracking(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    subset_loader=subset_loader, #test_subset
    num_classes=10 if "cifar100" != dataset else 100,
    device=device,
    epochs=50, #Max 50
    learning_rate=0.001, # 0.0001
    
    save_model_weights_each_epoch=True, # For Loss Landscape (Neuro-Visualizer)
    embedding_records_per_epoch=4,
    track_embedding_drift=True,
    track_scheduled_lr=True,
    track_pca=False,
    use_sam=use_sam,
    patience=10
    #optimizer=optimizer,
    #scheduler=scheduler
)

## Export

In [None]:
from helper.data_manager import save_training_data

run_folder = results["ll_flattened_weights_dir"]
run = f"{run_folder}_{dataset}_{model.emb_dim}_{max(results['val_accuracies']):.4f}"
save_training_data(run, results)
print(run)