In [None]:
import torch
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# ✅ Load pretrained ResNet18
model = torchvision.models.resnet18(weights='DEFAULT')
model.eval()

# ✅ Simple dataset (CIFAR-10 subset for quick eval)
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

# ✅ Apply L1 structured pruning (20%) on all conv layers
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name="weight", amount=0.2, n=1, dim=0)

# ✅ Remove pruning reparameterization to finalize weights
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')

# ✅ Quick evaluation
correct, total = 0, 0
with torch.no_grad():
    for i, (images, labels) in enumerate(testloader):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if i == 20:  # only a few batches for speed
            break

print(f"✅ Pruned model accuracy (approx): {100 * correct / total:.2f}%")
