#### Code Modified from MIT 6.5940

In [None]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('All required packages have been successfully installed!')

In [None]:
import copy
import random
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs

assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.relu1 = nn.ReLU()

        self.maxpool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(6, 16, 5, 1)
        self.relu3 = nn.ReLU()

        self.maxpool4 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.Conv2d(16, 120, 5, 1)
        self.relu5 = nn.ReLU()

        self.fc6 = nn.Linear(120, 10)
        self.relu6 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.maxpool2(x))
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.maxpool4(x))
        x = self.relu5(self.conv5(x)).squeeze()
        x = self.relu6(self.fc6(x))
        return x

In [None]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer and LR scheduler
    optimizer.step()
    scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [None]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

In [None]:
with open('/root/params.bin', 'rb') as f:
    params = np.fromfile(f, dtype=np.float32)
    conv1_weights = torch.from_numpy(params[:150].reshape(6, 1, 5, 5)).cuda()
    conv1_bias = torch.from_numpy(params[150:156].reshape(6)).cuda()
    conv3_weights = torch.from_numpy(params[156:156+2400].reshape(16, 6, 5, 5)).cuda()
    conv3_bias = torch.from_numpy(params[2556:2572].reshape(16)).cuda()
    conv5_weights = torch.from_numpy(params[2572:50572].reshape(120, 16, 5, 5)).cuda()
    conv5_bias = torch.from_numpy(params[50572:50692].reshape(120)).cuda()
    fc6_weights = torch.from_numpy(params[50692:51892].reshape(10, 120)).cuda()
    fc6_bias = torch.from_numpy(params[51892:].reshape(10)).cuda()

with open('/root/images.bin', 'rb') as f:
    images_raw = np.fromfile(f, dtype=np.uint8)
    images_raw = images_raw[16:].reshape(-1, 1, 28, 28)
    images = np.ones((images_raw.shape[0], 1, 32, 32)) * -1
    images[:, :, 2:30, 2:30] = images_raw / 255.0 * 2.0 - 1.0
    images = torch.from_numpy(images).float()

with open('/root/labels.bin', 'rb') as f:
    labels = np.fromfile(f, dtype=np.uint8)
    labels = torch.from_numpy(labels[8:])

model = LeNet().cuda()
model.conv1.weight.data = conv1_weights
model.conv1.bias.data = conv1_bias
model.conv3.weight.data = conv3_weights
model.conv3.bias.data = conv3_bias
model.conv5.weight.data = conv5_weights
model.conv5.bias.data = conv5_bias
model.fc6.weight.data = fc6_weights
model.fc6.bias.data = fc6_bias

def recover_model():
  model = LeNet().cuda()
  model.conv1.weight.data = conv1_weights
  model.conv1.bias.data = conv1_bias
  model.conv3.weight.data = conv3_weights
  model.conv3.bias.data = conv3_bias
  model.conv5.weight.data = conv5_weights
  model.conv5.bias.data = conv5_bias
  model.fc6.weight.data = fc6_weights
  model.fc6.bias.data = fc6_bias

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data,
        labels
    ):
        self.data = data
        self.labels = labels
    def __getitem__(self, idx):
        return (self.data[idx].cuda(), self.labels[idx].cuda())
    def __len__(self):
        return len(self.data)

In [None]:
mnist_dataset = CustomDataset(images, labels)
train_length=int(0.8 * len(mnist_dataset))
test_length=len(mnist_dataset)-train_length

train_dataset, test_dataset = torch.utils.data.random_split(mnist_dataset, \
  (train_length, test_length))

dataset = {}
dataset["train"] = train_dataset
dataset["test"] = test_dataset

dataloader = {}
dataloader["train"] = torch.utils.data.DataLoader(dataset["train"], \
  batch_size=10, num_workers=0)
dataloader["test"] = torch.utils.data.DataLoader(dataset["test"], \
  batch_size=10, num_workers=0)

In [None]:
def get_num_channels_to_keep(channels: int, prune_ratio: float) -> int:
    return int(round(channels * (1 - prune_ratio)))

