In [150]:
# from google.colab import drive
# drive.mount('/content/drive')

In [151]:
from __future__ import print_function
import os
import sys
import logging
import argparse
import time
from time import strftime
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import yaml
import copy


# Change to a directory in your Google Drive
# os.chdir('/content/drive/MyDrive/Colab Notebooks')

from vgg_cifar import vgg13

# Modify sys.argv to remove unwanted arguments
sys.argv = sys.argv[:1]

# settings
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 admm training')
parser.add_argument('--epochs', type=int, default=160, metavar='N',
                    help='number of epochs to train (default: 160)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='training batch size (default: 64)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--load-model-path', type=str, default="./model/cifar10_vgg13_acc_94.730.pt",
                    help='Path to pretrained model')
parser.add_argument('--sparsity-type', type=str, default='unstructured',
                    help="define sparsity_type: [unstructured, filter, etc.]")
parser.add_argument('--sparsity-method', type=str, default='omp',
                    help="define sparsity_method: [omp, imp, etc.]")
parser.add_argument('--yaml-path', type=str, default="./pruning_ratio_unstructured.yaml",
                    help='Path to yaml file')

args = parser.parse_args()

# --- for dubeg use ---------
# args_list = [
#     "--epochs", "160",
#     "--seed", "123",
#     # ... add other arguments and their values ...
# ]
# args = parser.parse_args(args_list)

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * float(correct) / float(len(test_loader.dataset))
    print("===========================PRUNED MODEL==============================================")

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy))

    return accuracy

def get_dataloaders(args):
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.Pad(4),
                             transforms.RandomCrop(32),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
        batch_size=args.batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, download=True,
                         transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
        batch_size=256, shuffle=False)

    return train_loader, test_loader


In [152]:
def read_prune_ratios_from_yaml(file_name, model):
    """
    Reads user-defined layer-wise target pruning ratios from a yaml file.
    Ensures that layer names in the yaml file match the model's layers.
    """
    if not isinstance(file_name, str):
        raise Exception("filename must be a string")

    with open(file_name, "r") as stream:
        try:
            raw_dict = yaml.safe_load(stream)
            prune_ratio_dict = raw_dict['prune_ratios']

            # Check if layer names match model layers
            for layer_name in prune_ratio_dict:
                if layer_name not in dict(model.named_modules()):
                    print(f"Warning: {layer_name} not found in the model!")

            return prune_ratio_dict

        except yaml.YAMLError as exc:
            print(exc)


In [153]:
def unstructured_prune(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    """
    Implement magnitude-based unstructured pruning for weight tensor (of a layer)
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
    
    :return:
        torch.(cuda.)Tensor, pruning mask (1 for nonzeros, 0 for zeros)
    """
    ##################### YOUR CODE STARTS HERE #####################

    # Step 1: Calculate the number of weights to prune
    num_elements = tensor.numel()
    num_prune = int(sparsity * num_elements)

    # Step 2: Find the threshold magnitude using absolute values of weights
    threshold = torch.topk(torch.abs(tensor).view(-1), num_prune, largest=False).values[-1]

    # Step 3: Create the pruning mask based on the absolute values
    mask = torch.abs(tensor) > threshold 

    # Step 4: Apply mask to the tensor
    tensor_pruned = tensor * mask.float()
    ##################### YOUR CODE ENDS HERE #######################

    # return the mask to record the pruning location ()
    return mask.float()



In [154]:
def filter_prune(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    """
    implement L2-norm-based filter pruning for weight tensor (of a layer)
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
    
    :return:
        torch.(cuda.)Tensor, pruning mask (1 for nonzeros, 0 for zeros)
    """
    
    ##################### YOUR CODE STARTS HERE #####################
    num_filters = tensor.shape[0]
    num_prune = int(sparsity * num_filters)

    # Calculate the L2 norm for each filter
    filter_norms = torch.norm(tensor.view(num_filters, -1), p=2, dim=1)  
    # Find the threshold norm for pruning
    threshold = torch.topk(filter_norms, num_prune, largest=False).values[-1]

    # Step 3: Get the pruning mask tensor based on the threshold
    #         ||filter||2 <= th -> mask=0,
    #         ||filter||2 >  th -> mask=1
    mask = (filter_norms > threshold).float().view(-1, *[1] * (tensor.dim() - 1))
    mask = mask.expand_as(tensor)  
    
    # Step 4: Apply mask tensor to the weight tensor
    tensor_pruned = tensor * mask

    ##################### YOUR CODE ENDS HERE #######################

    # Return the mask to record the pruning location
    return mask


In [155]:
def masked_retrain(model, sparsity_type, prune_ratio_dict, device, dataloader, criterion, optimizer, save_path, num_epochs=5):

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * len(dataloader), eta_min=4e-08)
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Accuracy computation
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # Step the scheduler
        scheduler.step()

        # Print Epoch Num and Accuracy
        epoch_accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Accuracy: {epoch_accuracy:.2f}%")


        # Apply the pruning mask again to all layers after training
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d):  
                if name in prune_ratio_dict:
                    sparsity = prune_ratio_dict[name]
                    if sparsity_type == 'unstructured':
                        mask = unstructured_prune(module.weight.data, sparsity)
                    elif sparsity_type == 'filter':
                        mask = filter_prune(module.weight.data, sparsity)

                    # Apply the mask to the layer's weights to keep pruned weights as zero
                    with torch.no_grad():
                        module.weight.data *= mask  # Reapply mask to keep pruned weights zero


    # Save the model
    torch.save(model.state_dict(), save_path)
    print(f"Model retrained and saved to {save_path}")

