# Problem 1.3: Performance Comparison

This notebook compares the performance between training on single device and multiple devices.

We will:
1. Run single node training (world_size=1, batch_size=64)
2. Run double nodes training (world_size=2, batch_size=128)
3. Collect and compare:
   - Training time (average over epochs from multiple devices)
   - Tokens per second (sum throughput from multiple devices, averaged over epochs)
4. Visualize the scaling improvement

**Note:** We drop the first epoch as warmup before collecting metrics.

## Important Notes

### Two Ways to Complete Problem 1.3:

#### Option A: Use This Notebook (Modal Cloud)
- Run all cells in this notebook
- Requires Modal account and credits
- Fully automated end-to-end
- Takes ~1-2 hours to complete

#### Option B: Use Standalone Script (Recommended)
- Run training experiments locally or on PSC:
  ```bash
  # Run single device
  python project/run_data_parallel.py --world_size 1 --batch_size 64
  
  # Run multi-device
  python project/run_data_parallel.py --world_size 2 --batch_size 128
  ```
- Then analyze results:
  ```bash
  python problem_1_3_analysis.py --workdir ./workdir
  ```
- See `PROBLEM_1_3_README.md` for full instructions

**Choose one option and proceed accordingly.**

In [7]:
import modal
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os

## Setup Modal App

We'll use the Modal infrastructure to run distributed training.

**Note:** Make sure you have the Modal CLI installed and authenticated:
```bash
pip install modal
modal token new
```

In [8]:
# Create Modal app and define image with dependencies
app = modal.App("problem-1-3-training")

image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "numpy<2",
        "torch==2.2.0",
        "datasets<4.0.0",
        "transformers==4.37.2",
        "sacrebleu==2.4.0",
        "tokenizers",
        "tqdm",
    )
    .add_local_python_source("project")
    .add_local_python_source("data_parallel")
)

# Shared volume for outputs
volume = modal.Volume.from_name("training-workdir-p1-3", create_if_missing=True)

## Define Training Function

This function runs distributed training using Modal with specified world_size.

