In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from dml-py.trainers import DMLTrainer
from dml-py.models.cifar import resnet20, wrn_16_2, mobilenet_v2
from dml-py.strategies import CurriculumLearning, PeerSelection, TemperatureScaling
from dml-py.utils import AMPConfig, apply_amp_to_trainer

## 1. Curriculum Learning

Start with easy examples and gradually increase difficulty:

In [None]:
# Prepare data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Create curriculum learning strategy
curriculum = CurriculumLearning(
    num_stages=3,
    difficulty_fn=lambda epoch: min(1.0, epoch / 30),  # Gradually increase difficulty
    start_easy=True
)

# Train with curriculum
models = [resnet20(10), wrn_16_2(10)]
trainer = DMLTrainer(models)

# Apply curriculum learning
curriculum.apply(trainer)

results = trainer.train(train_loader, test_loader, epochs=50)
print(f"Final accuracy with curriculum: {results['avg_acc'][-1]:.2f}%")

## 2. Peer Selection

Adaptively select which peers to learn from based on performance:

In [None]:
from dml-py.strategies import PeerSelection

# Create 4 models of different sizes
models = [
    resnet20(10),
    wrn_16_2(10),
    mobilenet_v2(10),
    resnet20(10)  # Another ResNet
]

# Initialize peer selection
peer_selector = PeerSelection(
    selection_strategy='performance',  # Choose best performing peers
    top_k=2,  # Learn from top 2 peers only
    update_frequency=5  # Update peer rankings every 5 epochs
)

trainer = DMLTrainer(models)
peer_selector.apply(trainer)

results = trainer.train(train_loader, test_loader, epochs=30)

# Show which peers were selected
print("\nPeer selection history:")
for epoch, selected in peer_selector.selection_history.items():
    print(f"Epoch {epoch}: Selected peers {selected}")

## 3. Temperature Scaling

Dynamically adjust temperature during training:

In [None]:
from dml-py.strategies import TemperatureScaling

# Create temperature scheduler
temp_scheduler = TemperatureScaling(
    initial_temperature=4.0,
    final_temperature=2.0,
    schedule='cosine',  # Cosine annealing
    warmup_epochs=5
)

models = [resnet20(10), wrn_16_2(10)]
trainer = DMLTrainer(models, temperature=4.0)

# Apply temperature scaling
temp_scheduler.apply(trainer)

results = trainer.train(train_loader, test_loader, epochs=50)

# Plot temperature schedule
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(temp_scheduler.temperature_history, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Temperature')
plt.title('Temperature Schedule')
plt.grid(True, alpha=0.3)
plt.show()

## 4. Mixed Precision Training

Train faster with automatic mixed precision (AMP):

In [None]:
from dml-py.utils import AMPConfig, apply_amp_to_trainer
import time

models = [resnet20(10), wrn_16_2(10)]

# Train WITHOUT AMP (baseline)
trainer_fp32 = DMLTrainer(models)
start = time.time()
trainer_fp32.train(train_loader, test_loader, epochs=10)
time_fp32 = time.time() - start

# Train WITH AMP
models = [resnet20(10), wrn_16_2(10)]  # Reset models
trainer_amp = DMLTrainer(models)

amp_config = AMPConfig(
    enabled=True,
    dtype=torch.float16
)
trainer_amp = apply_amp_to_trainer(trainer_amp, amp_config)

start = time.time()
trainer_amp.train(train_loader, test_loader, epochs=10)
time_amp = time.time() - start

# Compare
speedup = time_fp32 / time_amp
print(f"\nTraining Time:")
print(f"  FP32: {time_fp32:.1f}s")
print(f"  AMP:  {time_amp:.1f}s")
print(f"  Speedup: {speedup:.2f}x")

## 5. Combining Strategies

Use multiple strategies together for best results:

In [None]:
# Create models
models = [
    resnet20(10),
    wrn_16_2(10),
    mobilenet_v2(10)
]

# Initialize trainer
trainer = DMLTrainer(models, learning_rate=0.1, temperature=4.0)

# Apply curriculum learning
curriculum = CurriculumLearning(num_stages=3)
curriculum.apply(trainer)

# Apply peer selection
peer_selector = PeerSelection(selection_strategy='performance', top_k=2)
peer_selector.apply(trainer)

# Apply temperature scaling
temp_scheduler = TemperatureScaling(
    initial_temperature=4.0,
    final_temperature=2.0,
    schedule='cosine'
)
temp_scheduler.apply(trainer)

# Apply AMP
trainer = apply_amp_to_trainer(trainer, AMPConfig(enabled=True))

# Train with all strategies
print("Training with combined strategies...")
results = trainer.train(train_loader, test_loader, epochs=50)

print(f"\n✓ Final accuracy: {results['avg_acc'][-1]:.2f}%")

## 6. Hyperparameter Search

Automatically find the best hyperparameters:

In [None]:
from dml-py.utils import HyperparameterSpace, RandomSearcher

# Define search space
space = HyperparameterSpace({
    'learning_rate': [0.01, 0.05, 0.1],
    'temperature': [2, 3, 4, 5],
    'kl_weight': [0.5, 1.0, 2.0]
})

# Define objective function
def objective(config):
    models = [resnet20(10), wrn_16_2(10)]
    trainer = DMLTrainer(models, **config)
    results = trainer.train(train_loader, test_loader, epochs=10)
    return results['avg_acc'][-1]

# Run search
searcher = RandomSearcher(objective, 'accuracy', 'maximize')
best_config = searcher.search(space, n_trials=10)

print(f"\nBest configuration: {best_config}")
searcher.save_results('search_results.json')
searcher.plot_results('search_plot.png')

## Summary

You've learned:

✅ **Curriculum Learning** - Gradual difficulty increase  
✅ **Peer Selection** - Adaptive peer choosing  
✅ **Temperature Scaling** - Dynamic temperature adjustment  
✅ **Mixed Precision** - Faster training with AMP  
✅ **Combined Strategies** - Using multiple techniques together  
✅ **Hyperparameter Search** - Automatic optimization  

## Next Steps

- Try these strategies on your own datasets
- Experiment with different strategy combinations
- Check out `03_distillation_methods.ipynb` for knowledge distillation
- See `05_deployment.ipynb` for production deployment