In [156]:
def apply_pruning(model, sparsity_type, prune_ratio_dict, device, dataloader, criterion, optimizer, save_path):
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):  # find only conv layers
            if name in prune_ratio_dict:
                sparsity = prune_ratio_dict[name]
                if sparsity_type == 'unstructured':
                    # unstructured pruning
                    mask = unstructured_prune(module.weight.data, sparsity)
                elif sparsity_type == 'filter':
                    # filter pruning 
                    mask = filter_prune(module.weight.data, sparsity)

                # Apply mask 
                module.weight.data *= mask
    
    masked_retrain(model, sparsity_type, prune_ratio_dict, device, dataloader, criterion, optimizer, save_path, num_epochs=5)



In [157]:
def oneshot_magnitude_prune(model, sparsity_type, prune_ratio_dict, device, train_loader, criterion, optimizer, save_path, num_epochs=5 ):
    
    # Apply pruning (unstructured or filter)
    apply_pruning(model, sparsity_type, prune_ratio_dict, device, train_loader, criterion, optimizer, save_path)

    # Retrain the pruned model (masked retraining)
    masked_retrain(model, sparsity_type, prune_ratio_dict, device, train_loader, criterion, optimizer, save_path ,num_epochs)
    
    # complete
    print("==== OneShot Magnitude Training Complete ====")



In [158]:
def iterative_magnitude_prune(model, prune_ratio_dict, sparsity_type, device, dataloader, criterion, optimizer, save_path, iterations=3, num_epochs=5):
    
    for iteration in range(iterations):
        print(f"==== Iteration {iteration + 1} of {iterations} ====")

        # current pruning ratios
        current_prune_ratios = {
            layer: prune_ratio_dict[layer] * ((iteration + 1) / iterations)
            for layer in prune_ratio_dict
        }

        # apply pruning
        print("Applying pruning...")
        apply_pruning(model, sparsity_type, current_prune_ratios, device, dataloader, criterion, optimizer, save_path)

        # retraining
        print("Retraining pruned model...")
        masked_retrain(model, sparsity_type, current_prune_ratios, device, dataloader, criterion, optimizer, save_path, num_epochs)

    print("=== Iterative Magnitude Pruning Complete ===")
    return model