In [9]:
@app.function(
    image=image,
    gpu="A10G:2",  # Single container with 2 GPUs
    volumes={"/workdir": volume},
    timeout=3600 * 4,
    cpu=16.0,
    memory=65536,
    serialized=False,
)
def run_training(
    world_size: int = 2,
    batch_size: int = 128,
    n_epochs: int = 10,
    dataset: str = "bbaaaa/iwslt14-de-en-preprocess",
    model_max_length: int = 128,
    learning_rate: float = 1e-4,
):
    """Run distributed training in a single container with multiple GPUs"""
    import torch.multiprocessing as mp
    import torch.distributed as dist
    import torch
    import os

    # Set spawn method before doing anything with CUDA
    mp.set_start_method("spawn", force=True)

    # Import the training module
    import project.run_data_parallel as rdp

    # Set environment variables
    os.environ["PYTEST"] = "False"
    rdp.PYTEST = False

    # Set workdir to Modal volume
    os.chdir("/workdir")

    # Get the run_dp function
    run_dp = rdp.run_dp

    processes = []
    backend = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO

    # Spawn processes just like the original code
    for rank in range(world_size):
        p = mp.Process(
            target=run_dp,
            args=(
                rank,
                world_size,
                backend,
                dataset,
                model_max_length,
                n_epochs,
                batch_size,
                learning_rate,
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    volume.commit()

    # Collect metrics from all result files
    metrics_by_rank = {}
    for rank in range(world_size):
        metrics_by_rank[rank] = []
        for epoch_idx in range(n_epochs):
            filename = f"/workdir/workdir/rank{rank}_results_epoch{epoch_idx}.json"
            if os.path.exists(filename):
                with open(filename, 'r') as f:
                    data = json.load(f)
                    metrics_by_rank[rank].append(data)
    
    volume.commit()

    return {"status": "complete", "world_size": world_size, "batch_size": batch_size, "metrics": metrics_by_rank}

## Helper Functions for Metric Collection

In [10]:
def compute_aggregated_metrics(metrics_by_rank, drop_first_epoch=True):
    """Compute aggregated metrics across ranks and epochs
    
    Args:
        metrics_by_rank: Dict mapping rank to list of epoch metrics
        drop_first_epoch: Whether to drop the first epoch (warmup)
    
    Returns:
        Dict with aggregated metrics
    """
    # Collect all training times and tokens_per_sec
    all_training_times = []
    all_tokens_per_sec_by_epoch = []
    
    world_size = len(metrics_by_rank)
    # Check if metrics_by_rank is non-empty and contains data before accessing
    if not metrics_by_rank or 0 not in metrics_by_rank:
        return {
            'avg_training_time': 0.0,
            'std_training_time': 0.0,
            'avg_tokens_per_sec': 0.0,
            'std_tokens_per_sec': 0.0,
            'world_size': world_size,
        }
    n_epochs = len(metrics_by_rank[0])
    
    start_epoch = 1 if drop_first_epoch else 0
    
    # For each epoch (excluding first if drop_first_epoch)
    for epoch_idx in range(start_epoch, n_epochs):
        epoch_training_times = []
        epoch_tokens_per_sec = []
        
        # Collect metrics from all ranks for this epoch
        for rank in range(world_size):
            if epoch_idx < len(metrics_by_rank[rank]):
                epoch_data = metrics_by_rank[rank][epoch_idx]
                epoch_training_times.append(epoch_data['training_time'])
                epoch_tokens_per_sec.append(epoch_data['tokens_per_sec'])
        
        # Average training time across ranks for this epoch (only if we have data)
        if epoch_training_times:
            all_training_times.append(np.mean(epoch_training_times))
        
        # Sum tokens_per_sec across ranks (total throughput) for this epoch (only if we have data)
        if epoch_tokens_per_sec:
            all_tokens_per_sec_by_epoch.append(np.sum(epoch_tokens_per_sec))
    
    # Now compute mean and std across epochs
    # Check if we have any data to compute statistics
    if len(all_training_times) == 0:
        return {
            'avg_training_time': 0.0,
            'std_training_time': 0.0,
            'avg_tokens_per_sec': 0.0,
            'std_tokens_per_sec': 0.0,
            'world_size': world_size,
        }
    
    avg_training_time = np.mean(all_training_times)
    std_training_time = np.std(all_training_times)
    
    avg_tokens_per_sec = np.mean(all_tokens_per_sec_by_epoch)
    std_tokens_per_sec = np.std(all_tokens_per_sec_by_epoch)
    
    return {
        'avg_training_time': avg_training_time,
        'std_training_time': std_training_time,
        'avg_tokens_per_sec': avg_tokens_per_sec,
        'std_tokens_per_sec': std_tokens_per_sec,
        'world_size': world_size,
    }

## Run Training Experiments

### Experiment 1: Single Device (world_size=1, batch_size=64)

In [None]:
print("Starting single device training...")
with app.run():
    # Launch remote run and wait for it to complete
    result_single_future = run_training.remote(
        world_size=1,
        batch_size=64,
        n_epochs=10,
    )
    # Retrieve the actual return value from the future with error handling
    try:
        result_single = result_single_future.get()
    except Exception as e:
        import traceback
        print("Error while retrieving single-device result:")
        traceback.print_exc()
        # Re-raise to surface the error in the notebook
        raise
print(f"Single device training complete")
metrics_single = result_single['metrics']

Starting single device training...


### Experiment 2: Multiple Devices (world_size=2, batch_size=128)

In [None]:
print("Starting multi-device training...")
with app.run():
    # Launch remote run and wait for it to complete
    result_multi_future = run_training.remote(
        world_size=2,
        batch_size=128,
        n_epochs=10,
    )
    # Retrieve the actual return value from the future with error handling
    try:
        result_multi = result_multi_future.get()
    except Exception as e:
        import traceback
        print("Error while retrieving multi-device result:")
        traceback.print_exc()
        # Re-raise to surface the error in the notebook
        raise
print(f"Multi-device training complete")
metrics_multi = result_multi['metrics']

## Collect Metrics

Now we'll collect the metrics from both experiments, excluding the first epoch.

In [None]:
# Process single device metrics
print("Processing single device metrics...")
aggregated_single = compute_aggregated_metrics(metrics_single, drop_first_epoch=True)

print(f"Single Device Metrics:")
print(f"  Training Time: {aggregated_single['avg_training_time']:.2f} ± {aggregated_single['std_training_time']:.2f} seconds")
print(f"  Tokens/Second: {aggregated_single['avg_tokens_per_sec']:.2f} ± {aggregated_single['std_tokens_per_sec']:.2f}")

# Process multi-device metrics
print("\nProcessing multi-device metrics...")
aggregated_multi = compute_aggregated_metrics(metrics_multi, drop_first_epoch=True)

print(f"Multi-Device Metrics:")
print(f"  Training Time: {aggregated_multi['avg_training_time']:.2f} ± {aggregated_multi['std_training_time']:.2f} seconds")
print(f"  Tokens/Second: {aggregated_multi['avg_tokens_per_sec']:.2f} ± {aggregated_multi['std_tokens_per_sec']:.2f}")

## Visualize Results

Create comparison plots for training time and tokens per second.

In [None]:
def plot_comparison(means, stds, labels, ylabel, title, filename):
    """Create a bar plot comparing metrics"""
    fig, ax = plt.subplots(figsize=(8, 6))
    x = np.arange(len(means))
    
    ax.bar(x, means, yerr=stds,
           align='center', alpha=0.7, ecolor='red', capsize=10, width=0.6,
           color=['skyblue', 'lightcoral'])
    
    ax.set_ylabel(ylabel, fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=11)
    ax.yaxis.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, (mean, std) in enumerate(zip(means, stds)):
        ax.text(i, mean, f'{mean:.2f}\n±{std:.2f}', 
                ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    
    # Save to submit/figures directory
    output_dir = Path('submit/figures')
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / filename
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved figure to {output_path}")
    
    plt.show()
    plt.close(fig)

In [None]:
# Plot Training Time comparison
plot_comparison(
    means=[aggregated_single['avg_training_time'], aggregated_multi['avg_training_time']],
    stds=[aggregated_single['std_training_time'], aggregated_multi['std_training_time']],
    labels=['Single Device\n(world_size=1)', 'Multi-Device\n(world_size=2)'],
    ylabel='Training Time (seconds)',
    title='Training Time Comparison',
    filename='training_time_comparison.png'
)

In [None]:
# Plot Tokens Per Second comparison
plot_comparison(
    means=[aggregated_single['avg_tokens_per_sec'], aggregated_multi['avg_tokens_per_sec']],
    stds=[aggregated_single['std_tokens_per_sec'], aggregated_multi['std_tokens_per_sec']],
    labels=['Single Device\n(world_size=1)', 'Multi-Device\n(world_size=2)'],
    ylabel='Tokens Per Second',
    title='Throughput Comparison',
    filename='tokens_per_second_comparison.png'
)

## Summary

Print a summary of the speedup achieved.

In [None]:
print("\n" + "="*60)
print("PERFORMANCE COMPARISON SUMMARY")
print("="*60)

print(f"\nSingle Device (world_size=1, batch_size=64):")
print(f"  Average Training Time: {aggregated_single['avg_training_time']:.2f} ± {aggregated_single['std_training_time']:.2f} seconds")
print(f"  Average Throughput: {aggregated_single['avg_tokens_per_sec']:.2f} ± {aggregated_single['std_tokens_per_sec']:.2f} tokens/sec")

print(f"\nMulti-Device (world_size=2, batch_size=128):")
print(f"  Average Training Time: {aggregated_multi['avg_training_time']:.2f} ± {aggregated_multi['std_training_time']:.2f} seconds")
print(f"  Average Throughput: {aggregated_multi['avg_tokens_per_sec']:.2f} ± {aggregated_multi['std_tokens_per_sec']:.2f} tokens/sec")

# Calculate speedup (with validation to avoid division by zero)
if aggregated_multi['avg_training_time'] > 0 and aggregated_single['avg_tokens_per_sec'] > 0:
    time_speedup = aggregated_single['avg_training_time'] / aggregated_multi['avg_training_time']
    throughput_speedup = aggregated_multi['avg_tokens_per_sec'] / aggregated_single['avg_tokens_per_sec']
    
    print(f"\nSpeedup:")
    print(f"  Training Time Speedup: {time_speedup:.2f}x")
    print(f"  Throughput Speedup: {throughput_speedup:.2f}x")
else:
    print(f"\nSpeedup: Cannot compute (insufficient data)")

print("\n" + "="*60)
print(f"\nFigures saved to: submit/figures/")
print("  - training_time_comparison.png")
print("  - tokens_per_second_comparison.png")
print("="*60)