L1 pruning: torch.nn.utils.prune.l1_unstructured(module, name, amount, importance_scores=None)

Ln pruning: torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)

In [None]:
import os
import copy
import time
import pandas as pd
from torchvision.io import read_image
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
from torch import autograd
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import vgg16
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.utils.prune as prune
from heapq import nsmallest
import torch.optim.lr_scheduler as lr_scheduler
import torch.optim as optim

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1 
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

In [4]:
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
train_ds = Cub2011('.', train=True, transform = transform)
val_ds = Cub2011('.s', train=False, transform = transform)

ds = {'train': DataLoader(train_ds, batch_size = 32, shuffle=True),
      'val': DataLoader(val_ds, batch_size = 32, shuffle=False)}


ds_sizes = {'train': len(train_ds),
      'val': len(val_ds)}

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes,nclas, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_bal_acc = 0.0

    val_bal_acc = []
    val_acc = []
    val_loss = []

    train_bal_acc = []
    train_acc = []
    train_loss = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            CF = np.zeros((nclas,nclas)) # Confusion matrix

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                for i in range(len(labels.data)):
                    CF[labels.data[i]][preds[i]] +=1

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            recalli = 0
            for i in range(nclas):
                TP = CF[i][i]
                FN = 0
                for j in range(nclas):
                    if i!=j:
                        FN+=CF[i][j]
                if (TP+FN) !=0:
                    recalli+= TP/(TP+FN)
            epoch_bal_acc = recalli/nclas

            if phase == 'val':
                val_bal_acc.append(epoch_bal_acc)
                val_acc.append(epoch_acc)
                val_loss.append(epoch_loss)
            else:
                train_bal_acc.append(epoch_bal_acc)
                train_acc.append(epoch_acc)
                train_loss.append(epoch_loss)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Balanced Acc: {epoch_bal_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_bal_acc > best_bal_acc:
                best_acc = epoch_acc
                best_bal_acc = epoch_bal_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    #print(f'Best val Acc: {best_acc:4f}')
    print(f'Best val Balanced Acc: {best_bal_acc:4f}')

    print('Validation:')
    print('Val_bal_acc:', val_bal_acc)
    print('Val_acc:', val_acc)
    print('Val_loss:', val_loss)

    print('Training:')
    print('Train_bal_acc:', train_bal_acc)
    print('Train_acc:', train_acc)
    print('Train_loss:', train_loss)

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Pruning % of each layer (L1-Norm)

In [9]:
#If finetuning == False -> Feature extraction
def imageNetPruningCUBL1(pruningAmount, epochs, finetuning): 
  model = vgg16(weights='IMAGENET1K_V1')
  model.classifier[6] = nn.Linear(4096, 200) 

  features = []
  classifier = []

  for n,p in model.named_parameters():
    if n.split('.')[0] == 'features' and n.split('.')[2] == 'weight':
      features.append(int(n.split('.')[1]))
    elif n.split('.')[0] == 'classifier' and n.split('.')[2] == 'weight':
      classifier.append(int(n.split('.')[1]))

  for x in features:
    prune.l1_unstructured(model.features[x], 'weight', amount=pruningAmount)

    


  for x in classifier:
    prune.l1_unstructured(model.classifier[x], 'weight', amount=pruningAmount)

  if(finetuning):
    optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  else:
    optimizer_ft = optim.SGD(model.classifier[6].parameters(), lr=0.001, momentum=0.9)

  criterion = nn.CrossEntropyLoss()

  exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

  model = model.to(device)

  model_conv = train_model(model, criterion, optimizer_ft, exp_lr_scheduler, ds, ds_sizes, 200, num_epochs=epochs)

In [14]:
for x in range(1, 7):
  pruningAmount = x/10
  print("Pruning "+ str(x*10)+'% ... with Fine Tuning')
  imageNetPruningCUBL1(0.9999, 25, True)
  break
  print("Pruning "+ str(x*10)+'% ... with Feature Extraction')
  imageNetPruningCUBL1(pruningAmount, 25, False)




Pruning 10% ... with Fine Tuning
Before:
Parameter containing:
tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

         [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

         [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

         [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

         [[-3.1432e-01, -3.7048e-01, -1.3094e-01],
          [-4.7144e-01, -1.5503e-01,  3.4589e-01],
          [ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],


   

Pruning % of entire model (L1-Norm)