# Flow SDK Training Notebook

Train models from scratch using Flow SDK with support for distributed training, automatic checkpointing, and cost optimization.

This notebook covers:
- Single GPU training
- Multi-GPU distributed training
- Training from checkpoints
- Monitoring training progress
- Cost optimization strategies

## Setup

First, let's install and configure the Flow SDK:

In [None]:
# Install Flow SDK
!pip install flow-sdk --upgrade

# Import required libraries
import flow
from flow import TaskConfig
import json
import time
import matplotlib.pyplot as plt
from typing import Dict, List
import pandas as pd

In [None]:
# Initialize Flow client
flow_client = flow.Flow()

# Check authentication
print("✓ Flow SDK initialized")
print(f"API Endpoint: {flow_client.api_endpoint}")

## 1. Quick Training Example

Let's start with a simple ResNet training on CIFAR-10:

In [None]:
# Create training script
simple_train_script = """
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time
import json

# Training configuration
config = {
    'batch_size': 128,
    'epochs': 10,
    'learning_rate': 0.001,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"Training on: {config['device']}")
if config['device'] == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Data preparation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

# Model setup
model = models.resnet18(pretrained=False, num_classes=10)
model = model.to(config['device'])

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Training loop
metrics = {'train_loss': [], 'test_accuracy': []}

for epoch in range(config['epochs']):
    # Training
    model.train()
    train_loss = 0.0
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(config['device']), target.to(config['device'])
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx % 50 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
                  f'Loss: {loss.item():.4f}')
    
    # Validation
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(config['device']), target.to(config['device'])
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    accuracy = 100. * correct / len(test_loader.dataset)
    epoch_time = time.time() - start_time
    
    print(f'\nEpoch {epoch}: Train Loss: {train_loss/len(train_loader):.4f}, '
          f'Test Accuracy: {accuracy:.2f}%, Time: {epoch_time:.1f}s\n')
    
    metrics['train_loss'].append(train_loss / len(train_loader))
    metrics['test_accuracy'].append(accuracy)

# Save model and metrics
torch.save(model.state_dict(), 'resnet18_cifar10.pth')
with open('training_metrics.json', 'w') as f:
    json.dump(metrics, f)

print(f"\nTraining complete! Final accuracy: {accuracy:.2f}%")
"""

# Save script
with open("/tmp/train_cifar10.py", "w") as f:
    f.write(simple_train_script)

print("✓ Training script created")

In [None]:
# Configure and run training
training_config = TaskConfig(
    name="cifar10-resnet-training",
    command="""
    pip install torch torchvision
    python /workspace/train_cifar10.py
    """,
    instance_type="a100",  # Single A100 80GB
    upload_files={"/tmp/train_cifar10.py": "train_cifar10.py"},
    download_patterns=["*.pth", "*.json"],
    max_price_per_hour=10.00,
    max_run_time_hours=2
)

print("🚀 Starting training job...")
training_task = flow_client.run(training_config)
print(f"Task ID: {training_task.task_id}")
print(f"Status: {training_task.status}")

In [None]:
# Monitor training progress
print("⏳ Monitoring training progress...\n")

while True:
    task_info = flow_client.get_task(training_task.task_id)
    
    if task_info.status in ["completed", "failed", "cancelled"]:
        print(f"\nTask {task_info.status}!")
        break
    
    # Get recent logs
    logs = task_info.logs(tail=5)
    if logs:
        print(logs)
    
    print(f"\nStatus: {task_info.status} | Cost: ${task_info.total_cost:.3f}")
    time.sleep(30)

## 2. Distributed Multi-GPU Training

Scale up training with PyTorch Distributed Data Parallel:

