In [2]:
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
import time

In [4]:
# Environment Check

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

PyTorch Version: 2.9.0+cu126
CUDA Available: True
CUDA Version: 12.6
GPU: Tesla T4
GPU Memory: 15.83 GB
Using device: cuda
Using dtype: torch.bfloat16


# Motivation

Standard Training:  
  Forward: x -> a1 -> a2 -> a3 -> loss  
           [store] [store] [store]  
  Backward: uses stored a1, a2, a3  
  Memory: O(n) activations  
  
With Checkpointing:  
  Forward: x -> a1 -> a2 -> a3 -> loss  
           [save]       [save]  
  Backward: recompute a2 from a1, then use  
  Memory: O(sqrt(n)) with optimal placement  
  
For 100 layers:  
  Standard: 100 activations stored  
  Checkpointed: ~10 activations stored (at checkpoint boundaries)  
  Memory reduction: 10x  
  
Compute analysis:  
  Without checkpointing: n forward + n backward = 2n  
  With checkpointing:    n forward + n backward + n recompute = 3n  
  Compute overhead: ~50% of total training time  

KEY: The sqrt(n) rule - optimal checkpointing reduces memory from O(n) to O(sqrt(n))  

# Optimizations by torch.compile()

How torch.compile handles activations:

1. TRACING
   torch.compile traces forward AND backward into one graph.
   This lets it see the whole picture.

2. MIN-CUT PARTITIONING
   The graph is split at the optimal points.
   Algorithm minimizes: tensors crossing the cut
   (These are the tensors saved for backward)

3. AUTOMATIC RECOMPUTATION
   Cheap ops (relu, add, mul) are recomputed automatically.
   No user intervention needed.

4. FUSION
   Pointwise ops get fused into kernels.
   Fused ops are fast to recompute.

So looks like torch.compile gives us SOME memory savings for FREE, plus speed improvements from fusion.

# Resnet50 activation checkpointing

In [13]:
# 5.1 ResNet50 Architecture Overview

try:
    import torchvision.models as models
except Exception as e:
    models = None
    print("torchvision is required for the ResNet50 case study.")
    print(f"Details: {type(e).__name__}: {e}")

if models is not None:
    # Load ResNet50 without downloading weights
    try:
        resnet50 = models.resnet50(weights=None)
    except TypeError:
        # Older torchvision
        resnet50 = models.resnet50(pretrained=False)

    print("ResNet50 Structure:")
    print("-" * 60)
    print(f"conv1:  1 conv layer")
    print(f"layer1: {len(resnet50.layer1)} Bottleneck blocks (64 -> 256 channels)")
    print(f"layer2: {len(resnet50.layer2)} Bottleneck blocks (128 -> 512 channels)")
    print(f"layer3: {len(resnet50.layer3)} Bottleneck blocks (256 -> 1024 channels)")
    print(f"layer4: {len(resnet50.layer4)} Bottleneck blocks (512 -> 2048 channels)")
    print(f"fc:     1 linear layer")
    print()
    print(
        f"Total Bottleneck blocks: {len(resnet50.layer1) + len(resnet50.layer2) + len(resnet50.layer3) + len(resnet50.layer4)}"
    )
    print()
    print("Each Bottleneck block has 3 conv layers + skip connection.")
    print("Feature maps grow larger in early layers, then shrink spatially.")
    print("layer3 often has the largest activation memory (1024 channels, moderate spatial size).")

ResNet50 Structure:
------------------------------------------------------------
conv1:  1 conv layer
layer1: 3 Bottleneck blocks (64 -> 256 channels)
layer2: 4 Bottleneck blocks (128 -> 512 channels)
layer3: 6 Bottleneck blocks (256 -> 1024 channels)
layer4: 3 Bottleneck blocks (512 -> 2048 channels)
fc:     1 linear layer

Total Bottleneck blocks: 16

Each Bottleneck block has 3 conv layers + skip connection.
Feature maps grow larger in early layers, then shrink spatially.
layer3 often has the largest activation memory (1024 channels, moderate spatial size).


In [7]:
# 5.2 ResNet50 with Checkpointing: Three Strategies

try:
    import torchvision.models as models
except Exception as e:
    models = None
    ResNet50Checkpointed = None
    print("torchvision is required for ResNet50Checkpointed.")
    print(f"Details: {type(e).__name__}: {e}")

