In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

import torch.nn.utils.prune as prune

# import matplotlib.pyplot as plt

import numpy as np
import os
import time
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"currently available device: {device}")

currently available device: cuda


In [3]:
transform_train = transforms.Compose(
    [
      transforms.RandomCrop(32, padding = 4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

transform_test = transforms.Compose(
    [
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

In [4]:
train_dataset = torchvision.datasets.CIFAR10(
        root = './data', train = True,
        download = True, transform = transform_train
        )

test_dataset = torchvision.datasets.CIFAR10(
        root = './data', train = False,
        download = True, transform = transform_test
        )

Files already downloaded and verified
Files already downloaded and verified


In [5]:
batch_size = 256
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size = batch_size,
        shuffle = True, num_workers=16, pin_memory=True
        )

test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size = batch_size,
        shuffle = False , num_workers=16, pin_memory=True
        )

In [6]:
# Model

from torchvision.models import ResNet50_Weights, resnet50
standard_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)


In [7]:
# from torchvision.models import quantization
# quantized_model = quantization.resnet50(weights=quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, 
#                                         quantize=True)

In [8]:
num_ftrs = standard_model.fc.in_features

standard_model.fc = nn.Linear(in_features = num_ftrs, out_features = 10)

standard_model.conv1 = torch.nn.Conv2d(
    in_channels = 3, out_channels = 64,
    kernel_size = (3, 3), stride = (1, 1),
    padding = (1, 1), bias = False
)

In [9]:
standard_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
# pruning 함수 정의

In [11]:
def remove_parameters(model):
    
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass

    return model

In [12]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

    num_zeros = 0
    num_elements = 0

    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [13]:
def measure_global_sparsity(
    model, weight = True,
    bias = False, conv2d_use_mask = False,
    linear_use_mask = False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [14]:
l1_regularization_strength = 0
l2_regularization_strength = 1e-4
learning_rate = 0.01
learning_rate_decay = 1

In [15]:
def evaluate_model(model, test_loader, device, criterion = None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy

In [16]:
def fine_tune_train_model(model, train_loader, test_loader, device, l1_regularization_strength = 0,
                l2_regularization_strength = 1e-4, learning_rate = 1e-1, num_epochs = 20):

    # The training configurations were not carefully selected.

    criterion = nn.CrossEntropyLoss()

    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10-
    optimizer = torch.optim.SGD(
        model.parameters(), lr = learning_rate,
        momentum = 0.9, weight_decay = l2_regularization_strength
    )
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    # Define learning rate scheduler-
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        # optimizer, milestones = [100, 150],
        optimizer, milestones = [8, 15],
        gamma = 0.1, last_epoch = -1)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    

    # Evaluation-
    model.eval()
    eval_loss, eval_accuracy = evaluate_model(
        model = model, test_loader = test_loader,
        device = device, criterion = criterion)
    
    print(f"Pre fine-tuning: val_loss = {eval_loss:.3f} & val_accuracy = {eval_accuracy * 100:.3f}%")
    # print("Epoch: {:03d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(0, eval_loss, eval_accuracy))

    
    for epoch in range(num_epochs):

        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:

            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            l1_reg = torch.tensor(0.).to(device)
            for module in model.modules():
                mask = None
                weight = None
                for name, buffer in module.named_buffers():
                    if name == "weight_mask":
                        mask = buffer
                for name, param in module.named_parameters():
                    if name == "weight_orig":
                        weight = param
                # We usually only want to introduce sparsity to weights and prune weights.
                # Do the same for bias if necessary.
                if mask is not None and weight is not None:
                    l1_reg += torch.norm(mask * weight, 1)

            loss += l1_regularization_strength * l1_reg

            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(
            model = model, test_loader = test_loader,
            device = device, criterion = criterion)

        # Set learning rate scheduler
        scheduler.step()

        '''
        print(
            "Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}"
            .format(epoch + 1, train_loss, train_accuracy, eval_loss,
                    eval_accuracy))
        '''
        print(f"epoch = {epoch + 1} loss = {train_loss:.3f}, accuracy = {train_accuracy * 100:.3f}%, val_loss = {eval_loss:.3f}, val_accuracy = {eval_accuracy * 100:.3f}% & LR: {optimizer.param_groups[0]['lr']:.4f}")

    return model

In [17]:
def iterative_pruning_finetuning(
    model, train_loader, test_loader, device,
    learning_rate, l1_regularization_strength,
    l2_regularization_strength, learning_rate_decay = 0.1,
    conv2d_prune_amount = 0.2, linear_prune_amount = 0.1,
    num_iterations = 10, num_epochs_per_iteration = 10,
    model_filename_prefix = "pruned_model", model_dir = "saved_models",
    grouped_pruning = False):
    
    for i in range(num_iterations):

        print("\nPruning and Finetuning {}/{}".format(i + 1, num_iterations))

        print("Pruning...")


        # NOTE: For global pruning, linear/dense layer can also be pruned!
        if grouped_pruning == True:
            # grouped_pruning -> Global pruning
            parameters_to_prune = []
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    parameters_to_prune.append((module, "weight"))
                elif isinstance(module, torch.nn.Linear):
                    parameters_to_prune.append((module, "weight"))
        
            # L1Unstructured - prune (currently unpruned) entries in a tensor by zeroing
            # out the ones with the lowest absolute magnitude-
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method = prune.L1Unstructured,
                amount = conv2d_prune_amount,
            )
        
        # layer-wise pruning-
        else:
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = conv2d_prune_amount)
                elif isinstance(module, torch.nn.Linear):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = linear_prune_amount)

        # Compute validation accuracy just after pruning-
        _, eval_accuracy = evaluate_model(
            model = model, test_loader = test_loader,
            device = device, criterion = None)

    

        # Compute global sparsity-
        num_zeros, num_elements, sparsity = measure_global_sparsity(
            model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)
        
        print(f"Global sparsity = {sparsity * 100:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")
        # print(model.conv1._forward_pre_hooks)
        
        
        print("\nFine-tuning...")
        
        fine_tuned_model = fine_tune_train_model(
            model = model, train_loader = train_loader,
            test_loader = test_loader, device = device,
            l1_regularization_strength = l1_regularization_strength,
            l2_regularization_strength = l2_regularization_strength,
            # i -> current pruning round-
            # learning_rate = learning_rate * (learning_rate_decay ** i),
            learning_rate = learning_rate,
            num_epochs = num_epochs_per_iteration)

        _, eval_accuracy = evaluate_model(
            model=model, test_loader = test_loader,
            device = device, criterion = None)


        num_zeros, num_elements, sparsity = measure_global_sparsity(
            # model,
            fine_tuned_model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)

        print(f"Post fine-tuning: Global sparsity = {sparsity * 100:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")
        model = remove_parameters(model = model)
        '''
        model_filename = "{}_{}.pt".format(model_filename_prefix, i + 1)
        model_filepath = os.path.join(model_dir, model_filename)
        save_model(model=model,
                   model_dir=model_dir,
                   model_filename=model_filename)
        model = load_model(model=model,
                           model_filepath=model_filepath,
                           device=device)
        '''
        
    return model