In [None]:
# Distributed training script
distributed_script = """
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time

def setup_distributed():
    """Initialize distributed training."""
    rank = int(os.environ.get('RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    
    if world_size > 1:
        dist.init_process_group(backend='nccl')
        torch.cuda.set_device(rank)
    
    return rank, world_size

def train_distributed():
    rank, world_size = setup_distributed()
    device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
    
    # Only print from rank 0
    if rank == 0:
        print(f"Training on {world_size} GPUs")
        print(f"Device: {torch.cuda.get_device_name(rank)}")
    
    # Data preparation
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Use ImageNet subset or CIFAR-100 for demo
    train_dataset = datasets.CIFAR100(
        root='./data', 
        train=True, 
        download=True, 
        transform=transform
    )
    
    # Distributed sampler
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank
    ) if world_size > 1 else None
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,  # Per GPU batch size
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )
    
    # Model setup
    model = models.resnet50(pretrained=False, num_classes=100)
    model = model.to(device)
    
    if world_size > 1:
        model = DDP(model, device_ids=[rank])
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=0.1 * world_size,  # Scale learning rate
        momentum=0.9
    )
    
    # Training loop
    epochs = 5
    for epoch in range(epochs):
        if train_sampler:
            train_sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0.0
        start_time = time.time()
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if rank == 0 and batch_idx % 20 == 0:
                print(f'Epoch {epoch} [{batch_idx * len(data) * world_size}'
                      f'/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')
        
        epoch_time = time.time() - start_time
        
        if rank == 0:
            avg_loss = total_loss / len(train_loader)
            throughput = len(train_loader.dataset) / epoch_time
            print(f'\nEpoch {epoch}: Avg Loss: {avg_loss:.4f}, '
                  f'Time: {epoch_time:.1f}s, Throughput: {throughput:.1f} img/s\n')
    
    # Save model (only from rank 0)
    if rank == 0:
        if world_size > 1:
            torch.save(model.module.state_dict(), 'resnet50_distributed.pth')
        else:
            torch.save(model.state_dict(), 'resnet50_distributed.pth')
        print("Model saved!")
    
    # Cleanup
    if world_size > 1:
        dist.destroy_process_group()

if __name__ == "__main__":
    train_distributed()
"""

# Save script
with open("/tmp/train_distributed.py", "w") as f:
    f.write(distributed_script)

print("✓ Distributed training script created")

In [None]:
# Run distributed training on 4 GPUs
distributed_config = TaskConfig(
    name="distributed-resnet-training",
    command="""
    pip install torch torchvision
    
    # Run with torchrun for distributed training
    torchrun \
        --nproc_per_node=4 \
        --master_port=29500 \
        /workspace/train_distributed.py
    """,
    instance_type="4xa100",  # 4x A100 80GB
    upload_files={"/tmp/train_distributed.py": "train_distributed.py"},
    download_patterns=["*.pth"],
    max_price_per_hour=40.00,
    max_run_time_hours=2
)

print("🚀 Starting distributed training on 4x A100...")
dist_task = flow_client.run(distributed_config)
print(f"Task ID: {dist_task.task_id}")

## 3. Training with Checkpoints

Implement automatic checkpointing for long-running training:

In [None]:
# Checkpointing training script
checkpoint_script = """
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer, AdamW
from torch.utils.data import DataLoader, Dataset
import os
import json
from datetime import datetime

class CheckpointManager:
    def __init__(self, checkpoint_dir="checkpoints"):
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    def save_checkpoint(self, epoch, model, optimizer, metrics, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            'timestamp': datetime.now().isoformat()
        }
        
        # Save latest checkpoint
        path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth')
        torch.save(checkpoint, path)
        
        # Save epoch checkpoint
        epoch_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
        torch.save(checkpoint, epoch_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'checkpoint_best.pth')
            torch.save(checkpoint, best_path)
            
        print(f"Checkpoint saved: epoch {epoch}")
    
    def load_checkpoint(self, model, optimizer, checkpoint_name='checkpoint_latest.pth'):
        path = os.path.join(self.checkpoint_dir, checkpoint_name)
        if os.path.exists(path):
            checkpoint = torch.load(path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f"Resumed from epoch {checkpoint['epoch']}")
            return checkpoint['epoch'], checkpoint['metrics']
        return 0, {}

# Dummy dataset for demo
class TextDataset(Dataset):
    def __init__(self, size=10000):
        self.size = size
        self.texts = [f"Sample text {i}" for i in range(size)]
        self.labels = [i % 2 for i in range(size)]
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

# Training with checkpoints
def train_with_checkpoints():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training on: {device}")
    
    # Model setup
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=2
    ).to(device)
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    optimizer = AdamW(model.parameters(), lr=2e-5)
    
    # Checkpoint manager
    ckpt_manager = CheckpointManager()
    
    # Try to resume from checkpoint
    start_epoch, metrics = ckpt_manager.load_checkpoint(model, optimizer)
    best_accuracy = metrics.get('best_accuracy', 0.0)
    
    # Data loading
    dataset = TextDataset()
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Training loop
    epochs = 10
    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0
        correct = 0
        
        for batch_idx, (texts, labels) in enumerate(dataloader):
            # Tokenize
            inputs = tokenizer(
                texts, 
                return_tensors='pt', 
                padding=True, 
                truncation=True,
                max_length=128
            ).to(device)
            
            labels = torch.tensor(labels).to(device)
            
            # Forward pass
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            # Calculate accuracy
            predictions = outputs.logits.argmax(dim=-1)
            correct += (predictions == labels).sum().item()
            
            # Save checkpoint every 100 batches
            if batch_idx % 100 == 0:
                ckpt_manager.save_checkpoint(
                    epoch, model, optimizer,
                    {'batch': batch_idx, 'loss': total_loss / (batch_idx + 1)}
                )
        
        # Calculate epoch metrics
        accuracy = correct / len(dataset)
        avg_loss = total_loss / len(dataloader)
        
        print(f"\nEpoch {epoch}: Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Save checkpoint
        metrics = {
            'loss': avg_loss,
            'accuracy': accuracy,
            'best_accuracy': max(best_accuracy, accuracy)
        }
        
        is_best = accuracy > best_accuracy
        if is_best:
            best_accuracy = accuracy
        
        ckpt_manager.save_checkpoint(epoch, model, optimizer, metrics, is_best)
    
    print(f"\nTraining complete! Best accuracy: {best_accuracy:.4f}")

if __name__ == "__main__":
    train_with_checkpoints()
"""

# Save script
with open("/tmp/train_with_checkpoints.py", "w") as f:
    f.write(checkpoint_script)

print("✓ Checkpoint training script created")

In [None]:
# Run training with checkpoints and persistent storage
checkpoint_config = TaskConfig(
    name="bert-checkpoint-training",
    command="""
    pip install torch transformers
    python /workspace/train_with_checkpoints.py
    """,
    instance_type="a100",
    upload_files={"/tmp/train_with_checkpoints.py": "train_with_checkpoints.py"},
    download_patterns=["checkpoints/*"],
    
    # Enable automatic retries on spot instance preemption
    retry_on_failure=True,
    max_retries=3,
    
    # Cost optimization
    max_price_per_hour=10.00,
    max_run_time_hours=4,
    
    # Persistent volume for checkpoints
    volumes=[
        {
            "name": "training-checkpoints",
            "mount_path": "/workspace/checkpoints",
            "size_gb": 50
        }
    ]
)

print("🚀 Starting checkpoint-enabled training...")
ckpt_task = flow_client.run(checkpoint_config)
print(f"Task ID: {ckpt_task.task_id}")

## 4. Training Pipeline with Monitoring

Complete training pipeline with TensorBoard monitoring:

In [None]:
# Training pipeline with monitoring
pipeline_config = TaskConfig(
    name="training-pipeline-monitored",
    command="""
    # Install dependencies
    pip install torch torchvision tensorboard matplotlib
    
    # Start TensorBoard in background
    tensorboard --logdir=/workspace/logs --host=0.0.0.0 --port=6006 &
    
    # Create training script with TensorBoard logging
    python - << 'EOF'
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os

# Setup
writer = SummaryWriter('/workspace/logs')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Log system info
if device == 'cuda':
    writer.add_text('System/GPU', torch.cuda.get_device_name(0))
    writer.add_text('System/CUDA', torch.version.cuda)

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = datasets.MNIST('./data', train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)

model = SimpleNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Log model graph
dummy_input = torch.randn(1, 1, 28, 28).to(device)
writer.add_graph(model, dummy_input)

# Training loop with monitoring
global_step = 0
epochs = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Log to TensorBoard
        if batch_idx % 10 == 0:
            writer.add_scalar('Loss/train', loss.item(), global_step)
            
            # Log gradients
            for name, param in model.named_parameters():
                if param.grad is not None:
                    writer.add_histogram(f'Gradients/{name}', param.grad, global_step)
        
        global_step += 1
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)}] '
                  f'Loss: {loss.item():.6f}')
    
    # Validation
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(testloader.dataset)
    accuracy = 100. * correct / len(testloader.dataset)
    
    # Log validation metrics
    writer.add_scalar('Loss/test', test_loss, epoch)
    writer.add_scalar('Accuracy/test', accuracy, epoch)
    
    print(f'\nTest set: Avg loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(testloader.dataset)} ({accuracy:.2f}%)\n')

# Save model
torch.save(model.state_dict(), 'mnist_model.pth')
writer.close()

print("Training complete! TensorBoard logs saved to /workspace/logs")
EOF
    
    # Keep TensorBoard running
    sleep infinity
    """,
    instance_type="a100",
    ports=[6006],  # TensorBoard port
    download_patterns=["*.pth", "logs/*"],
    max_price_per_hour=10.00,
    max_run_time_hours=2
)

print("🚀 Starting training pipeline with monitoring...")
pipeline_task = flow_client.run(pipeline_config)
print(f"Task ID: {pipeline_task.task_id}")
print("\n📊 TensorBoard will be available at port 6006 once training starts")

## 5. Cost Analysis and Optimization

Analyze training costs and optimize resource usage:

In [None]:
# Get all training tasks from the last 24 hours
import datetime

training_tasks = flow_client.list_tasks(
    created_after=datetime.datetime.now() - datetime.timedelta(days=1)
)

# Filter for training tasks
training_tasks = [t for t in training_tasks if 'train' in t.name.lower()]

# Analyze costs
print("💰 Training Cost Analysis")
print("=" * 60)
print(f"{'Task Name':<30} {'Duration':<12} {'Cost':<10} {'Instance':<15}")
print("=" * 60)

total_cost = 0
total_gpu_hours = 0

for task in training_tasks:
    duration = getattr(task, 'duration_hours', 0)
    cost = getattr(task, 'total_cost', 0)
    total_cost += cost
    
    # Calculate GPU hours
    gpu_count = 1
    if '2xa100' in task.instance_type:
        gpu_count = 2
    elif '4xa100' in task.instance_type:
        gpu_count = 4
    elif '8xa100' in task.instance_type:
        gpu_count = 8
    
    total_gpu_hours += duration * gpu_count
    
    print(f"{task.name[:30]:<30} {duration:>6.2f} hrs  ${cost:>8.2f}  {task.instance_type:<15}")

print("=" * 60)
print(f"{'Total:':<30} {'':<12} ${total_cost:>8.2f}")
print(f"\nTotal GPU-hours: {total_gpu_hours:.1f}")
print(f"Average cost per GPU-hour: ${total_cost/total_gpu_hours:.2f}" if total_gpu_hours > 0 else "")

In [None]:
# Training cost optimization recommendations
optimization_tips = """
🎯 Training Cost Optimization Tips:

1. **Use Mixed Precision Training**
   - 2x faster training with minimal accuracy loss
   - Add: `--fp16` or use torch.cuda.amp

2. **Gradient Checkpointing**
   - Trade compute for memory
   - Allows larger batch sizes on same GPU

3. **Right-size Your Instance**
   - Monitor GPU utilization
   - Scale down if under 80% utilized

4. **Use Spot Instances**
   - Up to 70% cost savings
   - Implement checkpointing for resilience

5. **Optimize Data Loading**
   - Use multiple workers: num_workers=4
   - Pin memory: pin_memory=True
   - Prefetch data to GPU
"""

print(optimization_tips)

