## Install Dependencies

In [3]:
!pip install -q -r /kaggle/input/spectrogrand-size-reduction/requirements.txt

In [4]:
!pip install -q sparseml

## Migrate code from HouseX repo

In [5]:
import numpy as np
import torch
import json

class log(object):
    def __init__(self) -> None:
        self.data = {
            'train_loss': [], 'val_loss': [], 'test_loss': [],
            'train_acc': [], 'val_acc': [], 'test_acc': []
        }
        
    def push(self, train_loss, train_acc, val_loss, val_acc, test_loss, test_acc):
        self.data['train_loss'] += [train_loss]
        self.data['val_loss'] += [val_loss]
        self.data['test_loss'] += [test_loss]
        self.data['train_acc'] += [train_acc]
        self.data['val_acc'] += [val_acc]
        self.data['test_acc'] += [test_acc]
        
    def save(self, tar_path):
        with open(tar_path, 'w') as f:
            json.dump(self.data, f)

In [6]:
import random
import os
import json
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report
from tqdm import tqdm, trange
import librosa
from librosa import display
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms, models
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import argparse
from datetime import date
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import time 


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

BATCH_SIZE = 8
data_path = "./"
song_types = ['future house', 'bass house', 'progressive house', 'melodic house']
EPOCHS = 20

2024-04-27 06:55:57.187153: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-27 06:55:57.187246: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-27 06:55:57.318934: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [7]:
def get_tensors(transform, path='./melspecgrams/', mode=None):
    # Collect data
    image_tensors = []
    label_tensors = []

    spec_dir = os.path.join(path, mode)
    img_list = [ele for ele in os.listdir(spec_dir) if '.jpg' in ele]
    for img in img_list:
        # song_type = img[:img.index('-')] # Expected file name example: bass house-1.jpg
        song_type = img.split('/')[-1].split("_")[0]
        # print(img, song_type)
        img_path = spec_dir + '/' + img
        img_tensor = transform(Image.open(img_path).convert('RGB'))
        image_tensors.append(img_tensor)
        label_tensors.append(song_types.index(song_type))

    return image_tensors, label_tensors


class MelSpectrogramDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # input, target
        return self.images[idx], torch.tensor(self.labels[idx], dtype=torch.long)

In [15]:
criterion = nn.CrossEntropyLoss()

def one_epoch(model, loader, mode, opt, device=DEVICE, epoch_id=None):
    def run(mode, opt, device=device):
        losses = []
        correct_preds = 0
        length = 0
        for batch in tqdm(loader, desc='Epoch '+str(epoch_id+1)+' '+mode):
            images, labels = batch[0].to(device), batch[1].to(device)
            length += images.shape[0]

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            correct_preds += torch.sum(preds == labels).item()
            losses.append(loss.item())

            if mode == 'train':
                opt.zero_grad()
                loss.backward()
                opt.step()
        return {'loss': np.mean(losses), 'accuracy': correct_preds / length}

    if mode == 'train':
        model = model.train()
        return run('train', opt, device)
    else:
        model = model.eval()
        with torch.no_grad():
            return run('test', opt, device)


def train(train_loader, val_loader, test_loader, model, opt, epochs=EPOCHS, device=DEVICE, writer=None, eval_first=True):
    model = model.to(device)
    epoch = -1
    if eval_first:
        evaluate(model, device, val_loader, 'val', epoch)
        evaluate(model, device, test_loader, 'test', epoch)

    cur_log = log()
    for epoch in range(epochs):
        ret = one_epoch(model, train_loader, 'train', opt, device, epoch)
        train_loss = ret['loss']
        train_acc = ret['accuracy']
        print('Epoch {}: '.format(epoch+1))
        print(f"train loss: {train_loss}")
        print(f"train accuracy: {train_acc}")

        val_loss, val_acc = evaluate(model, device, val_loader, 'val', epoch)
        test_loss, test_acc = evaluate(model, device, test_loader, 'test', epoch)
        if writer:
            writer.add_scalar('loss/train', train_loss, epoch)
            writer.add_scalar('accuracy/train', train_acc, epoch)
            writer.add_scalar('loss/val', val_loss, epoch)
            writer.add_scalar('accuracy/val', val_acc, epoch)
            writer.add_scalar('loss/test', test_loss, epoch)
            writer.add_scalar('accuracy/test', test_acc, epoch)
            
            cur_log.push(train_loss, train_acc, val_loss, val_acc, test_loss, test_acc)
    model = model.cpu()
    return cur_log


