In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time

# Use the same ResNet architecture from your sparse code
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# Dense ResNet (same as your sparse version but without masking)
class DenseResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(DenseResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        strides = [stride] + [1] * (num_blocks - 1)
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def DenseResNet18(num_classes=10):
    return DenseResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def initialize_weights_dense(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

def evaluate_dense(model, data_loader, criterion, device):
    model.eval()
    loss_total = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss_total += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    avg_loss = loss_total / total
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

def train_cifar10_dense_baseline(epochs=200, device='cuda', seed=42):
    """Dense baseline with IDENTICAL setup to your sparse training"""
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Use your EXACT same data transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    # Use your EXACT same dataset split
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )

    train_size = int(0.89 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, [train_size, val_size]
    )

    train_loader = torch.utils.data.DataLoader(
        train_subset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_subset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True
    )

    # Use your EXACT same model architecture (without masking)
    model = DenseResNet18(num_classes=10).to(device)

    # Use your EXACT same training setup
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Use your EXACT same initialization
    initialize_weights_dense(model)

    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        running_loss = 0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        scheduler.step()

        if epoch % 20 == 0 or epoch == epochs - 1:
            val_loss, val_accuracy = evaluate_dense(model, val_loader, criterion, device)
            train_accuracy = 100.0 * correct / total
            print(f"Dense Epoch {epoch+1:3d}: Train: {train_accuracy:.2f}%, Val: {val_accuracy:.2f}%")

    training_time = time.time() - start_time
    test_loss, test_accuracy = evaluate_dense(model, test_loader, criterion, device)

    print(f"\nDense CIFAR-10 training completed: {test_accuracy:.2f}% test accuracy in {training_time:.1f}s")

    return {
        'test_accuracy': test_accuracy,
        'training_time': training_time,
        'final_sparsity': 0.0
    }

def run_cifar10_comparison():
    """Compare your sparse results with dense baseline"""

    print("Running CIFAR-10 Dense vs Sparse Comparison...")
    print("1. Running Dense Baseline on CIFAR-10...")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dense_result = train_cifar10_dense_baseline(epochs=200, device=device, seed=42)

    # Your sparse results (from your code output)
    sparse_accuracy = 95.19
    sparse_std = 0.11
    sparse_sparsity = 79.6

    print("\n" + "="*60)
    print("CIFAR-10: DENSE vs SPARSE COMPARISON")
    print("="*60)
    print(f"Dense Baseline:      {dense_result['test_accuracy']:.2f}% (11.2M parameters)")
    print(f"Dynamic Pruning:     {sparse_accuracy:.2f} ± {sparse_std:.2f}% (~2.3M parameters)")
    print(f"Accuracy difference: {sparse_accuracy - dense_result['test_accuracy']:+.2f} percentage points")
    print(f"Parameter reduction: {sparse_sparsity:.1f}% (~8.9M parameters removed)")

    print(f"\nTraining Time:")
    print(f"Dense:               {dense_result['training_time']:.1f}s")
    print(f"Dynamic Pruning:     ~3553s (from your results)")

    accuracy_diff = sparse_accuracy - dense_result['test_accuracy']
    print(f"\nCIFAR-10 Efficiency Summary:")
    print(f"• Achieved {sparse_sparsity:.1f}% parameter reduction")
    print(f"• Accuracy performance: {accuracy_diff:+.2f} percentage points vs dense")
    print(f"• Single training cycle with progressive sparsification")

    return dense_result

if __name__ == "__main__":
    print("Running CIFAR-10 Dense Baseline...")
    result = run_cifar10_comparison()

Running CIFAR-10 Dense Baseline...
Running CIFAR-10 Dense vs Sparse Comparison...
1. Running Dense Baseline on CIFAR-10...


100%|██████████| 170M/170M [00:04<00:00, 39.5MB/s]


Dense Epoch   1: Train: 18.37%, Val: 24.45%
Dense Epoch  21: Train: 75.77%, Val: 72.60%
Dense Epoch  41: Train: 79.97%, Val: 74.47%
Dense Epoch  61: Train: 81.60%, Val: 77.78%
Dense Epoch  81: Train: 83.86%, Val: 76.69%
Dense Epoch 101: Train: 85.86%, Val: 81.27%
Dense Epoch 121: Train: 88.27%, Val: 81.67%
Dense Epoch 141: Train: 91.10%, Val: 86.07%
Dense Epoch 161: Train: 94.52%, Val: 89.44%
Dense Epoch 181: Train: 97.36%, Val: 91.36%
Dense Epoch 200: Train: 98.15%, Val: 91.60%

Dense CIFAR-10 training completed: 95.27% test accuracy in 3221.6s

CIFAR-10: DENSE vs SPARSE COMPARISON
Dense Baseline:      95.27% (11.2M parameters)
Dynamic Pruning:     95.19 ± 0.11% (~2.3M parameters)
Accuracy difference: -0.08 percentage points
Parameter reduction: 79.6% (~8.9M parameters removed)

Training Time:
Dense:               3221.6s
Dynamic Pruning:     ~3553s (from your results)

CIFAR-10 Efficiency Summary:
• Achieved 79.6% parameter reduction
• Accuracy performance: -0.08 percentage points vs