# Example optimized configuration
optimized_config = TaskConfig(
    name="optimized-training",
    command="""
    python train.py \
        --fp16 \
        --gradient_checkpointing \
        --batch_size 128 \
        --num_workers 4 \
        --pin_memory
    """,
    instance_type="a100",
    max_price_per_hour=8.00,  # Lower bid for spot
    retry_on_failure=True     # Handle preemptions
)

print("\n📋 Example optimized configuration created")

## 6. Advanced Training Patterns

### Hyperparameter Tuning

In [None]:
# Hyperparameter sweep
learning_rates = [1e-4, 5e-4, 1e-3]
batch_sizes = [32, 64, 128]

sweep_tasks = []

for lr in learning_rates:
    for bs in batch_sizes:
        config = TaskConfig(
            name=f"sweep-lr{lr}-bs{bs}",
            command=f"""
            python train.py \
                --learning_rate {lr} \
                --batch_size {bs} \
                --epochs 5
            """,
            instance_type="a100",
            max_price_per_hour=10.00,
            max_run_time_hours=1
        )
        
        # task = flow_client.run(config)
        # sweep_tasks.append(task)
        print(f"Would run: lr={lr}, batch_size={bs}")

print(f"\n🎯 Hyperparameter sweep: {len(learning_rates) * len(batch_sizes)} configurations")

### Data Parallel vs Model Parallel

In [None]:
# Comparison of parallelization strategies
strategies = {
    "Data Parallel (DDP)": {
        "description": "Split batch across GPUs, replicate model",
        "best_for": "Models that fit on single GPU",
        "instance": "4xa100",
        "example": "ResNet, BERT-Base"
    },
    "Model Parallel": {
        "description": "Split model layers across GPUs",
        "best_for": "Models too large for single GPU",
        "instance": "8xa100",
        "example": "GPT-3, LLaMA-70B"
    },
    "Pipeline Parallel": {
        "description": "Split model stages, micro-batching",
        "best_for": "Very deep models",
        "instance": "4xa100",
        "example": "Deep ResNets, Transformers"
    },
    "3D Parallel": {
        "description": "Combine all strategies",
        "best_for": "Extreme scale training",
        "instance": "8xa100",
        "example": "LLaMA-2, GPT-4"
    }
}

print("🎯 Parallelization Strategy Guide\n")
for strategy, info in strategies.items():
    print(f"**{strategy}**")
    print(f"  Description: {info['description']}")
    print(f"  Best for: {info['best_for']}")
    print(f"  Recommended: {info['instance']}")
    print(f"  Examples: {info['example']}")
    print()

## 7. Cleanup and Summary

In [None]:
# Summary of all training runs
all_tasks = flow_client.list_tasks()
training_tasks = [t for t in all_tasks if 'train' in t.name.lower()]

print("📊 Training Session Summary")
print("=" * 50)
print(f"Total training tasks: {len(training_tasks)}")
print(f"Total cost: ${sum(getattr(t, 'total_cost', 0) for t in training_tasks):.2f}")
print(f"Total GPU hours: {sum(getattr(t, 'duration_hours', 0) for t in training_tasks):.1f}")

# Cleanup running tasks
running = [t for t in training_tasks if t.status == "running"]
if running:
    print(f"\n⚠️  {len(running)} tasks still running")
    # Uncomment to stop all
    # for task in running:
    #     flow_client.cancel(task.task_id)

## Summary

In this notebook, you learned how to:

1. **Run single GPU training** with PyTorch and TensorFlow
2. **Scale to multi-GPU** with distributed training
3. **Implement checkpointing** for fault tolerance
4. **Monitor training** with TensorBoard
5. **Analyze costs** and optimize resource usage
6. **Use advanced patterns** like hyperparameter sweeps

### Key Takeaways

- Start with single GPU and profile before scaling
- Always implement checkpointing for long runs
- Use mixed precision training for 2x speedup
- Monitor GPU utilization to right-size instances
- Mithril's dynamic pricing rewards flexible scheduling

### Next Steps

- Explore [Fine-tuning Notebook](fine-tuning.ipynb) for model customization
- Check [Inference Notebook](inference.ipynb) for deployment
- Read [Distributed Training Guide](../../guides/distributed-training.md)