@torch.no_grad()
def channel_prune(model: nn.Module,
                  prune_ratio: Union[List, float]) -> nn.Module:
    # sanity check of provided prune_ratio
    assert isinstance(prune_ratio, (float, list))
    n_conv = len([m for m in model.modules() if isinstance(m, nn.Conv2d)])
    # note that for the ratios, it affects the previous conv output and next
    # conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...
    if isinstance(prune_ratio, list):
        assert len(prune_ratio) == n_conv - 1
    else:  # convert float to list
        prune_ratio = [prune_ratio] * (n_conv - 1)

    # we prune the convs in the backbone with a uniform ratio
    model = copy.deepcopy(model)  # prevent overwrite
    # we only apply pruning to the backbone features
    all_convs = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    print(all_convs)
    for i_ratio, p_ratio in enumerate(prune_ratio):
        prev_conv = all_convs[i_ratio]
        next_conv = all_convs[i_ratio + 1]
        original_channels = prev_conv.out_channels  # same as next_conv.in_channels
        n_keep = get_num_channels_to_keep(original_channels, p_ratio)

        # prune the output of the previous conv and bn
        prev_conv.weight.set_(prev_conv.weight.detach()[:n_keep])
        prev_conv.bias.set_(prev_conv.bias.detach()[:n_keep])
        next_conv.weight.set_(next_conv.weight.detach()[:, :n_keep])

    return model


In [None]:
# function to sort the channels from important to non-important
def get_input_channel_importance(weight):
    in_channels = weight.shape[1]
    importances = []
    # compute the importance for each input channel
    for i_c in range(weight.shape[1]):
        channel_weight = weight.detach()[:, i_c]
        importance = torch.norm(channel_weight.flatten())
        importances.append(importance.view(1))
    return torch.cat(importances)

@torch.no_grad()
def apply_channel_sorting(model):
    model = copy.deepcopy(model)  # do not modify the original model
    # fetch all the conv and bn layers from the backbone
    all_convs = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
    # iterate through conv layers
    for i_conv in range(len(all_convs) - 1):
        # each channel sorting index, we need to apply it to:
        # - the output dimension of the previous conv
        # - the previous BN layer
        # - the input dimension of the next conv (we compute importance here)
        prev_conv = all_convs[i_conv]
        next_conv = all_convs[i_conv + 1]
        # note that we always compute the importance according to input channels
        importance = get_input_channel_importance(next_conv.weight)
        # sorting from large to small
        sort_idx = torch.argsort(importance, descending=True)

        # apply to previous conv and its following bn
        prev_conv.weight.copy_(torch.index_select(
            prev_conv.weight.detach(), 0, sort_idx))
        prev_conv.bias.copy_(torch.index_select(
            prev_conv.bias.detach(), 0, sort_idx))

        next_conv.weight.copy_(torch.index_select(
            next_conv.weight.detach(), 1, sort_idx))

    return model

In [None]:
channel_pruning_ratio = 0.95  # pruned-out ratio, 0.3, 0.5, 0.7, 0.8, 0.9

print(" * Without sorting...")
pruned_model = channel_prune(model, [0, channel_pruning_ratio])
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")


print(" * With sorting...")
sorted_model = apply_channel_sorting(model)
pruned_model = channel_prune(sorted_model, [0, channel_pruning_ratio])
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")


In [None]:
num_finetune_epochs = 5
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_accuracy = 0
for epoch in range(num_finetune_epochs):
    train(pruned_model, dataloader['train'], criterion, optimizer, scheduler)
    accuracy = evaluate(pruned_model, dataloader['test'])
    is_best = accuracy > best_accuracy
    if is_best:
        best_accuracy = accuracy
    print(f'Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

In [None]:
print("Before pruning")
for name, param in model.named_parameters():
  print(name, param.shape, param.dtype)

print("After pruning")
for name, param in pruned_model.named_parameters():
  print(name, param.shape, param.dtype)

print("Saving pruned parameters")
for name, param in pruned_model.named_parameters():
  print("Saving", name)
  param.cpu().detach().numpy().tofile("{}.bin".format(name))