In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
import torch.nn.functional as F
from sklearn.cluster import KMeans
import os
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device}")

torch.manual_seed(1212)
np.random.seed(1212)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1212)

cuda


In [None]:
class VGG(nn.Module):
    def __init__(self, vgg_name='VGG16', num_classes=10):
        super(VGG, self).__init__()
        self.cfg = {
            'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
        }
        self.features = self._make_layers(self.cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, num_classes)
        )

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                          nn.BatchNorm2d(x),
                          nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

model = VGG('VGG16').to(device)
print(f"# of parameter: {sum(p.numel() for p in model.parameters()):,}")


# of parameter: 15,253,578


In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

score_trainset = torch.utils.data.Subset(trainset, range(0, len(trainset), 50))  # 2% of dataset
score_loader = torch.utils.data.DataLoader(
    score_trainset, batch_size=128, shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Size of training set: {len(trainset)}")
print(f"Size of testing set: {len(testset)}")
print(f"Size fo score subset: {len(score_trainset)}")

100%|██████████| 170M/170M [00:05<00:00, 30.8MB/s]


Size of training set: 50000
Size of testing set: 10000
Size fo score subset: 1000


# Noise data

In [None]:
!git clone https://github.com/UCSC-REAL/cifar-10-100n.git

In [None]:
def apply_symmetric_label_noise(labels, noise_rate, num_classes=10, seed=None):
    labels = labels.clone().long()
    N = labels.size(0)

    g = torch.Generator()
    if seed is not None:
        g.manual_seed(seed)

    flip_mask = torch.rand(N, generator=g) < noise_rate
    num_flip = flip_mask.sum()

    if num_flip == 0:
        return labels, flip_mask
    new_labels = torch.randint(low=0, high=num_classes, size=(num_flip,), generator=g)
    orig = labels[flip_mask]
    same = new_labels == orig
    while same.any():
        new_labels[same] = torch.randint(0, num_classes, (same.sum(),), generator=g)
        same = new_labels == orig

    labels[flip_mask] = new_labels
    return labels, flip_mask

class CIFAR10NWithNoise(Dataset):
    def __init__(self, base_dataset, labels_tensor):
        self.base = base_dataset
        self.labels = labels_tensor.long()

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, _ = self.base[idx]
        label = self.labels[idx].item()
        return img, label

noise_file = torch.load("./cifar-10-100n/data/CIFAR-10_human.pt", weights_only=False)

clean_label  = torch.as_tensor(noise_file['clean_label']).long()
worst_label  = torch.as_tensor(noise_file['worse_label']).long()
aggre_label  = torch.as_tensor(noise_file['aggre_label']).long()
# random_label1 = noise_file["random_label1"]
# random_label2 = noise_file["random_label2"]
# random_label3 = noise_file["random_label3"]

noise_rates = [0.1, 0.3, 0.5]
noisy_labels_dict = {}

for rate in noise_rates:
    noisy_labels, flip_mask = apply_symmetric_label_noise(
        clean_label,
        noise_rate=rate,
        num_classes=10,
        seed=42
    )
    noisy_labels_dict[rate] = noisy_labels
    print(f"noise rate {rate*100:.0f}%: {flip_mask.sum().item()} / {len(clean_label)}")

In [None]:
train_clean = CIFAR10NWithNoise(trainset, clean_label)
# 10% noise
train_noise_10 = CIFAR10NWithNoise(trainset, noisy_labels_dict[0.1])
# 30% noise
train_noise_30 = CIFAR10NWithNoise(trainset, noisy_labels_dict[0.3])
# 50% noise
train_noise_50 = CIFAR10NWithNoise(trainset, noisy_labels_dict[0.5])

# EX: 30% noise：
# trainloader = DataLoader(train_noise_30, batch_size=128, shuffle=True, num_workers=4)


# checkpoint

In [None]:
def save_checkpoint(model, results, optimizer, epoch, filepath, model_type='original'):
    """
    Save model checkpoint

    Args:
        model: Model
        results: Training results dictionary
        optimizer: Optimizer
        epoch: Current epoch
        filepath: Save path
        model_type: Model type ('original', 'snip', 'random')
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'results': results,
        'model_type': model_type
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")


def load_checkpoint(model, optimizer, filepath):
    """
    Load model checkpoint

    Returns:
        model, optimizer, epoch, results, model_type
    """
    if not os.path.exists(filepath):
        print(f"Checkpoint not found: {filepath}")
        return model, optimizer, 0, None, None


    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    print(f"Checkpoint loaded from {filepath}")
    return model, optimizer, checkpoint['epoch'], checkpoint['results'], checkpoint['model_type']


def save_results(results, filepath='results.json'):
    """Save training results as JSON"""
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"Results saved to {filepath}")


def load_results(filepath='results.json'):
    """Load training results"""
    if not os.path.exists(filepath):
        print(f"Results file not found: {filepath}")
        return None

    with open(filepath, 'r') as f:
        results = json.load(f)
    print(f"Results loaded from {filepath}")
    return results


def save_pruning_masks(keep_masks, filepath='pruning_masks.pth'):
    """Save pruning masks"""
    torch.save(keep_masks, filepath)
    print(f"Pruning masks saved to {filepath}")


def load_pruning_masks(filepath='pruning_masks.pth'):
    """Load pruning masks"""
    if not os.path.exists(filepath):
        print(f"Pruning masks not found: {filepath}")
        return None

    keep_masks = torch.load(filepath, map_location=device)
    print(f"Pruning masks loaded from {filepath}")
    return keep_masks

def save_masks(mask_data, filepath):
    """Save pruning masks"""
    torch.save(mask_data, filepath)
    print(f"✓ Masks saved: {filepath}")


def load_masks(filepath):
    """Load pruning masks"""
    if not os.path.exists(filepath):
        print(f"No masks found: {filepath}")
        return None

    masks = torch.load(filepath, map_location=device)
    print(f"✓ Masks loaded: {filepath}")
    return masks

# SNIP Function


In [None]:
def snip_forward_conv2d(self, x):
    """Custom Conv2d forward for computing connection sensitivity"""
    return nn.functional.conv2d(x, self.weight * self.weight_mask, self.bias,
                                self.stride, self.padding, self.dilation, self.groups)

def snip_forward_linear(self, x):
    """Custom Linear forward for computing connection sensitivity"""
    return nn.functional.linear(x, self.weight * self.weight_mask, self.bias)

def apply_prune_mask(model, keep_masks):
    """Apply pruning masks to model weights"""
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear),
        model.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            def hook(grads):
                return grads * keep_mask
            return hook

        # Register hook to ensure pruned weights have zero gradients
        layer.weight.data[keep_mask == 0.] = 0.
        layer.weight.register_hook(hook_factory(keep_mask))

def snip_pruning(model, dataloader, sparsity=0.5):
    """
    Perform SNIP pruning

    Args:
        model: Model to prune
        dataloader: Data loader for computing connection sensitivity
        sparsity: Pruning rate (0.5 means prune 50% of parameters)

    Returns:
        keep_masks: Keep masks for each layer
    """
    # Save original requires_grad state
    original_requires_grad = {}

    # Replace forward methods of all prunable layers with custom versions
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            original_requires_grad[name] = layer.weight.requires_grad
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            layer.weight.requires_grad = False
        elif isinstance(layer, nn.Linear):
            original_requires_grad[name] = layer.weight.requires_grad
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            layer.weight.requires_grad = False

    # Replace forward methods
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            layer.forward = snip_forward_conv2d.__get__(layer, nn.Conv2d)
        elif isinstance(layer, nn.Linear):
            layer.forward = snip_forward_linear.__get__(layer, nn.Linear)

    # Compute connection sensitivity
    criterion = nn.CrossEntropyLoss()

    inputs, targets = next(iter(dataloader))
    inputs, targets = inputs.to(device), targets.to(device)

    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()

    # Collect gradients (connection sensitivity) from all layers
    grads_abs = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grads_abs.append(torch.abs(layer.weight_mask.grad))

    # Concatenate all gradients into a 1D vector
    all_scores = torch.cat([torch.flatten(x) for x in grads_abs])

    # Compute threshold
    num_params_to_keep = int(len(all_scores) * (1 - sparsity))
    threshold, _ = torch.kthvalue(all_scores, len(all_scores) - num_params_to_keep)

    # Generate keep masks for each layer
    keep_masks = []
    for g in grads_abs:
        keep_masks.append((g >= threshold).float())

    print(f"SNIP pruning completed! Pruning rate: {sparsity * 100:.1f}%")
    print(f"Threshold: {threshold:.6f}")

    # Restore original forward methods and requires_grad states
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            layer.forward = nn.Conv2d.forward.__get__(layer, nn.Conv2d)
            layer.weight.requires_grad = original_requires_grad[name]
            del layer.weight_mask
        elif isinstance(layer, nn.Linear):
            layer.forward = nn.Linear.forward.__get__(layer, nn.Linear)
            layer.weight.requires_grad = original_requires_grad[name]
            del layer.weight_mask

    return keep_masks

print("SNIP algorithm implementation completed!")

SNIP 演算法實作完成！


# Random Function

In [None]:
def random_pruning(model, sparsity=0.5):
    """
    Perform random pruning - randomly select weights to keep

    Args:
        model: Model to prune
        sparsity: Pruning rate (0.5 means prune 50% of parameters)

    Returns:
        keep_masks: Keep masks for each layer
    """
    print(f"Performing random pruning (pruning rate: {sparsity * 100:.1f}%)...")

    keep_masks = []
    total_params = 0
    kept_params = 0

    # Collect weights from all prunable layers
    prunable_layers = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            prunable_layers.append(layer)

    # Generate random masks for each layer
    for layer in prunable_layers:
        weight_shape = layer.weight.shape
        total_elements = layer.weight.numel()

        # Generate random mask
        flat_mask = torch.ones(total_elements)
        num_to_prune = int(total_elements * sparsity)

        # Randomly select positions to prune
        prune_indices = torch.randperm(total_elements)[:num_to_prune]
        flat_mask[prune_indices] = 0

        # Reshape to original shape
        keep_mask = flat_mask.view(weight_shape).to(device)
        keep_masks.append(keep_mask)

        total_params += total_elements
        kept_params += (keep_mask == 1).sum().item()

    actual_sparsity = 1 - (kept_params / total_params)
    print(f"Random pruning completed!")
    print(f"Actual pruning rate: {actual_sparsity * 100:.2f}%")
    print(f"Parameters kept: {kept_params:,} / {total_params:,}")

    return keep_masks

# GSRP structured (Filter-level) Pruning

In [None]:
def get_conv_layers(model):
    conv_infos = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_infos.append({
                "name": name,
                "layer": module,
                "out_channels": module.out_channels
            })
    return conv_infos

def compute_gsrp_scores(model, loader, device, T=3):
    model.eval()
    conv_infos = get_conv_layers(model)

    for info in conv_infos:
        n_filters = info["out_channels"]
        info["signs"] = torch.zeros(n_filters, T, dtype=torch.int8)
        info["snip"] = torch.zeros(n_filters, dtype=torch.float32)

    data_iter = iter(loader)

    for t in range(T):
        try:
            x, y = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            x, y = next(data_iter)

        x, y = x.to(device), y.to(device)
        for p in model.parameters():
            if p.grad is not None:
                p.grad.zero_()

        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()

        for info in conv_infos:
            layer = info["layer"]
            g = layer.weight.grad.detach()
            w = layer.weight.detach()

            g_scalar = g.view(g.size(0), -1).mean(dim=1)
            sign = torch.sign(g_scalar)
            sign[sign == 0] = 1
            info["signs"][:, t] = sign.cpu().to(torch.int8)

            snip = (g * w).abs().view(g.size(0), -1).sum(dim=1)
            info["snip"] += snip.cpu()

    for info in conv_infos:
        signs = info["signs"].float()
        majority = torch.sign(signs.sum(dim=1))
        majority[majority == 0] = 1
        agree = (signs == majority.unsqueeze(1)).float().mean(dim=1)
        C = agree

        B = info["snip"]
        if B.max() > 0:
            B = B / B.max()

        S = C * B
        info["C"] = C
        info["B"] = B
        info["S"] = S

    return conv_infos

def collect_conv_activations(model, loader, device, conv_infos, num_batches=10):
    model.eval()

    for info in conv_infos:
        F = info["out_channels"]
        info["act_sum"] = torch.zeros(F)
        info["act_count"] = 0

    handles = []
    for info in conv_infos:
        layer = info["layer"]

        def make_hook(info_ref):
            def hook(module, inp, out):
                act = out.detach().mean(dim=(0, 2, 3)).cpu()
                info_ref["act_sum"] += act
                info_ref["act_count"] += 1
            return hook

        h = layer.register_forward_hook(make_hook(info))
        handles.append(h)

    with torch.no_grad():
        for i, (x, _) in enumerate(loader):
            if i >= num_batches:
                break
            x = x.to(device)
            model(x)

    for h in handles:
        h.remove()

    for info in conv_infos:
        if info["act_count"] > 0:
            info["acts"] = info["act_sum"] / info["act_count"]
        else:
            info["acts"] = info["act_sum"]

    return conv_infos

def choose_filters_by_gsrp(conv_infos, target_sparsity=0.9, use_clustering=True):
    for info in conv_infos:
        F = info["out_channels"]
        keep_k = max(1, int(F * (1 - target_sparsity)))
        S = info["S"]

        if not use_clustering or keep_k >= F:
            _, idx = torch.topk(S, k=keep_k)
            mask = torch.zeros(F, dtype=torch.bool)
            mask[idx] = True
            info["mask"] = mask
            continue

        acts = info["acts"].view(-1, 1).numpy()
        K = min(keep_k, F)

        if K <= 1:
            idx = torch.argmax(S).unsqueeze(0)
            mask = torch.zeros(F, dtype=torch.bool)
            mask[idx] = True
            info["mask"] = mask
            continue

        kmeans = KMeans(n_clusters=K, n_init=10, random_state=12)
        labels = kmeans.fit_predict(acts)

        S_np = S.numpy()
        keep_idx = []
        for c in range(K):
            members = np.where(labels == c)[0]
            if len(members) == 0:
                continue
            best = members[S_np[members].argmax()]
            keep_idx.append(best)

        keep_idx = list(sorted(set(keep_idx)))

        if len(keep_idx) < keep_k:
            remaining = [i for i in range(F) if i not in keep_idx]
            extra_needed = keep_k - len(keep_idx)
            extra = sorted(remaining, key=lambda i: -S_np[i])[:extra_needed]
            keep_idx.extend(extra)

        mask = torch.zeros(F, dtype=torch.bool)
        mask[keep_idx] = True
        info["mask"] = mask

    return conv_infos

@torch.no_grad()
def apply_masks(conv_infos):
    for info in conv_infos:
        mask = info.get("mask", None)
        if mask is None:
            continue
        layer = info["layer"]
        w = layer.weight
        pruned = ~mask
        w[pruned, :, :, :] = 0.0
        if layer.bias is not None:
            layer.bias[pruned] = 0.0

# GSRP weight-wise (unstructure) pruning


In [None]:

def get_prunable_params(model):
    param_infos = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            param_infos.append({
                "name": name,
                "module": module,
                "weight": module.weight,
                "shape": module.weight.shape,
            })
    return param_infos

def compute_gsrp_scores_weightwise(model, loader, device, T=3):
    model.eval()
    param_infos = get_prunable_params(model)

    for info in param_infos:
        shape = info["shape"]
        info["sign_sum"] = torch.zeros(shape, dtype=torch.float32, device=device)
        info["snip"] = torch.zeros(shape, dtype=torch.float32, device=device)

    data_iter = iter(loader)

    for t in range(T):
        try:
            x, y = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            x, y = next(data_iter)

        x, y = x.to(device), y.to(device)
        model.zero_grad(set_to_none=True)

        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()

        for info in param_infos:
            w = info["weight"]
            g = w.grad.detach()

            sign = torch.sign(g)
            sign[sign == 0] = 1.0
            info["sign_sum"] += sign

            snip = (g * w).abs()
            info["snip"] += snip

    for info in param_infos:
        sign_sum = info["sign_sum"]
        snip_sum = info["snip"]

        C = (T + sign_sum.abs()) / (2.0 * T)

        B = snip_sum
        maxB = B.max()
        if maxB > 0:
            B = B / maxB

        S = C * B

        info["C"] = C
        info["B"] = B
        info["S"] = S

    return param_infos

def choose_weights_by_gsrp_weightwise(param_infos, target_sparsity=0.9):
    all_scores = []
    for info in param_infos:
        all_scores.append(info["S"].view(-1))
    all_scores_flat = torch.cat(all_scores)

    N = all_scores_flat.numel()
    keep_k = max(1, int(N * (1.0 - target_sparsity)))

    if keep_k >= N:
        thresh = all_scores_flat.min() - 1.0
    else:
        topk_vals, _ = torch.topk(all_scores_flat, keep_k, largest=True)
        thresh = topk_vals[-1].item()

    for info in param_infos:
        S = info["S"]
        mask = (S >= thresh)
        info["mask"] = mask

    return param_infos

@torch.no_grad()
def apply_weightwise_masks(param_infos):
    for info in param_infos:
        mask = info.get("mask", None)
        if mask is None:
            continue
        w = info["weight"]
        m = mask.to(w.device).to(w.dtype)
        w.mul_(m)

# Structure pruning (for method 3&4)


In [None]:

def prune_vgg_model_physically(original_model, gsrp_infos):
    """
    Create a new VGG model with filters physically removed.
    This creates a smaller, faster model.
    """
    from copy import deepcopy

    conv_masks = {}
    for info in gsrp_infos:
        conv_masks[info['name']] = info['mask']

    new_features = []
    prev_out_channels = 3

    for module in original_model.features:
        if isinstance(module, nn.Conv2d):
            conv_name = None
            for name, m in original_model.named_modules():
                if m is module:
                    conv_name = name
                    break

            if conv_name in conv_masks:
                keep_mask = conv_masks[conv_name]
                keep_indices = torch.where(keep_mask)[0]
            else:
                keep_indices = torch.arange(module.out_channels)

            new_conv = nn.Conv2d(
                in_channels=prev_out_channels,
                out_channels=len(keep_indices),
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=(module.bias is not None)
            )

            if prev_out_channels == module.in_channels:
                new_conv.weight.data = module.weight.data[keep_indices].clone()
            else:
                new_conv.weight.data = module.weight.data[keep_indices][:, :prev_out_channels, :, :].clone()

            if module.bias is not None:
                new_conv.bias.data = module.bias.data[keep_indices].clone()

            new_features.append(new_conv)
            prev_out_channels = len(keep_indices)

        elif isinstance(module, nn.BatchNorm2d):
            new_bn = nn.BatchNorm2d(prev_out_channels)

            if prev_out_channels <= module.num_features:
                new_bn.weight.data = module.weight.data[:prev_out_channels].clone()
                new_bn.bias.data = module.bias.data[:prev_out_channels].clone()
                new_bn.running_mean = module.running_mean[:prev_out_channels].clone()
                new_bn.running_var = module.running_var[:prev_out_channels].clone()

            new_bn.num_batches_tracked = module.num_batches_tracked.clone()
            new_features.append(new_bn)

        else:
            new_features.append(deepcopy(module))

    # Build new classifier
    # After all pooling (32->16->8->4->2->1) for CIFAR-10, spatial size is 1x1
    # So input features = prev_out_channels * 1 * 1
    new_first_linear_in = prev_out_channels

    new_classifier = nn.Sequential(
        nn.Linear(new_first_linear_in, 512),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(512, 512),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(512, 10)
    )

    orig_first_linear = original_model.classifier[0]
    if new_first_linear_in <= orig_first_linear.in_features:
        new_classifier[0].weight.data = orig_first_linear.weight.data[:, :new_first_linear_in].clone()
    else:
        new_classifier[0].weight.data[:, :orig_first_linear.in_features] = orig_first_linear.weight.data.clone()

    new_classifier[0].bias.data = orig_first_linear.bias.data.clone()

    for i in [3, 6]:
        new_classifier[i].weight.data = original_model.classifier[i].weight.data.clone()
        new_classifier[i].bias.data = original_model.classifier[i].bias.data.clone()

    class PrunedVGG(nn.Module):
        def __init__(self, features, classifier):
            super().__init__()
            self.features = features
            self.classifier = classifier

        def forward(self, x):
            out = self.features(x)
            out = out.view(out.size(0), -1)
            out = self.classifier(out)
            return out

    pruned_model = PrunedVGG(nn.Sequential(*new_features), new_classifier)
    return pruned_model

 # Train and test

In [None]:
def train(model, loader, optimizer, criterion,device, mask_infos=None, weight_mask_infos=None):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = logits.max(1)
        total += y.size(0)
        total_correct += predicted.eq(y).sum().item()

    train_loss = total_loss / len(trainloader)
    train_acc = 100. * total_correct / total

    return train_loss, train_acc

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item()
        _, predicted = logits.max(1)
        total += y.size(0)
        total_correct += predicted.eq(y).sum().item()

    total_loss = total_loss / len(testloader)
    test_acc = 100. * total_correct / total

    return total_loss, test_acc

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_nonzero_parameters(model):
    return sum((p != 0).sum().item() for p in model.parameters() if p.requires_grad)

# Hyperparameter

In [None]:
LEARNING_RATE = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100

SPARSITY = 0.9
GSRP_WEIGHTWISE_SPARSITY = 0.9
GSRP_STRUCTURED_SPARSITY =  0.3
GSRP_STRUCTURED_ONLY = 0.5
GSRP_HYBRID_WEIGHTWISE = 0.8

T_SCORE = 3
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
results = {
    'original':        {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
    'GSRP_weightwise': {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
    'GSRP_hybrid':     {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
    'GSRP_structure':  {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
    'snip':            {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
    'random':          {'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': [], 'epoch_time': []},
}


# Method1: original

In [None]:
print("=" * 70)
print("Training original model (no pruning)")
print("=" * 70)

# Check if checkpoint exists
checkpoint_path = os.path.join(CHECKPOINT_DIR, 'model_original.pth')
start_epoch = 0

model_original = VGG('VGG16').to(device)
criterion = nn.CrossEntropyLoss()
optimizer_original = optim.SGD(model_original.parameters(), lr=LEARNING_RATE,
                               momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler_original = optim.lr_scheduler.MultiStepLR(optimizer_original,
                                                    milestones=[25, 40], gamma=0.1)

# Try to load checkpoint
if os.path.exists(checkpoint_path):
    print(f"Checkpoint found, load it? (y/n)")
    # Auto-load in Colab
    load_existing = True  # Set to False to retrain

    if load_existing:
        model_original, optimizer_original, start_epoch, loaded_results, _ = \
            load_checkpoint(model_original, optimizer_original, checkpoint_path)
        if loaded_results is not None:
            results['original'] = loaded_results
            print(f"Continue training from Epoch {start_epoch}")

total_params = count_parameters(model_original)
print(f"# of parameter: {total_params:,}")

if start_epoch < EPOCHS:
    start_time = time.time()

    for epoch in range(start_epoch, EPOCHS):
        e_start_time = time.time()
        train_loss, train_acc = train(model_original, trainloader,
                                           optimizer_original, criterion, device)
        test_loss, test_acc = evaluate(model_original, testloader, criterion, device)
        e_end_time = time.time()
        epoch_time = e_end_time - e_start_time

        results['original']['train_acc'].append(train_acc)
        results['original']['test_acc'].append(test_acc)
        results['original']['train_loss'].append(train_loss)
        results['original']['test_loss'].append(test_loss)
        results['original']['epoch_time'].append(epoch_time)

        scheduler_original.step()

        print(f"Epoch [{epoch+1}/{EPOCHS}] "
              f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model_original, results['original'], optimizer_original,
                          epoch + 1, checkpoint_path, 'original')

    original_time = time.time() - start_time

    # Training complete, save final checkpoint
    save_checkpoint(model_original, results['original'], optimizer_original,
                   EPOCHS, checkpoint_path, 'original')

    print(f"Training time: {original_time/60:.2f} minutes")
else:
    print("Model training already completed")

print(f"Accuracy: {results['original']['test_acc'][-1]:.2f}%")

# Method 2: GSRP weight-wise (unstructured)

In [None]:
print("\n" + "=" * 80)
print("METHOD 2: GSRP WEIGHT-WISE (UNSTRUCTURED)")
print("=" * 80)

checkpoint_path_GSRP_weightwise = os.path.join(CHECKPOINT_DIR, 'model_GSRP_weightwise.pth')
masks_path_GSRP_weightwise = os.path.join(CHECKPOINT_DIR, 'masks_GSRP_weightwise.pth')
start_epoch = 0

model_GSRP_weightwise = VGG('VGG16').to(device)
criterion = nn.CrossEntropyLoss()


# Try to load checkpoint
if os.path.exists(masks_path_GSRP_weightwise):
    load_existing = True  # Set to False to retrain

    if load_existing:
        optimizer_GSRP_weightwise = optim.SGD(model_GSRP_weightwise.parameters(), lr=LEARNING_RATE,
                                   momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        model_GSRP_weightwise, optimizer_GSRP_weightwise, start_epoch, loaded_results, _ = \
            load_checkpoint(model_GSRP_weightwise, optimizer_GSRP_weightwise, checkpoint_path_GSRP_weightwise)
        if loaded_results is not None:
            results['GSRP_weightwise'] = loaded_results
            # Load pruning masks
            keep_masks = load_pruning_masks(masks_path_GSRP_weightwise)
            if keep_masks is not None:
                apply_prune_mask(model_GSRP_weightwise, keep_masks)
            print(f"Continue training from Epoch {start_epoch}")

if start_epoch == 0:
    # Perform GSRP_weightwise pruning
    param_infos = compute_gsrp_scores_weightwise(model_GSRP_weightwise, score_loader, device, T=T_SCORE)
    param_infos = choose_weights_by_gsrp_weightwise(param_infos, target_sparsity=GSRP_HYBRID_WEIGHTWISE)
    keep_masks = [info["mask"].to(device).float() for info in param_infos]

    apply_prune_mask(model_GSRP_weightwise, keep_masks)

    # Save pruning masks
    save_pruning_masks(keep_masks, masks_path_GSRP_weightwise)

if start_epoch < EPOCHS:
    optimizer_GSRP_weightwise = optim.SGD(model_GSRP_weightwise.parameters(), lr=LEARNING_RATE,
                               momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler_GSRP_weightwise = optim.lr_scheduler.MultiStepLR(optimizer_GSRP_weightwise,
                                                   milestones=[25, 40], gamma=0.1)

    start_time = time.time()

    for epoch in range(start_epoch, EPOCHS):
        e_start_time = time.time()
        train_loss, train_acc = train(model_GSRP_weightwise, trainloader,
                                           optimizer_GSRP_weightwise, criterion, device)
        test_loss, test_acc = evaluate(model_GSRP_weightwise, testloader, criterion, device)

        e_end_time = time.time()
        epoch_time = e_end_time - e_start_time

        results['GSRP_weightwise']['train_acc'].append(train_acc)
        results['GSRP_weightwise']['test_acc'].append(test_acc)
        results['GSRP_weightwise']['train_loss'].append(train_loss)
        results['GSRP_weightwise']['test_loss'].append(test_loss)
        results['GSRP_weightwise']['epoch_time'].append(epoch_time)

        scheduler_GSRP_weightwise.step()

        print(f"Epoch [{epoch+1}/{EPOCHS}] "
              f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model_GSRP_weightwise, results['GSRP_weightwise'], optimizer_GSRP_weightwise,
                          epoch + 1, checkpoint_path_GSRP_weightwise, 'GSRP_weightwise')

    GSRP_weightwise_time = time.time() - start_time

    # Training complete, save final checkpoint
    save_checkpoint(model_GSRP_weightwise, results['GSRP_weightwise'], optimizer_GSRP_weightwise,
                   EPOCHS, checkpoint_path_GSRP_weightwise, 'GSRP_weightwise')

    print(f"Training time: {GSRP_weightwise_time/60:.2f} minutes")
else:
    print("Model training already completed")

print(f"Accuracy: {results['GSRP_weightwise']['test_acc'][-1]:.2f}%")


METHOD 2: GSRP WEIGHT-WISE (UNSTRUCTURED)
找不到 checkpoint: ./checkpoints/model_GSRP_weightwise.pth
剪枝遮罩已儲存至 ./checkpoints/masks_GSRP_weightwise.pth
Epoch [1/100] Train Loss: 1.725, Train Acc: 32.59%, Test Loss: 1.441, Test Acc: 47.55%
Epoch [2/100] Train Loss: 1.181, Train Acc: 57.84%, Test Loss: 1.119, Test Acc: 62.46%
Epoch [3/100] Train Loss: 0.924, Train Acc: 68.50%, Test Loss: 0.930, Test Acc: 69.07%
Epoch [4/100] Train Loss: 0.799, Train Acc: 73.47%, Test Loss: 0.788, Test Acc: 73.51%
Epoch [5/100] Train Loss: 0.715, Train Acc: 76.38%, Test Loss: 0.725, Test Acc: 76.83%
Epoch [6/100] Train Loss: 0.645, Train Acc: 79.03%, Test Loss: 0.783, Test Acc: 74.21%
Epoch [7/100] Train Loss: 0.587, Train Acc: 81.04%, Test Loss: 0.785, Test Acc: 75.16%
Epoch [8/100] Train Loss: 0.547, Train Acc: 82.31%, Test Loss: 0.685, Test Acc: 77.84%
Epoch [9/100] Train Loss: 0.506, Train Acc: 83.57%, Test Loss: 0.547, Test Acc: 82.34%
Epoch [10/100] Train Loss: 0.481, Train Acc: 84.52%, Test Loss: 0.466

# Method 3: GSRP Hybrid (Structured + Weight-wise)

In [None]:
print("\n" + "=" * 80)
print("METHOD 3: GSRP Hybrid (Structured + Weight-wise) Pruning")
print("=" * 80)
start_epoch = 0
checkpoint_path_GSRP_hybrid = os.path.join(CHECKPOINT_DIR, 'model_GSRP_hybrid.pth')


base_model = VGG('VGG16').to(device)

# Step 1: Structured pruning with physical removal
gsrp_infos3 = compute_gsrp_scores(base_model, score_loader, device, T=T_SCORE)
gsrp_infos3 = collect_conv_activations(base_model, score_loader, device, gsrp_infos3, num_batches=5)
gsrp_infos3 = choose_filters_by_gsrp(
    gsrp_infos3,
    target_sparsity=GSRP_STRUCTURED_SPARSITY,
    use_clustering=True
)

model_GSRP_hybrid = prune_vgg_model_physically(base_model, gsrp_infos3).to(device)

print(f"After structured pruning:")
print(f"  Original: {sum(p.numel() for p in base_model.parameters()):,} params")
print(f"  Pruned:   {sum(p.numel() for p in model_GSRP_hybrid.parameters()):,} params")

# Step 2: Weight-wise pruning on the physically pruned model
param_infos3 = compute_gsrp_scores_weightwise(model_GSRP_hybrid, score_loader, device, T=T_SCORE)
param_infos3 = choose_weights_by_gsrp_weightwise(param_infos3, target_sparsity=GSRP_HYBRID_WEIGHTWISE)

keep_masks3 = [info["mask"].to(device).float() for info in param_infos3]
apply_prune_mask(model_GSRP_hybrid, keep_masks3)


total_params = sum(p.numel() for p in model_GSRP_hybrid.parameters())
zero_params = sum((p == 0).sum().item() for p in model_GSRP_hybrid.parameters())
actual_sparsity = zero_params / total_params
print(f"\nAfter weight-wise pruning on top:")
print(f"  Total parameters: {total_params:,}")
print(f"  Zero parameters: {zero_params:,}")
print(f"  Weight sparsity: {actual_sparsity:.2%}")
print(f"  Combined reduction: Physical pruning + {actual_sparsity:.0%} zeros")

criterion = nn.CrossEntropyLoss()
optimizer_GSRP_hybrid = optim.SGD(
    model_GSRP_hybrid.parameters(),
    lr=LEARNING_RATE,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

scheduler_GSRP_hybrid = optim.lr_scheduler.MultiStepLR(
    optimizer_GSRP_hybrid, milestones=[25, 40], gamma=0.1
)

start_time = time.time()

for epoch in range(start_epoch, EPOCHS):
    e_start_time = time.time()
    train_loss, train_acc = train(model_GSRP_hybrid, trainloader,
                                        optimizer_GSRP_hybrid, criterion, device)
    test_loss, test_acc = evaluate(model_GSRP_hybrid, testloader, criterion, device)
    e_end_time = time.time()
    epoch_time = e_end_time - e_start_time

    results['GSRP_hybrid']['train_acc'].append(train_acc)
    results['GSRP_hybrid']['test_acc'].append(test_acc)
    results['GSRP_hybrid']['train_loss'].append(train_loss)
    results['GSRP_hybrid']['test_loss'].append(test_loss)
    results['GSRP_hybrid']['epoch_time'].append(epoch_time)


    scheduler_GSRP_hybrid.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
          f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_checkpoint(model_GSRP_hybrid, results['GSRP_hybrid'], optimizer_GSRP_hybrid,
                      epoch + 1, checkpoint_path_GSRP_hybrid, 'GSRP_hybrid')

GSRP_hybrid_time = time.time() - start_time

# Training complete, save final checkpoint
save_checkpoint(model_GSRP_hybrid, results['GSRP_hybrid'], optimizer_GSRP_hybrid,
                EPOCHS, checkpoint_path_GSRP_hybrid, 'GSRP_hybrid')

print(f"Training time: {GSRP_hybrid_time/60:.2f} minutes")

print(f"Accuracy: {results['GSRP_hybrid']['test_acc'][-1]:.2f}%")


METHOD 3: GSRP Hybrid (Structured + Weight-wise) Pruning
After structured pruning:
  Original: 15,253,578 params
  Pruned:   7,649,827 params

After weight-wise pruning on top:
  Total parameters: 7,649,827
  Zero parameters: 6,114,904
  Weight sparsity: 79.94%
  Combined reduction: Physical pruning + 80% zeros
Epoch [1/50] Train Loss: 1.783, Train Acc: 30.08%, Test Loss: 1.499, Test Acc: 42.15%
Epoch [2/50] Train Loss: 1.321, Train Acc: 52.27%, Test Loss: 1.237, Test Acc: 58.25%
Epoch [3/50] Train Loss: 1.037, Train Acc: 64.26%, Test Loss: 0.901, Test Acc: 68.92%
Epoch [4/50] Train Loss: 0.891, Train Acc: 69.82%, Test Loss: 0.839, Test Acc: 71.15%
Epoch [5/50] Train Loss: 0.784, Train Acc: 73.79%, Test Loss: 0.721, Test Acc: 75.35%
Epoch [6/50] Train Loss: 0.714, Train Acc: 76.64%, Test Loss: 0.759, Test Acc: 74.75%
Epoch [7/50] Train Loss: 0.659, Train Acc: 78.50%, Test Loss: 0.663, Test Acc: 78.15%
Epoch [8/50] Train Loss: 0.610, Train Acc: 80.19%, Test Loss: 0.664, Test Acc: 78.12

# Method 4: GSRP structure(gradient+cluster)

In [None]:
print("\n" + "=" * 80)
print("METHOD 4: GSRP Structured (Physical Only) Pruning")
print("=" * 80)


checkpoint_path_GSRP_structure = os.path.join(CHECKPOINT_DIR, 'model_GSRP_structure.pth')

model_GSRP_structure = VGG('VGG16').to(device)
GSRP_structure_temp = VGG('VGG16').to(device)

gsrp_infos = compute_gsrp_scores(model_GSRP_structure, score_loader, device, T=T_SCORE)
gsrp_infos = collect_conv_activations(model_GSRP_structure, score_loader, device, gsrp_infos, num_batches=5)
gsrp_infos = choose_filters_by_gsrp(gsrp_infos, target_sparsity=GSRP_STRUCTURED_ONLY, use_clustering=True)

model_GSRP_structure  = prune_vgg_model_physically(model_GSRP_structure, gsrp_infos).to(device)

# Count actual parameters
original_params = sum(p.numel() for p in GSRP_structure_temp.parameters())
pruned_params = sum(p.numel() for p in model_GSRP_structure.parameters())
reduction = 100 * (1 - pruned_params / original_params)

print(f"Original parameters: {original_params:,}")
print(f"Pruned parameters: {pruned_params:,}")
print(f"Parameter reduction: {reduction:.2f}%")
print(f"Model is now physically smaller and faster!")

optimizer_GSRP_structure = optim.SGD(model_GSRP_structure.parameters(),
                                     lr=LEARNING_RATE,
                                     momentum=MOMENTUM,
                                     weight_decay=WEIGHT_DECAY)
scheduler_GSRP_structure = optim.lr_scheduler.MultiStepLR(
    optimizer_GSRP_structure, milestones=[25, 40], gamma=0.1
)

for epoch in range(start_epoch, EPOCHS):
    e_start_time = time.time()
    train_loss, train_acc = train(model_GSRP_structure, trainloader,
                                        optimizer_GSRP_structure, criterion, device)
    test_loss, test_acc = evaluate(model_GSRP_structure, testloader, criterion, device)
    e_end_time = time.time()
    epoch_time = e_end_time - e_start_time

    results['GSRP_structure']['train_acc'].append(train_acc)
    results['GSRP_structure']['test_acc'].append(test_acc)
    results['GSRP_structure']['train_loss'].append(train_loss)
    results['GSRP_structure']['test_loss'].append(test_loss)
    results['GSRP_structure']['epoch_time'].append(epoch_time)

    scheduler_GSRP_structure.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
          f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_checkpoint(model_GSRP_structure, results['GSRP_structure'], optimizer_GSRP_structure,
                      epoch + 1, checkpoint_path_GSRP_structure, 'GSRP_structure')

GSRP_structure_time = time.time() - start_time

# Training complete, save final checkpoint
save_checkpoint(model_GSRP_structure, results['GSRP_structure'], optimizer_GSRP_structure,
                EPOCHS, checkpoint_path_GSRP_structure, 'GSRP_structure')

print(f"Training time: {GSRP_structure_time/60:.2f} minutes")

print(f"Accuracy: {results['GSRP_structure']['test_acc'][-1]:.2f}%")


METHOD 4: GSRP Structured (Physical Only) Pruning
Original parameters: 15,253,578
Pruned parameters: 4,083,754
Parameter reduction: 73.23%
Model is now physically smaller and faster!
Epoch [1/50] Train Loss: 1.678, Train Acc: 35.22%, Test Loss: 1.289, Test Acc: 53.66%
Epoch [2/50] Train Loss: 1.171, Train Acc: 58.69%, Test Loss: 1.049, Test Acc: 63.78%
Epoch [3/50] Train Loss: 0.931, Train Acc: 68.10%, Test Loss: 0.860, Test Acc: 70.43%
Epoch [4/50] Train Loss: 0.799, Train Acc: 73.00%, Test Loss: 0.760, Test Acc: 75.06%
Epoch [5/50] Train Loss: 0.711, Train Acc: 76.36%, Test Loss: 0.699, Test Acc: 76.72%
Epoch [6/50] Train Loss: 0.649, Train Acc: 78.42%, Test Loss: 0.666, Test Acc: 78.22%
Epoch [7/50] Train Loss: 0.600, Train Acc: 80.32%, Test Loss: 0.586, Test Acc: 80.41%
Epoch [8/50] Train Loss: 0.558, Train Acc: 81.72%, Test Loss: 0.558, Test Acc: 81.39%
Epoch [9/50] Train Loss: 0.526, Train Acc: 82.63%, Test Loss: 0.509, Test Acc: 83.18%
Epoch [10/50] Train Loss: 0.495, Train Acc

# Method 5: SNIP


In [None]:
print("\n" + "=" * 80)
print("METHOD 5: SNIP Pruning")
print("=" * 80)

checkpoint_path_snip = os.path.join(CHECKPOINT_DIR, 'model_snip.pth')
masks_path_snip = os.path.join(CHECKPOINT_DIR, 'masks_snip.pth')
start_epoch = 0

model_snip = VGG('VGG16').to(device)
total_params = sum(p.numel() for p in model_snip.parameters())
# Try to load checkpoint
if os.path.exists(checkpoint_path_snip):
    load_existing = True  # Set to False to retrain

    if load_existing:
        optimizer_snip = optim.SGD(model_snip.parameters(), lr=LEARNING_RATE,
                                   momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        model_snip, optimizer_snip, start_epoch, loaded_results, _ = \
            load_checkpoint(model_snip, optimizer_snip, checkpoint_path_snip)
        if loaded_results is not None:
            results['snip'] = loaded_results
            # Load pruning masks
            keep_masks = load_pruning_masks(masks_path_snip)
            if keep_masks is not None:
                apply_prune_mask(model_snip, keep_masks)
            print(f"Continue training from Epoch {start_epoch}")

if start_epoch == 0:
    # Perform SNIP pruning
    keep_masks = snip_pruning(model_snip, score_loader, sparsity=SPARSITY)
    apply_prune_mask(model_snip, keep_masks)

    # Save pruning masks
    save_pruning_masks(keep_masks, masks_path_snip)

nonzero_params = count_nonzero_parameters(model_snip)
print(f"# of parameter: {nonzero_params:,} ({nonzero_params/total_params*100:.2f}%)")

if start_epoch < EPOCHS:
    optimizer_snip = optim.SGD(model_snip.parameters(), lr=LEARNING_RATE,
                               momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler_snip = optim.lr_scheduler.MultiStepLR(optimizer_snip,
                                                   milestones=[25, 40], gamma=0.1)

    start_time = time.time()

    for epoch in range(start_epoch, EPOCHS):
        e_start_time = time.time()
        train_loss, train_acc = train(model_snip, trainloader,
                                           optimizer_snip, criterion, device)
        test_loss, test_acc = evaluate(model_snip, testloader, criterion, device)
        e_end_time = time.time()
        epoch_time = e_end_time - e_start_time

        results['snip']['train_acc'].append(train_acc)
        results['snip']['test_acc'].append(test_acc)
        results['snip']['train_loss'].append(train_loss)
        results['snip']['test_loss'].append(test_loss)
        results['snip']['epoch_time'].append(epoch_time)

        scheduler_snip.step()

        print(f"Epoch [{epoch+1}/{EPOCHS}] "
              f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model_snip, results['snip'], optimizer_snip,
                          epoch + 1, checkpoint_path_snip, 'snip')

    snip_time = time.time() - start_time

    # Training complete, save final checkpoint
    save_checkpoint(model_snip, results['snip'], optimizer_snip,
                   EPOCHS, checkpoint_path_snip, 'snip')

    print(f"Training time: {snip_time/60:.2f} minutes")
else:
    print("Model training already completed")

print(f"Accuracy: {results['snip']['test_acc'][-1]:.2f}%")


METHOD 5: SNIP Pruning
SNIP 剪枝完成！剪枝率: 90.0%
閾值: 0.000036
剪枝遮罩已儲存至 ./checkpoints/masks_snip.pth
# of parameter: 1,533,470 (10.05%)


# Random


In [None]:
print("\n" + "=" * 80)
print("METHOD 6: Random Pruning")
print("=" * 80)

checkpoint_path_random = os.path.join(CHECKPOINT_DIR, 'model_random.pth')
masks_path_random = os.path.join(CHECKPOINT_DIR, 'masks_random.pth')
start_epoch = 0

model_random = VGG('VGG16').to(device)

# Try to load checkpoint
if os.path.exists(checkpoint_path_random):
    load_existing = True  # Set to False to retrain

    if load_existing:
        optimizer_random = optim.SGD(model_random.parameters(), lr=LEARNING_RATE,
                                     momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        model_random, optimizer_random, start_epoch, loaded_results, _ = \
            load_checkpoint(model_random, optimizer_random, checkpoint_path_random)
        if loaded_results is not None:
            results['random'] = loaded_results
            # Load pruning masks
            keep_masks_random = load_pruning_masks(masks_path_random)
            if keep_masks_random is not None:
                apply_prune_mask(model_random, keep_masks_random)
            print(f"Continue training from Epoch {start_epoch}")

if start_epoch == 0:
    # Perform random pruning
    keep_masks_random = random_pruning(model_random, sparsity=SPARSITY)
    apply_prune_mask(model_random, keep_masks_random)

    # Save pruning masks
    save_pruning_masks(keep_masks_random, masks_path_random)

nonzero_params_random = count_nonzero_parameters(model_random)
print(f"# of parameter: {nonzero_params_random:,} ({nonzero_params_random/total_params*100:.2f}%)")

if start_epoch < EPOCHS:
    optimizer_random = optim.SGD(model_random.parameters(), lr=LEARNING_RATE,
                                 momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler_random = optim.lr_scheduler.MultiStepLR(optimizer_random,
                                                     milestones=[25, 40], gamma=0.1)

    start_time = time.time()

    for epoch in range(start_epoch, EPOCHS):
        e_start_time = time.time()
        train_loss, train_acc = train(model_random, trainloader,
                                           optimizer_random, criterion, device)
        test_loss, test_acc = evaluate(model_random, testloader, criterion, device)
        e_end_time = time.time()
        epoch_time = e_end_time - e_start_time

        results['random']['train_acc'].append(train_acc)
        results['random']['test_acc'].append(test_acc)
        results['random']['train_loss'].append(train_loss)
        results['random']['test_loss'].append(test_loss)
        results['random']['epoch_time'].append(epoch_time)

        scheduler_random.step()

        print(f"Epoch [{epoch+1}/{EPOCHS}] "
              f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model_random, results['random'], optimizer_random,
                          epoch + 1, checkpoint_path_random, 'random')

    random_time = time.time() - start_time

    # Training complete, save final checkpoint
    save_checkpoint(model_random, results['random'], optimizer_random,
                   EPOCHS, checkpoint_path_random, 'random')

    print(f"Training time: {random_time/60:.2f} minutes")
else:
    print("Model training already completed")

print(f"Accuracy: {results['random']['test_acc'][-1]:.2f}%")

# Save all results
save_results(results, os.path.join(CHECKPOINT_DIR, 'all_results.json'))

# Save All Results


In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 2, figsize=(16, 10))

method_keys   = ['original', 'GSRP_weightwise', 'GSRP_hybrid', 'GSRP_structure', 'snip', 'random']
method_labels = ['Original', 'GSRP Weight-wise', 'GSRP Hybrid', 'GSRP Structured', 'SNIP', 'Random']
colors        = ['#2E86AB', '#A23B72', '#F18F01', '#06A77D', '#D62839', '#8B4789']

# === 1. Training accuracy ===
for key, label, c in zip(method_keys, method_labels, colors):
    axes[0, 0].plot(results[key]['train_acc'], label=label, linewidth=2.5, alpha=0.8, color=c)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Training Accuracy (%)', fontsize=12)
axes[0, 0].set_title('Training Accuracy over Epochs', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=9)
axes[0, 0].grid(True, alpha=0.3)

# === 2. Test accuracy ===
for key, label, c in zip(method_keys, method_labels, colors):
    axes[0, 1].plot(results[key]['test_acc'], label=label, linewidth=2.5, alpha=0.8, color=c)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[0, 1].set_title('Test Accuracy over Epochs', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=9)
axes[0, 1].grid(True, alpha=0.3)

# === 3. Training loss ===
for key, label, c in zip(method_keys, method_labels, colors):
    axes[1, 0].plot(results[key]['train_loss'], label=label, linewidth=2.5, alpha=0.8, color=c)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Training Loss', fontsize=12)
axes[1, 0].set_title('Training Loss over Epochs', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=9)
axes[1, 0].grid(True, alpha=0.3)

# === 4. Final test accuracy bar chart ===
final_accs = [results[key]['test_acc'][-1] for key in method_keys]

bars = axes[1, 1].bar(method_labels, final_accs, color=colors,
                      alpha=0.8, edgecolor='black', linewidth=1.5)
axes[1, 1].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[1, 1].set_title('Final Test Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1, 1].set_ylim([0, 100])
axes[1, 1].grid(True, alpha=0.3, axis='y')

for bar, acc in zip(bars, final_accs):
    height = bar.get_height()
    axes[1, 1].text(bar.get_x() + bar.get_width()/2., height + 1,
                    f'{acc:.2f}%', ha='center', va='bottom',
                    fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig('all_methods_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
avg_times = []
model_names = []

for model_name, stats in results.items():
    if stats['epoch_time']:
        avg_times.append(sum(stats['epoch_time']) / len(stats['epoch_time']))
        model_names.append(model_name)

plt.figure()
plt.bar(model_names, avg_times)
plt.xlabel("Model")
plt.ylabel("Average time per epoch (seconds)")
plt.title("Average Epoch Time per Model")
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# Load trained models and perform Inference and comparison
# No retraining needed, directly use saved checkpoints
# ============================================================================

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import json
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================================================
# Load model architecture (same as training)
# ============================================================================

class VGG(nn.Module):
    def __init__(self, vgg_name='VGG16', num_classes=10):
        super(VGG, self).__init__()
        self.cfg = {
            'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
        }
        self.features = self._make_layers(self.cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, num_classes)
        )

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                          nn.BatchNorm2d(x),
                          nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


# ============================================================================
# Load test data
# ============================================================================

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

print(f"Test set size: {len(testset)}")


# ============================================================================
# Utility functions
# ============================================================================

def apply_prune_mask(model, keep_masks):
    """Apply pruning masks to model weights"""
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear),
        model.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            def hook(grads):
                return grads * keep_mask
            return hook

        layer.weight.data[keep_mask == 0.] = 0.
        layer.weight.register_hook(hook_factory(keep_mask))


def test_model(model, testloader, device):
    """Test model"""
    model.eval()
    criterion = nn.CrossEntropyLoss()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(testloader)
    test_acc = 100. * correct / total

    return test_loss, test_acc


def count_nonzero_parameters(model):
    """Count non-zero parameters"""
    return sum((p != 0).sum().item() for p in model.parameters() if p.requires_grad)


# ============================================================================
# Load all models and perform Inference
# ============================================================================

CHECKPOINT_DIR = './checkpoints'

print("\n" + "=" * 70)
print("Loading models and performing Inference")
print("=" * 70)

# Load training results
results_path = os.path.join(CHECKPOINT_DIR, 'all_results.json')
if os.path.exists(results_path):
    with open(results_path, 'r') as f:
        results = json.load(f)
    print(f"✓ Training results loaded")
else:
    print("✗ Training results file not found")
    results = None

# 1. Load original model
print("\nLoading original model...")
model_original = VGG('VGG16').to(device)
checkpoint_original = torch.load(os.path.join(CHECKPOINT_DIR, 'model_original.pth'),
                                map_location=device)
model_original.load_state_dict(checkpoint_original['model_state_dict'])
total_params = sum(p.numel() for p in model_original.parameters() if p.requires_grad)

loss_original, acc_original = test_model(model_original, testloader, device)
print(f"✓ Original model")
print(f"  Parameters: {total_params:,}")
print(f"  Test accuracy: {acc_original:.2f}%")
print(f"  Test loss: {loss_original:.4f}")


# 2. Load GSRP weight-wise pruned model
print("\nLoading GSRP weight-wise pruned model...")
model_weightwise = VGG('VGG16').to(device)

checkpoint_weightwise = torch.load(
    os.path.join(CHECKPOINT_DIR, 'model_GSRP_weightwise.pth'),
    map_location=device
)
model_weightwise.load_state_dict(checkpoint_weightwise['model_state_dict'])

# Load and apply pruning masks
masks_weightwise = torch.load(
    os.path.join(CHECKPOINT_DIR, 'masks_GSRP_weightwise.pth'),
    map_location=device
)
apply_prune_mask(model_weightwise, masks_weightwise)

nonzero_params_weightwise = count_nonzero_parameters(model_weightwise)
loss_weightwise, acc_weightwise = test_model(model_weightwise, testloader, device)

print(f"✓ GSRP weight-wise model")
print(f"  Parameters: {nonzero_params_weightwise:,} ({nonzero_params_weightwise/total_params*100:.2f}%)")
print(f"  Test accuracy: {acc_weightwise:.2f}%")
print(f"  Test loss: {loss_weightwise:.4f}")

# 3. Load SNIP pruned model
print("\nLoading SNIP pruned model...")
model_snip = VGG('VGG16').to(device)
checkpoint_snip = torch.load(os.path.join(CHECKPOINT_DIR, 'model_snip.pth'),
                            map_location=device)
model_snip.load_state_dict(checkpoint_snip['model_state_dict'])

# Load and apply pruning masks
masks_snip = torch.load(os.path.join(CHECKPOINT_DIR, 'masks_snip.pth'),
                       map_location=device)
apply_prune_mask(model_snip, masks_snip)

nonzero_params_snip = count_nonzero_parameters(model_snip)
loss_snip, acc_snip = test_model(model_snip, testloader, device)
print(f"✓ SNIP pruned model")
print(f"  Parameters: {nonzero_params_snip:,} ({nonzero_params_snip/total_params*100:.2f}%)")
print(f"  Test accuracy: {acc_snip:.2f}%")
print(f"  Test loss: {loss_snip:.4f}")

# 4. Load random pruned model
print("\nLoading random pruned model...")
model_random = VGG('VGG16').to(device)
checkpoint_random = torch.load(os.path.join(CHECKPOINT_DIR, 'model_random.pth'),
                              map_location=device)
model_random.load_state_dict(checkpoint_random['model_state_dict'])

# Load and apply pruning masks
masks_random = torch.load(os.path.join(CHECKPOINT_DIR, 'masks_random.pth'),
                         map_location=device)
apply_prune_mask(model_random, masks_random)

nonzero_params_random = count_nonzero_parameters(model_random)
loss_random, acc_random = test_model(model_random, testloader, device)
print(f"✓ Random pruned model")
print(f"  Parameters: {nonzero_params_random:,} ({nonzero_params_random/total_params*100:.2f}%)")
print(f"  Test accuracy: {acc_random:.2f}%")
print(f"  Test loss: {loss_random:.4f}")


# ============================================================================
# Visualization comparison
# ============================================================================
if results is not None:
    print("\n" + "=" * 70)
    print("Plotting training curves and comparison charts")
    print("=" * 70)

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))

    # Only plot these four: original / GSRP_weightwise / snip / random
    method_keys   = ['original', 'GSRP_weightwise', 'snip', 'random']
    method_labels = ['Original', 'GSRP Weight-wise', 'SNIP', 'Random']
    colors        = ['#2E86AB', '#06A77D', '#A23B72', '#F18F01']

    # 1. Training accuracy
    for key, label, c in zip(method_keys, method_labels, colors):
        axes[0, 0].plot(results[key]['train_acc'], label=label, linewidth=2.5, color=c)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Training Accuracy (%)')
    axes[0, 0].set_title('Training Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # 2. Test accuracy
    for key, label, c in zip(method_keys, method_labels, colors):
        axes[0, 1].plot(results[key]['test_acc'], label=label, linewidth=2.5, color=c)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Test Accuracy (%)')
    axes[0, 1].set_title('Test Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # 3. Final accuracy bar chart (use real inference results, not just last epoch)
    models_bar = method_labels
    accs_bar   = [acc_original, acc_weightwise, acc_snip, acc_random]
    bars = axes[0, 2].bar(models_bar, accs_bar, color=colors,
                          alpha=0.8, edgecolor='black', linewidth=1.5)
    axes[0, 2].set_ylabel('Test Accuracy (%)')
    axes[0, 2].set_title('Final Test Accuracy')
    axes[0, 2].set_ylim([0, 100])
    axes[0, 2].grid(True, alpha=0.3, axis='y')
    for bar, acc in zip(bars, accs_bar):
        h = bar.get_height()
        axes[0, 2].text(bar.get_x() + bar.get_width()/2., h + 1,
                        f'{acc:.2f}%', ha='center', va='bottom',
                        fontsize=10, fontweight='bold')

    # 4. Training loss
    for key, label, c in zip(method_keys, method_labels, colors):
        axes[1, 0].plot(results[key]['train_loss'], label=label, linewidth=2.5, color=c)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Training Loss')
    axes[1, 0].set_title('Training Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # 5. Test loss
    for key, label, c in zip(method_keys, method_labels, colors):
        axes[1, 1].plot(results[key]['test_loss'], label=label, linewidth=2.5, color=c)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Test Loss')
    axes[1, 1].set_title('Test Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    # 6. Parameter count bar chart
    params_bar = [total_params,
                  nonzero_params_weightwise,
                  nonzero_params_snip,
                  nonzero_params_random]

    bars = axes[1, 2].bar(models_bar, params_bar, color=colors,
                          alpha=0.8, edgecolor='black', linewidth=1.5)
    axes[1, 2].set_ylabel('Number of Parameters')
    axes[1, 2].set_title('Parameter Count')
    axes[1, 2].ticklabel_format(style='plain', axis='y')
    axes[1, 2].grid(True, alpha=0.3, axis='y')
    for bar, param in zip(bars, params_bar):
        h = bar.get_height()
        axes[1, 2].text(bar.get_x() + bar.get_width()/2., h,
                        f'{param/1e6:.2f}M', ha='center', va='bottom',
                        fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig('inference_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Chart saved as 'inference_comparison.png'")

# ============================================================================
# Detailed comparison report
# ============================================================================
print("\n" + "=" * 80)
print("Detailed Comparison Report")
print("=" * 80)

print(f"\n{'Model':<20} {'Parameters':<20} {'Keep Rate':<15} {'Test Accuracy':<15} {'Test Loss':<15}")
print("-" * 80)
print(f"{'Original':<20} {total_params:>15,}   {'100.0%':<15} {acc_original:>10.2f}%   {loss_original:>10.4f}")
print(f"{'GSRP weight-wise':<20} {nonzero_params_weightwise:>15,}   "
      f"{f'{nonzero_params_weightwise/total_params*100:.1f}%':<15} {acc_weightwise:>10.2f}%   {loss_weightwise:>10.4f}")
print(f"{'SNIP pruning':<20} {nonzero_params_snip:>15,}   "
      f"{f'{nonzero_params_snip/total_params*100:.1f}%':<15} {acc_snip:>10.2f}%   {loss_snip:>10.4f}")
print(f"{'Random pruning':<20} {nonzero_params_random:>15,}   "
      f"{f'{nonzero_params_random/total_params*100:.1f}%':<15} {acc_random:>10.2f}%   {loss_random:>10.4f}")

print("\n" + "=" * 80)
print("Accuracy Differences")
print("=" * 80)

print(f"GSRP weight-wise vs Original: {acc_weightwise - acc_original:+.2f}% "
      f"({'✓ Better' if acc_weightwise > acc_original else '✗ Worse'})")
print(f"SNIP vs Original: {acc_snip - acc_original:+.2f}% "
      f"({'✓ Better' if acc_snip > acc_original else '✗ Worse'})")
print(f"Random vs Original: {acc_random - acc_original:+.2f}% "
      f"({'✓ Better' if acc_random > acc_original else '✗ Worse'})")

print(f"GSRP weight-wise vs SNIP: {acc_weightwise - acc_snip:+.2f}% "
      f"({'✓ GSRP better' if acc_weightwise > acc_snip else '✗ SNIP better'})")
print(f"GSRP weight-wise vs Random: {acc_weightwise - acc_random:+.2f}% "
      f"({'✓ GSRP better' if acc_weightwise > acc_random else '✗ Random better'})")
print(f"SNIP vs Random: {acc_snip - acc_random:+.2f}% "
      f"({'✓ SNIP better' if acc_snip > acc_random else '✗ Random better'})")

print("\n" + "=" * 80)
print("Conclusion")
print("=" * 80)
print(f"Parameter reduction (GSRP): {(1 - nonzero_params_weightwise/total_params)*100:.1f}%")
print(f"Parameter reduction (SNIP):  {(1 - nonzero_params_snip/total_params)*100:.1f}%")
print(f"Parameter reduction (Random):  {(1 - nonzero_params_random/total_params)*100:.1f}%")