In [18]:
_, eval_accuracy = evaluate_model(
    model = standard_model, test_loader=test_loader,
    device = device, criterion = None)

In [19]:
num_zeros, num_elements, sparsity = measure_global_sparsity(standard_model)
print(f"Global sparsity = {sparsity:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")

Global sparsity = 0.000% & val_accuracy = 10.810%


In [20]:
pruned_model = copy.deepcopy(standard_model)


# Prune and fine-tune trained model-
'''
num_iterations - number of pruning iterations/rounds
num_epochs_per_iteration - number of fine-tuning rounds
'''
pruned_model = iterative_pruning_finetuning(
        model = pruned_model, train_loader = train_loader,
        test_loader = test_loader, device = device,
        learning_rate = learning_rate, learning_rate_decay = learning_rate_decay,
        l1_regularization_strength = l1_regularization_strength, l2_regularization_strength = l2_regularization_strength,
        conv2d_prune_amount = 0.2, linear_prune_amount = 0.1,
        num_iterations = 5, num_epochs_per_iteration = 7,
        grouped_pruning = True)


Pruning and Finetuning 1/5
Pruning...
Global sparsity = 19.978% & val_accuracy = 10.920%

Fine-tuning...
Pre fine-tuning: val_loss = 2.384 & val_accuracy = 10.920%
epoch = 1 loss = 1.211, accuracy = 56.496%, val_loss = 0.554, val_accuracy = 80.920% & LR: 0.0100
epoch = 2 loss = 0.456, accuracy = 84.182%, val_loss = 0.356, val_accuracy = 87.510% & LR: 0.0100
epoch = 3 loss = 0.296, accuracy = 89.768%, val_loss = 0.270, val_accuracy = 90.520% & LR: 0.0100
epoch = 4 loss = 0.221, accuracy = 92.342%, val_loss = 0.245, val_accuracy = 91.770% & LR: 0.0100
epoch = 5 loss = 0.175, accuracy = 93.984%, val_loss = 0.232, val_accuracy = 92.220% & LR: 0.0100
epoch = 6 loss = 0.146, accuracy = 94.962%, val_loss = 0.230, val_accuracy = 92.470% & LR: 0.0100
epoch = 7 loss = 0.122, accuracy = 95.712%, val_loss = 0.214, val_accuracy = 92.940% & LR: 0.0100
Post fine-tuning: Global sparsity = 19.978% & val_accuracy = 92.940%

Pruning and Finetuning 2/5
Pruning...
Global sparsity = 20.000% & val_accuracy 

In [21]:
final_model = remove_parameters(model = pruned_model)

_, eval_accuracy = evaluate_model(
    model = pruned_model, test_loader = test_loader,
    device = device, criterion = None
)


num_zeros, num_elements, sparsity = measure_global_sparsity(final_model)


print(f"Global sparsity = {sparsity:.3f} & val_accuracy = {eval_accuracy:.3f}")

Global sparsity = 0.200 & val_accuracy = 0.950


In [22]:
torch.save(final_model.state_dict(), f"./ResNet50_trained_sparsity-{sparsity * 100:.3f}.pth")

In [23]:
print("%.2f MB" %(os.path.getsize("/home/aiteam/tykim/scratch/lightweight/pruning/ResNet50_trained_sparsity-20.000.pth")/1e6))

94.40 MB


In [26]:
standard_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')
    
print_model_size(standard_model)

102.54 MB