In [159]:
def test_sparity(model, sparsity_type="unstructured"):
    
    print(f"Sparsity type is: {sparsity_type}")
    total_zeros = 0
    total_params = 0

    conv_layers = {}
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            conv_layers[name +".weight"] = layer
                
    if sparsity_type == 'unstructured':
      for name, param in model.named_parameters():
          if 'weight' in name:
                  if name == "classifier.weight" or param.numel()<1000:
                    continue
                  else:
                    zero_count = (param == 0).sum().item()
                    total_zeros += zero_count
                    total_params += param.numel()
                    print(f"(zero/total) weights of {name} is: ({zero_count}/{param.numel()}). Sparsity is: {zero_count / param.numel():.4f}")
    
    elif sparsity_type == 'filter':
      for name, param in model.named_parameters(): 
          if name in conv_layers.keys():        
            filters_zero = (param.view(param.size(0), -1).norm(p=2, dim=1) == 0).sum().item()
            total_zeros += filters_zero
            total_params += param.size(0)
            print(f"(empty/total) filters of {name} is: ({filters_zero}/{param.size(0)}). Filter sparsity is: {filters_zero / param.size(0):.4f}")

    overall_sparsity = (total_zeros / total_params)*100
    print(f"Total sparsity is: {overall_sparsity:.4f}")


In [160]:
def get_pruned_filters(pruned_model: nn.Module) -> dict:
   
    pruned_filters_dict = {}

    for name, layer in pruned_model.named_modules():
        # Check if convolutional layer
        if isinstance(layer, nn.Conv2d):
            # Get weight 
            weights = layer.weight.data
            num_filters = weights.shape[0]  # Number of filters (output channels)
            
            # find filters that are pruned 
            pruned_filters = []
            for i in range(num_filters):
                if torch.all(weights[i] == 0):
                    pruned_filters.append(i)
            
            if pruned_filters:
                pruned_filters_dict[name] = pruned_filters


    return pruned_filters_dict



In [161]:
def prune_channels_after_filter_prune(model, pruned_filter_dict):

    for i, (layer_name, filter_indices) in enumerate(pruned_filter_dict.items()):
        
        # Determine the next layer's name
        next_layer = list(pruned_filter_dict.keys())[i + 1] if i + 1 < len(pruned_filter_dict) else None
        
        if next_layer not in pruned_filter_dict:
            continue
        if next_layer == None:
            continue
        
        
        # prune the channels from the next layer
        for name, module in model.named_modules():
            if name == next_layer: 
                # Get the current weights of next layer
                weight = module.weight.data
                
                # prune channels
                for i in filter_indices:
                    weight[:, i, :, :] = 0  

                module.weight.data = weight
                break
        
    return model


In [162]:
def main():

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Setup random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set up model architecture and load pretrained dense model
    model = vgg13()
    model.to(device)
    # Load model state dict with device mapping
    model.load_state_dict(torch.load(args.load_model_path, map_location=device))


    # Get the training and testing data loaders
    train_loader, test_loader = get_dataloaders(args)

    # Select loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)


    # You may use this lr scheduler to fine-tune your pruned model.
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader), eta_min=1e-5)


    print("============================UNPRUNED MODEL===========================================")

    # # Test the model after pruning and fine-tuning
    # print("Testing the UNPRUNED model...")
    # test(model, device, test_loader)

    # print("Sparsity of UNPRUNED model:")
    # test_sparity(model, "unstructured")


    print("===========================PRUNED MODEL==============================================")

    model_copy = copy.deepcopy(model)
    model_copy.to(device)
    
    # Read pruning ratios from the YAML file
    print(f"Reading pruning ratios from {args.yaml_path}...")
    prune_ratio_dict = read_prune_ratios_from_yaml(args.yaml_path, model)
    # print(prune_ratio_dict)

    pruning_type = "unstructured"
    # save_path = "./model/OMP_filter_0.4.pt"

    # # Apply OMP
    # save_path = "./model/OMP_unstructured_0.8.pt" if pruning_type == "unstructured" else "./model/OMP_filter_0.4.pt"
    # print(f'Applying oneshot magnitude {pruning_type} pruning...')
    # oneshot_magnitude_prune(model_copy, pruning_type, prune_ratio_dict, device, train_loader, criterion, optimizer, save_path, num_epochs=5)

    # Apply IMP
    save_path = "./model/IMP_unstructured_0.8.pt" if pruning_type == "unstructured" else "./model/IMP_filter_0.4.pt"
    print(f'Applying Iterative magnitude {pruning_type} pruning...')
    iterative_magnitude_prune(model_copy, prune_ratio_dict, pruning_type, device, train_loader, criterion, optimizer, save_path, 3, 5)

    pruned_model = vgg13()
    pruned_model.to(device)
    pruned_model.load_state_dict(torch.load(save_path, map_location=device))
    
    # Test the model after pruning and fine-tuning
    print("Testing the pruned and fine-tuned model...")
    test(pruned_model, device, test_loader)

    # Check model sparsity after pruning
    print("Testing sparsity after pruning...")
    test_sparity(pruned_model, sparsity_type=pruning_type)
    
    # print("===========================TESTING CHANNEL PRUNED MODEL==============================================")
    
    
    # pruned_filters_dict = get_pruned_filters(pruned_model)    
    # new_model = prune_channels_after_filter_prune(pruned_model, pruned_filters_dict)
    
    # print("Testing the pruned and fine-tuned model...")
    # test(pruned_model, device, test_loader)
    
    # print("Testing the PRUNED CHANNELS MODEL...")
    # test(new_model, device, test_loader)
    

