In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from dml-py.trainers import DMLTrainer
from dml-py.models.cifar import resnet20, wrn_16_2
from dml-py.analysis import LossLandscape, RobustnessAnalyzer
from dml-py.analysis.visualization import plot_training_curves, plot_knowledge_flow

sns.set_style('whitegrid')

## 1. Train Models

In [None]:
# Prepare data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Train DML
models = [resnet20(10), wrn_16_2(10)]
trainer = DMLTrainer(models, learning_rate=0.1, temperature=3)

results = trainer.train(
    train_loader,
    test_loader,
    epochs=30,
    save_checkpoints=True,
    checkpoint_dir='checkpoints'
)

print("Training complete!")

## 2. Training Curves

In [None]:
# Plot accuracy curves
plot_training_curves(
    results,
    metrics=['model_0_acc', 'model_1_acc', 'avg_acc'],
    title='DML Training - Test Accuracy',
    save_path='training_accuracy.png'
)

# Plot loss curves
plot_training_curves(
    results,
    metrics=['train_loss', 'kl_loss'],
    title='DML Training - Losses',
    save_path='training_loss.png'
)

## 3. Loss Landscape Visualization

Visualize the loss surface around trained models:

In [None]:
from dml-py.analysis import LossLandscape
import torch.nn as nn

# Create loss landscape analyzer
landscape = LossLandscape(
    model=trainer.models[0],
    criterion=nn.CrossEntropyLoss(),
    dataloader=test_loader
)

# Plot 1D loss curve
landscape.plot_1d(
    'loss_landscape_1d.png',
    alpha_min=-1.0,
    alpha_max=1.0,
    num_points=51
)

# Plot 2D loss surface
landscape.plot_2d(
    'loss_landscape_2d.png',
    alpha_min=-0.5,
    alpha_max=0.5,
    beta_min=-0.5,
    beta_max=0.5,
    num_points=20
)

print("✓ Loss landscape plots saved!")

## 4. Knowledge Flow Visualization

See how models transfer knowledge to each other:

In [None]:
# Plot knowledge flow heatmap
plot_knowledge_flow(
    trainer,
    test_loader,
    save_path='knowledge_flow.png'
)

# This shows:
# - Which models learn more from which peers
# - KL divergence between model predictions
# - Mutual information flow

## 5. Model Agreement Analysis

In [None]:
import numpy as np

def compute_agreement(models, dataloader):
    """Compute how often models agree on predictions."""
    agreements = []
    
    for model in models:
        model.eval()
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.cuda() if torch.cuda.is_available() else inputs
            
            # Get predictions from all models
            preds = [model(inputs).argmax(dim=1) for model in models]
            
            # Check if all models agree
            for i in range(len(preds[0])):
                all_same = all(p[i] == preds[0][i] for p in preds)
                agreements.append(all_same)
    
    return np.mean(agreements) * 100

agreement = compute_agreement(trainer.models, test_loader)
print(f"Model agreement: {agreement:.2f}%")
print(f"Disagreement: {100-agreement:.2f}%")

## 6. Robustness Analysis

In [None]:
from dml-py.analysis import RobustnessAnalyzer

# Analyze robustness to noise
analyzer = RobustnessAnalyzer(trainer.models[0])

# Test with different noise levels
noise_levels = [0.0, 0.05, 0.1, 0.15, 0.2]
accuracies = []

for noise in noise_levels:
    acc = analyzer.evaluate_with_noise(test_loader, noise_std=noise)
    accuracies.append(acc)
    print(f"Noise std={noise:.2f}: Accuracy={acc:.2f}%")

# Plot robustness curve
plt.figure(figsize=(10, 6))
plt.plot(noise_levels, accuracies, 'b-o', linewidth=2, markersize=8)
plt.xlabel('Noise Standard Deviation', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Model Robustness to Input Noise', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('robustness_plot.png', dpi=150)
plt.show()

## 7. Prediction Confidence Analysis

In [None]:
import torch.nn.functional as F

def analyze_confidence(model, dataloader):
    """Analyze prediction confidence distribution."""
    confidences_correct = []
    confidences_incorrect = []
    
    model.eval()
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.cuda() if torch.cuda.is_available() else inputs
            targets = targets.cuda() if torch.cuda.is_available() else targets
            
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            max_probs, preds = probs.max(dim=1)
            
            # Separate by correctness
            correct_mask = preds == targets
            confidences_correct.extend(max_probs[correct_mask].cpu().numpy())
            confidences_incorrect.extend(max_probs[~correct_mask].cpu().numpy())
    
    return confidences_correct, confidences_incorrect

# Analyze both models
conf_correct_1, conf_incorrect_1 = analyze_confidence(trainer.models[0], test_loader)
conf_correct_2, conf_incorrect_2 = analyze_confidence(trainer.models[1], test_loader)

# Plot confidence distributions
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Model 1
axes[0].hist(conf_correct_1, bins=50, alpha=0.7, label='Correct', color='green')
axes[0].hist(conf_incorrect_1, bins=50, alpha=0.7, label='Incorrect', color='red')
axes[0].set_xlabel('Confidence', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Model 1 - Prediction Confidence', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Model 2
axes[1].hist(conf_correct_2, bins=50, alpha=0.7, label='Correct', color='green')
axes[1].hist(conf_incorrect_2, bins=50, alpha=0.7, label='Incorrect', color='red')
axes[1].set_xlabel('Confidence', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title('Model 2 - Prediction Confidence', fontsize=14)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('confidence_analysis.png', dpi=150)
plt.show()

print(f"\nModel 1 - Avg confidence (correct): {np.mean(conf_correct_1):.3f}")
print(f"Model 1 - Avg confidence (incorrect): {np.mean(conf_incorrect_1):.3f}")
print(f"\nModel 2 - Avg confidence (correct): {np.mean(conf_correct_2):.3f}")
print(f"Model 2 - Avg confidence (incorrect): {np.mean(conf_incorrect_2):.3f}")

## Summary

You've learned how to:

✅ Visualize training curves and metrics  
✅ Plot loss landscapes  
✅ Analyze knowledge flow between models  
✅ Measure model agreement  
✅ Test robustness to noise  
✅ Analyze prediction confidence  

These tools help you:
- Understand what models are learning
- Debug training issues
- Compare different approaches
- Ensure model quality

## Next Steps

- Apply these analyses to your own models
- Create custom visualization functions
- Check out `05_deployment.ipynb` for production deployment