def evaluate(model, device=DEVICE, loader=None, comment='val', epoch_id=None):
    model = model.to(device)
    ret = one_epoch(model, loader, 'test', None, device, epoch_id)
    loss = ret['loss']
    accuracy = ret['accuracy']

    print(f"{comment} loss: {loss}")
    print(f"{comment} accuracy: {accuracy}")
    return loss, accuracy

## Load Data

In [16]:
train_transform = transforms.Compose([
            transforms.Resize((96, 96)),
            transforms.RandomPosterize(2, p = 0.25),
            transforms.ColorJitter(brightness = (0.50, 1.00)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

test_transform = transforms.Compose([
            transforms.Resize((96, 96)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
DATA_DIR="/kaggle/input/housex-spectrograms/melspecgrams_data"

train_set = MelSpectrogramDataset(*get_tensors(train_transform, DATA_DIR, mode='train'))
val_set = MelSpectrogramDataset(*get_tensors(test_transform, DATA_DIR, mode='val'))
test_set = MelSpectrogramDataset(*get_tensors(test_transform, DATA_DIR, mode='test'))

train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, BATCH_SIZE, shuffle=True)

print('dataset length:', len(train_set), len(val_set), len(test_set))
print('dataloader length:', len(train_loader), len(val_loader), len(test_loader))


dataset length: 3440 537 337
dataloader length: 430 68 43


## Load Resnet101 model

In [11]:
BASE_MODEL_PATH = "/kaggle/input/spectrogrand-size-reduction/resnet_finetuned_full.pth"

base_model = torch.load(BASE_MODEL_PATH)
base_model.to(DEVICE)
base_model.eval()

Sequential(
  (0): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

In [None]:
# Ref: https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

_ = print_size_of_model(base_model, "base")

In [17]:
with torch.no_grad():
    loss, accuracy = evaluate(base_model, device=DEVICE, loader=val_loader, comment='val', epoch_id=0)
    print('='*90)
    loss, accuracy = evaluate(base_model, device=DEVICE, loader=test_loader, comment='test', epoch_id=0)

Epoch 1 test:   0%|          | 0/68 [00:00<?, ?it/s]

val loss: 1.3374779790653573
val accuracy: 0.7523277467411545


Epoch 1 test:   0%|          | 0/43 [00:00<?, ?it/s]

test loss: 1.5709080062747278
test accuracy: 0.7566765578635015


## List out the prunable layers

In [18]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision import transforms
from tqdm.auto import tqdm
import math

In [19]:
for (name, layer) in get_prunable_layers(base_model):
    print(f"{name}")

0.conv1
0.layer1.0.conv1
0.layer1.0.conv2
0.layer1.0.conv3
0.layer1.0.downsample.0
0.layer1.1.conv1
0.layer1.1.conv2
0.layer1.1.conv3
0.layer1.2.conv1
0.layer1.2.conv2
0.layer1.2.conv3
0.layer2.0.conv1
0.layer2.0.conv2
0.layer2.0.conv3
0.layer2.0.downsample.0
0.layer2.1.conv1
0.layer2.1.conv2
0.layer2.1.conv3
0.layer2.2.conv1
0.layer2.2.conv2
0.layer2.2.conv3
0.layer2.3.conv1
0.layer2.3.conv2
0.layer2.3.conv3
0.layer3.0.conv1
0.layer3.0.conv2
0.layer3.0.conv3
0.layer3.0.downsample.0
0.layer3.1.conv1
0.layer3.1.conv2
0.layer3.1.conv3
0.layer3.2.conv1
0.layer3.2.conv2
0.layer3.2.conv3
0.layer3.3.conv1
0.layer3.3.conv2
0.layer3.3.conv3
0.layer3.4.conv1
0.layer3.4.conv2
0.layer3.4.conv3
0.layer3.5.conv1
0.layer3.5.conv2
0.layer3.5.conv3
0.layer3.6.conv1
0.layer3.6.conv2
0.layer3.6.conv3
0.layer3.7.conv1
0.layer3.7.conv2
0.layer3.7.conv3
0.layer3.8.conv1
0.layer3.8.conv2
0.layer3.8.conv3
0.layer3.9.conv1
0.layer3.9.conv2
0.layer3.9.conv3
0.layer3.10.conv1
0.layer3.10.conv2
0.layer3.10.conv3

## Specify Regex for pruning

```
- '0.conv1.weight'
- 're:0.layer1.*.conv1.weight'
- 're:0.layer1.*.conv2.weight'
- 're:0.layer1.*.conv3.weight'
- 're:0.layer1.0.downsample.0.weight'
- 're:0.layer2.*.conv1.weight'
- 're:0.layer2.*.conv2.weight'
- 're:0.layer2.*.conv3.weight'
- 're:0.layer2.0.downsample.0.weight'
- 're:0.layer3.*.conv1.weight'
- 're:0.layer3.*.conv2.weight'
- 're:0.layer3.*.conv3.weight'
- 're:0.layer3.0.downsample.0.weight'
- 're:0.layer4.*.conv1.weight'
- 're:0.layer4.*.conv2.weight'
- 're:0.layer4.*.conv3.weight'
- 're:0.layer4.0.downsample.0.weight'
```

## Write config into YAML file

In [20]:
%%writefile -a config3.yaml

# Epoch hyperparams
stabilization_epochs: 1.0
pruning_epochs: 7.0
finetuning_epochs: 7.0

# Learning rate hyperparams
init_lr: 0.0004
final_lr: 0.0001

# Pruning hyperparams
init_sparsity: 0.05
final_sparsity: 0.75

# Stabalization Stage
training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)
  
  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(init_lr)

# Pruning Stage
pruning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    
  - !GlobalMagnitudePruningModifier
    init_sparsity: eval(init_sparsity)
    final_sparsity: eval(final_sparsity)
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    update_frequency: 0.5
    params:        
        - '0.conv1.weight'
        - 're:0.layer1.*.conv1.weight'
        - 're:0.layer1.*.conv2.weight'
        - 're:0.layer1.*.conv3.weight'
        - 're:0.layer1.0.downsample.0.weight'
        - 're:0.layer2.*.conv1.weight'
        - 're:0.layer2.*.conv2.weight'
        - 're:0.layer2.*.conv3.weight'
        - 're:0.layer2.0.downsample.0.weight'
        - 're:0.layer3.*.conv1.weight'
        - 're:0.layer3.*.conv2.weight'
        - 're:0.layer3.*.conv3.weight'
        - 're:0.layer3.0.downsample.0.weight'
        - 're:0.layer4.*.conv1.weight'
        - 're:0.layer4.*.conv2.weight'
        - 're:0.layer4.*.conv3.weight'
        - 're:0.layer4.0.downsample.0.weight'
    leave_enabled: True

# Finetuning Stage
finetuning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs + pruning_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)


Writing config3.yaml


In [21]:
pruning_recipe_path = "/kaggle/working/config3.yaml"

In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(base_model.parameters(), lr=4e-3)


manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
optimizer = manager.modify(base_model, optimizer, steps_per_epoch=len(train_loader))

2024-04-27 07:02:07 sparseml.pytorch.utils.logger INFO     Logging all SparseML modifier-level logs to sparse_logs/27-04-2024_07.02.07.log


In [23]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    print(f"Running Training Epoch {epoch_name}")
    train_log = one_epoch(base_model, train_loader, "train", optimizer, device=DEVICE, epoch_id=epoch)
    print(f"Training Epoch: {epoch_name}\nTraining Log: {train_log}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_log = one_epoch(base_model, val_loader, "test", optimizer, device=DEVICE, epoch_id=epoch)
    print(f"Validation Epoch: {epoch_name}\nValidation Log: {val_log}\n")
        
manager.finalize(base_model)

Running Training Epoch 1/15


Epoch 1 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 1/15
Training Log: {'loss': 0.1231679953880937, 'accuracy': 0.9613372093023256}

Running Validation Epoch 1/15


Epoch 1 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 1/15
Validation Log: {'loss': 1.2161188360056685, 'accuracy': 0.7486033519553073}

Running Training Epoch 2/15


Epoch 2 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 2/15
Training Log: {'loss': 0.07454859348709274, 'accuracy': 0.973546511627907}

Running Validation Epoch 2/15


Epoch 2 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 2/15
Validation Log: {'loss': 1.1614989555361235, 'accuracy': 0.7486033519553073}

Running Training Epoch 3/15


Epoch 3 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 3/15
Training Log: {'loss': 0.054143060485596854, 'accuracy': 0.9819767441860465}

Running Validation Epoch 3/15


Epoch 3 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 3/15
Validation Log: {'loss': 1.2161715898852008, 'accuracy': 0.7541899441340782}

Running Training Epoch 4/15


Epoch 4 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 4/15
Training Log: {'loss': 0.05865261051627851, 'accuracy': 0.9816860465116279}

Running Validation Epoch 4/15


Epoch 4 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 4/15
Validation Log: {'loss': 1.1778810402215374, 'accuracy': 0.7430167597765364}

Running Training Epoch 5/15


Epoch 5 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 5/15
Training Log: {'loss': 0.0802261143511968, 'accuracy': 0.9718023255813953}

Running Validation Epoch 5/15


Epoch 5 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 5/15
Validation Log: {'loss': 1.297752298784497, 'accuracy': 0.7225325884543762}

Running Training Epoch 6/15


Epoch 6 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 6/15
Training Log: {'loss': 0.0964152697397959, 'accuracy': 0.9642441860465116}

Running Validation Epoch 6/15


Epoch 6 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 6/15
Validation Log: {'loss': 2.5714671935009608, 'accuracy': 0.7039106145251397}

Running Training Epoch 7/15


Epoch 7 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 7/15
Training Log: {'loss': 0.11625398937991799, 'accuracy': 0.9607558139534884}

Running Validation Epoch 7/15


Epoch 7 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 7/15
Validation Log: {'loss': 1.4736720979871119, 'accuracy': 0.6964618249534451}

Running Training Epoch 8/15


Epoch 8 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 8/15
Training Log: {'loss': 0.10338873315729785, 'accuracy': 0.9651162790697675}

Running Validation Epoch 8/15


Epoch 8 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 8/15
Validation Log: {'loss': 1.739689396584735, 'accuracy': 0.7057728119180633}

Running Training Epoch 9/15


Epoch 9 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 9/15
Training Log: {'loss': 0.09369476662372066, 'accuracy': 0.9680232558139535}

Running Validation Epoch 9/15


Epoch 9 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 9/15
Validation Log: {'loss': 2.6545438084439787, 'accuracy': 0.7039106145251397}

Running Training Epoch 10/15


Epoch 10 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 10/15
Training Log: {'loss': 0.0863599643486533, 'accuracy': 0.9709302325581395}

Running Validation Epoch 10/15


Epoch 10 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 10/15
Validation Log: {'loss': 1.3422448064255363, 'accuracy': 0.7299813780260708}

Running Training Epoch 11/15


Epoch 11 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 11/15
Training Log: {'loss': 0.07195321180151104, 'accuracy': 0.9802325581395349}

Running Validation Epoch 11/15


Epoch 11 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 11/15
Validation Log: {'loss': 1.409050850526375, 'accuracy': 0.7243947858472998}

Running Training Epoch 12/15


Epoch 12 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 12/15
Training Log: {'loss': 0.060283472305596913, 'accuracy': 0.9802325581395349}

Running Validation Epoch 12/15


Epoch 12 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 12/15
Validation Log: {'loss': 2.757637572748696, 'accuracy': 0.6983240223463687}

Running Training Epoch 13/15


Epoch 13 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 13/15
Training Log: {'loss': 0.05295939687114929, 'accuracy': 0.9828488372093023}

Running Validation Epoch 13/15


Epoch 13 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 13/15
Validation Log: {'loss': 2.8948284517151905, 'accuracy': 0.7094972067039106}

Running Training Epoch 14/15


Epoch 14 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 14/15
Training Log: {'loss': 0.04551049604758804, 'accuracy': 0.9851744186046512}

Running Validation Epoch 14/15


Epoch 14 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 14/15
Validation Log: {'loss': 1.8778793616833462, 'accuracy': 0.7039106145251397}

Running Training Epoch 15/15


Epoch 15 train:   0%|          | 0/430 [00:00<?, ?it/s]

Training Epoch: 15/15
Training Log: {'loss': 0.05710619653664464, 'accuracy': 0.9805232558139535}

Running Validation Epoch 15/15


Epoch 15 test:   0%|          | 0/68 [00:00<?, ?it/s]

Validation Epoch: 15/15
Validation Log: {'loss': 1.4601740945294939, 'accuracy': 0.7132216014897579}



In [24]:
for (name, layer) in get_prunable_layers(base_model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

0.conv1.weight: 0.3969
0.layer1.0.conv1.weight: 0.5352
0.layer1.0.conv2.weight: 0.5834
0.layer1.0.conv3.weight: 0.5789
0.layer1.0.downsample.0.weight: 0.5618
0.layer1.1.conv1.weight: 0.7708
0.layer1.1.conv2.weight: 0.8780
0.layer1.1.conv3.weight: 0.8932
0.layer1.2.conv1.weight: 0.5703
0.layer1.2.conv2.weight: 0.5425
0.layer1.2.conv3.weight: 0.7548
0.layer2.0.conv1.weight: 0.4749
0.layer2.0.conv2.weight: 0.5649
0.layer2.0.conv3.weight: 0.6556
0.layer2.0.downsample.0.weight: 0.6838
0.layer2.1.conv1.weight: 0.8167
0.layer2.1.conv2.weight: 0.8162
0.layer2.1.conv3.weight: 0.8246
0.layer2.2.conv1.weight: 0.6136
0.layer2.2.conv2.weight: 0.6217
0.layer2.2.conv3.weight: 0.6790
0.layer2.3.conv1.weight: 0.5218
0.layer2.3.conv2.weight: 0.5136
0.layer2.3.conv3.weight: 0.6672
0.layer3.0.conv1.weight: 0.4445
0.layer3.0.conv2.weight: 0.6324
0.layer3.0.conv3.weight: 0.6348
0.layer3.0.downsample.0.weight: 0.7661
0.layer3.1.conv1.weight: 0.9581
0.layer3.1.conv2.weight: 0.9708
0.layer3.1.conv3.weight: 0.9

## Benchmark performance of the pruned model

In [29]:
base_model.eval()
with torch.no_grad():
    loss, accuracy = evaluate(base_model, device=DEVICE, loader=val_loader, comment='val:pruned', epoch_id=0)
    print('='*90)
    loss, accuracy = evaluate(base_model, device=DEVICE, loader=test_loader, comment='test:pruned', epoch_id=0)

Epoch 1 test:   0%|          | 0/68 [00:00<?, ?it/s]

val:pruned loss: 1.4605824131080334
val:pruned accuracy: 0.7132216014897579


Epoch 1 test:   0%|          | 0/43 [00:00<?, ?it/s]

test:pruned loss: 1.7651157700131799
test:pruned accuracy: 0.7477744807121661


## Save pruned model

- Note: Since we are doing Unstructured Pruning, the model size will remain the same

In [27]:
torch.save(base_model, "pruned_config3.pth")

## Get average pruning sparsity

In [34]:
all_pruning_values = []

for (name, layer) in get_prunable_layers(base_model):
    all_pruning_values.append(tensor_sparsity(layer.weight).item())
    
print(f"Average sparsity value: {sum(all_pruning_values)/len(all_pruning_values)}")

Average sparsity value: 0.7140883775252216


## Post-training dynamic quantisation

In [35]:
base_model.to("cpu")
model_int8 = torch.ao.quantization.quantize_dynamic(
    base_model,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights
model_int8.eval()
model_int8.to("cpu")

Sequential(
  (0): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

In [37]:
model_int8.eval()
with torch.no_grad():
    loss, accuracy = evaluate(model_int8, device="cpu", loader=val_loader, comment='val:pruned:ptdq', epoch_id=0)
    print('='*90)
    loss, accuracy = evaluate(model_int8, device="cpu", loader=test_loader, comment='test:pruned:ptdq', epoch_id=0)

Epoch 1 test:   0%|          | 0/68 [00:00<?, ?it/s]

val:pruned:ptdq loss: 1.464708717985024
val:pruned:ptdq accuracy: 0.7113594040968343


Epoch 1 test:   0%|          | 0/43 [00:00<?, ?it/s]

test:pruned:ptdq loss: 1.8043933983805567
test:pruned:ptdq accuracy: 0.7388724035608308


## Save Pruned+Quantised model

In [38]:
torch.save(model_int8, "pruned_config3_ptdq.pth")