if models is not None:
    class ResNet50Checkpointed(nn.Module):
        """ResNet50 with configurable checkpointing strategies."""

        def __init__(self, num_classes=1000, checkpoint_strategy='none'):
            """
            Args:
                checkpoint_strategy: 'none', 'per_stage', 'per_block', or 'aggressive'
            """
            super().__init__()

            # Load base ResNet50 without downloading weights
            try:
                base = models.resnet50(weights=None)
            except TypeError:
                base = models.resnet50(pretrained=False)

            # Copy all layers
            self.conv1 = base.conv1
            self.bn1 = base.bn1
            self.relu = base.relu
            self.maxpool = base.maxpool
            self.layer1 = base.layer1
            self.layer2 = base.layer2
            self.layer3 = base.layer3
            self.layer4 = base.layer4
            self.avgpool = base.avgpool
            self.fc = nn.Linear(2048, num_classes)

            self.checkpoint_strategy = checkpoint_strategy

        def _forward_stage(self, stage, x):
            """Forward through a stage (layer1, layer2, etc.)."""
            for block in stage:
                x = block(x)
            return x

        def forward(self, x):
            # Stem
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            # NOTE: ResNet blocks include BatchNorm, which updates running stats in train mode.
            # Checkpointing recomputes forward in backward, which can update BN stats twice.
            # For strict parity, consider freezing BN stats (eval for BN) or using GroupNorm.

            if self.checkpoint_strategy == 'none':
                x = self.layer1(x)
                x = self.layer2(x)
                x = self.layer3(x)
                x = self.layer4(x)

            elif self.checkpoint_strategy == 'per_stage':
                x = checkpoint(lambda t: self._forward_stage(self.layer1, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer2, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer3, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer4, t), x, use_reentrant=False)

            elif self.checkpoint_strategy == 'per_block':
                for block in self.layer1:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer2:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer3:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer4:
                    x = checkpoint(block, x, use_reentrant=False)

            elif self.checkpoint_strategy == 'aggressive':
                x = self.layer1(x)
                x = self.layer2(x)
                for block in self.layer3:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer4:
                    x = checkpoint(block, x, use_reentrant=False)
            else:
                raise ValueError(f"Unknown checkpoint_strategy: {self.checkpoint_strategy}")

            # Head
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

    print("ResNet50Checkpointed created with 4 strategies:")
    print("-" * 60)
    print("'none':       No checkpointing (baseline)")
    print("'per_stage':  Checkpoint each of the 4 stages")
    print("'per_block':  Checkpoint each of the 16 bottleneck blocks")
    print("'aggressive': Only checkpoint layer3 and layer4 (best ROI)")

ResNet50Checkpointed created with 4 strategies:
------------------------------------------------------------
'none':       No checkpointing (baseline)
'per_stage':  Checkpoint each of the 4 stages
'per_block':  Checkpoint each of the 16 bottleneck blocks
'aggressive': Only checkpoint layer3 and layer4 (best ROI)


In [8]:
# 5.3 Memory Comparison: ResNet50 Strategies

def benchmark_resnet50(strategy, batch_size=32, image_size=224, num_iters=5):
    """Benchmark memory and time for a ResNet50 checkpointing strategy."""
    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    model = ResNet50Checkpointed(num_classes=1000, checkpoint_strategy=strategy).cuda()
    model.train()

    x = torch.randn(batch_size, 3, image_size, image_size, device="cuda")
    target = torch.randint(0, 1000, (batch_size,), device="cuda")
    criterion = nn.CrossEntropyLoss()

    # Warmup
    for _ in range(2):
        out = model(x)
        loss = criterion(out, target)
        loss.backward()
        model.zero_grad(set_to_none=True)

    # Measure memory + time
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_iters):
        out = model(x)
        loss = criterion(out, target)
        loss.backward()
        model.zero_grad(set_to_none=True)
    torch.cuda.synchronize()

    elapsed = (time.time() - start) / num_iters * 1000  # ms/iter
    peak_mem = torch.cuda.max_memory_allocated() / 1e6  # MB

    del model, x, target
    torch.cuda.empty_cache()

    return peak_mem, elapsed

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    batch_size = 32
    image_size = 224

    print("ResNet50 Checkpointing Comparison")
    print(f"Batch size: {batch_size}, Image size: {image_size}x{image_size}")
    print("-" * 70)
    print(f"{'Strategy':<15} {'Peak Memory (MB)':<20} {'Time (ms)':<15} {'Mem Savings':<15}")
    print("-" * 70)

    strategies = ['none', 'aggressive', 'per_stage', 'per_block']
    results = {}

    for strategy in strategies:
        mem, time_ms = benchmark_resnet50(strategy, batch_size, image_size)
        results[strategy] = (mem, time_ms)

    baseline_mem = results['none'][0]
    for strategy in strategies:
        mem, time_ms = results[strategy]
        savings = (1 - mem / baseline_mem) * 100
        print(f"{strategy:<15} {mem:<20.2f} {time_ms:<15.2f} {savings:>6.1f}%")

    print("-" * 70)
else:
    print("[Run on GPU and ensure torchvision is installed for ResNet50 benchmarks]")