if __name__ == '__main__':
    main()


  model.load_state_dict(torch.load(args.load_model_path, map_location=device))


Files already downloaded and verified
Files already downloaded and verified
Reading pruning ratios from ./pruning_ratio_unstructured.yaml...
Applying Iterative magnitude unstructured pruning...
==== Iteration 1 of 3 ====
Applying pruning...
Epoch [1/5], Accuracy: 99.98%
Epoch [2/5], Accuracy: 99.98%
Epoch [3/5], Accuracy: 99.97%
Epoch [4/5], Accuracy: 99.99%
Epoch [5/5], Accuracy: 99.97%
Model retrained and saved to ./model/IMP_unstructured_0.8.pt
Retraining pruned model...
Epoch [1/5], Accuracy: 99.99%
Epoch [2/5], Accuracy: 99.97%
Epoch [3/5], Accuracy: 99.98%
Epoch [4/5], Accuracy: 99.97%
Epoch [5/5], Accuracy: 99.97%
Model retrained and saved to ./model/IMP_unstructured_0.8.pt
==== Iteration 2 of 3 ====
Applying pruning...
Epoch [1/5], Accuracy: 99.96%
Epoch [2/5], Accuracy: 99.97%
Epoch [3/5], Accuracy: 99.95%
Epoch [4/5], Accuracy: 99.97%
Epoch [5/5], Accuracy: 99.95%
Model retrained and saved to ./model/IMP_unstructured_0.8.pt
Retraining pruned model...
Epoch [1/5], Accuracy: 99

  pruned_model.load_state_dict(torch.load(save_path, map_location=device))



Test set: Average loss: -2.9464, Accuracy: 9378/10000 (93.7800%)

Testing sparsity after pruning...
Sparsity type is: unstructured
(zero/total) weights of features.0.weight is: (0/1728). Sparsity is: 0.0000
(zero/total) weights of features.3.weight is: (3686/36864). Sparsity is: 0.1000
(zero/total) weights of features.7.weight is: (11059/73728). Sparsity is: 0.1500
(zero/total) weights of features.10.weight is: (29491/147456). Sparsity is: 0.2000
(zero/total) weights of features.14.weight is: (103219/294912). Sparsity is: 0.3500
(zero/total) weights of features.17.weight is: (412876/589824). Sparsity is: 0.7000
(zero/total) weights of features.21.weight is: (943718/1179648). Sparsity is: 0.8000
(zero/total) weights of features.24.weight is: (2123366/2359296). Sparsity is: 0.9000
(zero/total) weights of features.28.weight is: (2123366/2359296). Sparsity is: 0.9000
(zero/total) weights of features.31.weight is: (1887436/2359296). Sparsity is: 0.8000
Total sparsity is: 81.2399