ResNet50 Checkpointing Comparison
Batch size: 32, Image size: 224x224
----------------------------------------------------------------------
Strategy        Peak Memory (MB)     Time (ms)       Mem Savings    
----------------------------------------------------------------------
none            3033.27              276.69             0.0%
aggressive      2627.26              316.66            13.4%
per_stage       1765.28              368.39            41.8%
per_block       1360.01              368.90            55.2%
----------------------------------------------------------------------


In [9]:
# 5.4 Maximum Batch Size: ResNet50

import torch
import torch.nn as nn

def find_max_batch_resnet50(strategy, start=16, max_batch=256):
    """Find maximum batch size before OOM for ResNet50 (CUDA only)."""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for find_max_batch_resnet50().")
    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    start = max(1, int(start))
    max_batch = max(start, int(max_batch))

    def can_run(batch_size):
        model = None
        x = None
        target = None
        out = None
        loss = None
        try:
            torch.cuda.empty_cache()
            model = ResNet50Checkpointed(checkpoint_strategy=strategy).cuda()
            model.train()
            x = torch.randn(batch_size, 3, 224, 224, device="cuda")
            target = torch.randint(0, 1000, (batch_size,), device="cuda")

            out = model(x)
            loss = nn.CrossEntropyLoss()(out, target)
            loss.backward()
            return True
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                return False
            raise
        finally:
            try:
                if model is not None:
                    model.zero_grad(set_to_none=True)
            except Exception:
                pass
            del model, x, target, out, loss
            torch.cuda.empty_cache()

    # Exponential search to find an upper bound, then binary search.
    max_working = 0
    batch = start
    while batch <= max_batch and can_run(batch):
        max_working = batch
        batch *= 2

    low = max_working + 1
    high = min(batch, max_batch)

    while low <= high:
        mid = (low + high) // 2
        if can_run(mid):
            max_working = mid
            low = mid + 1
        else:
            high = mid - 1

    return max_working

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    print("Finding maximum batch size for ResNet50 (224x224 images)")
    print("-" * 50)

    strategies = ['none', 'aggressive', 'per_block']
    baseline = None

    for strategy in strategies:
        max_batch = find_max_batch_resnet50(strategy)
        if baseline is None:
            baseline = max_batch
        improvement = (max_batch / baseline) if baseline else float('inf')
        print(f"{strategy:<15}: max batch = {max_batch:>3d} ({improvement:.2f}x)")

    print("-" * 50)
    print("\nWith checkpointing, you can often fit larger batches.")
else:
    print("[Run on GPU and ensure torchvision is installed to find max batch sizes]")

Finding maximum batch size for ResNet50 (224x224 images)
--------------------------------------------------
none           : max batch = 177 (1.00x)
aggressive     : max batch = 205 (1.16x)
per_block      : max batch = 256 (1.45x)
--------------------------------------------------

With checkpointing, you can often fit larger batches.


In [10]:
# 5.5 Complete Training Loop: ResNet50 with Checkpointing

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def train_resnet50_with_checkpointing(
    checkpoint_strategy='aggressive',
    batch_size=32,
    num_epochs=3,
    num_samples=256,  # Small for demo
    use_amp=True,
):
    """Complete training loop with checkpointing and mixed precision."""

    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    # Create model
    model = ResNet50Checkpointed(
        num_classes=10,  # Simplified for demo
        checkpoint_strategy=checkpoint_strategy,
    ).cuda()

    # Create synthetic dataset
    X = torch.randn(num_samples, 3, 224, 224)
    y = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler() if (use_amp and torch.cuda.is_available()) else None

    print(f"Training ResNet50 with '{checkpoint_strategy}' checkpointing")
    print(f"Batch size: {batch_size}, Mixed precision: {use_amp and scaler is not None}")
    print("-" * 50)

    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0

        torch.cuda.reset_peak_memory_stats()

        for batch_x, batch_y in loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()

            optimizer.zero_grad(set_to_none=True)

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(batch_x)
                    loss = criterion(outputs, batch_y)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        peak_mem = torch.cuda.max_memory_allocated() / 1e6
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}: loss = {avg_loss:.4f}, peak memory = {peak_mem:.2f} MB")

    print("-" * 50)
    print("Training complete!")
    return model

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    model = train_resnet50_with_checkpointing(
        checkpoint_strategy='aggressive',
        batch_size=32,
        num_epochs=3,
        use_amp=True,
    )
else:
    print("[Run on GPU and ensure torchvision is installed for training demo]")

  scaler = torch.cuda.amp.GradScaler() if (use_amp and torch.cuda.is_available()) else None
  with torch.cuda.amp.autocast():


Training ResNet50 with 'aggressive' checkpointing
Batch size: 32, Mixed precision: True
--------------------------------------------------
Epoch 1/3: loss = 2.6300, peak memory = 1630.90 MB
Epoch 2/3: loss = 1.9535, peak memory = 1630.90 MB
Epoch 3/3: loss = 1.5271, peak memory = 1630.90 MB
--------------------------------------------------
Training complete!
