<a href="https://colab.research.google.com/github/yating-zh/model_compression/blob/main/Copy_of_pruning_quick_start.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Pruning Quickstart

Model pruning is a technique to reduce the model size and computation by reducing model weight size or intermediate state size.
There are three common practices for pruning a DNN model:

#. Pre-training a model -> Pruning the model -> Fine-tuning the pruned model
#. Pruning a model during training (i.e., pruning aware training) -> Fine-tuning the pruned model
#. Pruning a model -> Training the pruned model from scratch

NNI supports all of the above pruning practices by working on the key pruning stage.
Following this tutorial for a quick look at how to use NNI to prune a model in a common practice.


## Preparation

In this tutorial, we use a simple model and pre-trained on MNIST dataset.
If you are familiar with defining a model and training in pytorch, you can skip directly to `Pruning Model`_.



In [None]:
# Downgrade the Python version to 2.1.0. Otherwise the Speedup does not work.
!pip install torch  torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda==12.1



In [None]:
# install nni (Neural Network Intelligence)
! pip install nni



In [None]:
import torch
import torch.nn.functional as F
from torch.optim import SGD

from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device

# define the model
model = TorchModel().to(device)

# show the model structure, note that pruner will wrap the model layer.
print(model)

TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)


In [None]:
# sava the model before compression
torch.save(model.state_dict(), 'original_model.pth')

In [None]:
# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2) # 0.01 is the learning rate
criterion = F.nll_loss


In [None]:
# Start timing
import time
start_time = time.time()

In [None]:
# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

Average test loss: 0.5561, Accuracy: 8485/10000 (85%)
Average test loss: 0.2397, Accuracy: 9281/10000 (93%)
Average test loss: 0.1655, Accuracy: 9495/10000 (95%)


In [None]:
# End timing
end_time = time.time()
total_time_original = end_time - start_time


## Pruning Model

Using L1NormPruner to prune the model and generate the masks.
Usually, a pruner requires original model and ``config_list`` as its inputs.
Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/config_list>`.

The following `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,
except the layer named `fc3`, because `fc3` is `exclude`.
The final sparsity ratio for each layer is 50%. The layer named `fc3` will not be pruned.



In [None]:
config_list = [{
    'op_types': ['Linear', 'Conv2d'],
    'exclude_op_names': ['fc3'],
    'sparse_ratio': 0.5
}]

Pruners usually require `model` and `config_list` as input arguments.



In [None]:
# create a wrapper, and in order to apply masks for each layer in that wrapper, 这里的pruner就是一个wrapper
from nni.compression.pruning import L1NormPruner
pruner = L1NormPruner(model, config_list) # pruner ~= wrapper


# show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
print(model)

TorchModel(
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)), module_name=conv1)
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)), module_name=conv2)
  )
  (fc1): Linear(
    in_features=256, out_features=120, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=256, out_features=120, bias=True), module_name=fc1)
  )
  (fc2): Linear(
    in_features=120, out_features=84, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=120, out_features=84, bias=True), module_name=fc2)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), str

In [None]:
# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

fc1  sparsity :  0.5
conv1  sparsity :  0.5
fc2  sparsity :  0.5
conv2  sparsity :  0.5


Speedup the original model with masks, note that `ModelSpeedup` requires an unwrapped model.
The model becomes smaller after speedup,
and reaches a higher sparsity ratio because `ModelSpeedup` will propagate the masks across layers.



In [None]:
# need to unwrap the model, if the model is wrapped before speedup
pruner.unwrap_model()

# speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
from nni.compression.speedup import ModelSpeedup
# from nni.compression.torch import ModelSpeedup


m_speedup = ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks)
m_speedup.speedup_model()


# ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()
# (3, 1, 28, 28) in the code represents the dimensions of a tensor

# 3: The number of data samples in the batch. This means that the input consists of 3 separate images being processed simultaneously.
# 1: The number of channels in each image. For grayscale images, such as those typically used in the MNIST dataset, this number is 1. If it were a color image in a standard RGB format, this number would be 3.
# 28, 28: The dimensions of each image. In the case of the MNIST dataset, each image is 28 pixels wide by 28 pixels high.


[2024-05-06 02:21:02] [32mStart to speedup the model...[0m


INFO:nni.compression.speedup.model_speedup:Start to speedup the model...


[2024-05-06 02:21:02] [32mResolve the mask conflict before mask propagate...[0m


INFO:nni.compression.speedup.model_speedup:Resolve the mask conflict before mask propagate...


[2024-05-06 02:21:02] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-06 02:21:02] [32mdim1 sparsity: 0.000000[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.000000


0 Filter
[2024-05-06 02:21:02] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-06 02:21:02] [32mdim1 sparsity: 0.000000[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.000000


[2024-05-06 02:21:02] [32mInfer module masks...[0m


INFO:nni.compression.speedup.model_speedup:Infer module masks...


[2024-05-06 02:21:02] [32mPropagate original variables[0m


INFO:nni.compression.speedup.model_speedup:Propagate original variables


[2024-05-06 02:21:02] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for placeholder: x, output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: relu1, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu1, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: pool1, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: pool1, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: relu2, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu2, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: pool2, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: pool2, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_function: flatten, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_function: flatten, output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: relu3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu3, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: relu4, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu4, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc3, , output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for call_function: log_softmax, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_function: log_softmax, output mask:  0.0000 


[2024-05-06 02:21:02] [32mPropagate variables for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for output: output, output mask:  0.0000 


[2024-05-06 02:21:02] [32mUpdate direct sparsity...[0m


INFO:nni.compression.speedup.model_speedup:Update direct sparsity...


[2024-05-06 02:21:02] [32mUpdate direct mask for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for placeholder: x, output mask:  0.0000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: relu1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu1, , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: pool1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: pool1, , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: relu2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu2, , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_module: pool2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: pool2, , output mask:  0.5000 


[2024-05-06 02:21:02] [32mUpdate direct mask for call_function: flatten, output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_function: flatten, output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_module: relu3, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu3, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_module: relu4, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu4, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc3, , output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate direct mask for call_function: log_softmax, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_function: log_softmax, output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate direct mask for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for output: output, output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate indirect sparsity...[0m


INFO:nni.compression.speedup.model_speedup:Update indirect sparsity...


[2024-05-06 02:21:03] [32mUpdate indirect mask for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for output: output, output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_function: log_softmax, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_function: log_softmax, output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc3, , output mask:  0.0000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: relu4, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu4, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: fc2, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc2, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: relu3, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu3, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: fc1, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc1, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_function: flatten, output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_function: flatten, output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: pool2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: pool2, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: relu2, , output mask:  0.5391 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu2, , output mask:  0.5391 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: conv2, weight:  0.7500 bias:  0.5000 , output mask:  0.5391 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: conv2, weight:  0.7500 bias:  0.5000 , output mask:  0.5391 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: pool1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: pool1, , output mask:  0.5000 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: relu1, , output mask:  0.5475 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu1, , output mask:  0.5475 


[2024-05-06 02:21:03] [32mUpdate indirect mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5475 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5475 


[2024-05-06 02:21:03] [32mUpdate indirect mask for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for placeholder: x, output mask:  0.0000 


[2024-05-06 02:21:03] [32mResolve the mask conflict after mask propagate...[0m


INFO:nni.compression.speedup.model_speedup:Resolve the mask conflict after mask propagate...


[2024-05-06 02:21:03] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-06 02:21:03] [32mdim1 sparsity: 0.428571[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.428571






0 Filter
[2024-05-06 02:21:03] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-06 02:21:03] [32mdim1 sparsity: 0.428571[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.428571






[2024-05-06 02:21:03] [32mReplace compressed modules...[0m


INFO:nni.compression.speedup.model_speedup:Replace compressed modules...


[2024-05-06 02:21:03] [32mreplace module (name: conv1, op_type: Conv2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: conv1, op_type: Conv2d)


[2024-05-06 02:21:03] [32mreplace conv2d with in_channels: 1, out_channels: 3[0m


INFO:nni.compression.speedup.replacement:replace conv2d with in_channels: 1, out_channels: 3


[2024-05-06 02:21:03] [32mreplace module (name: relu1, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu1, op_type: ReLU)


[2024-05-06 02:21:03] [32mreplace module (name: pool1, op_type: MaxPool2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: pool1, op_type: MaxPool2d)


[2024-05-06 02:21:03] [32mreplace module (name: conv2, op_type: Conv2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: conv2, op_type: Conv2d)


[2024-05-06 02:21:03] [32mreplace conv2d with in_channels: 3, out_channels: 8[0m


INFO:nni.compression.speedup.replacement:replace conv2d with in_channels: 3, out_channels: 8


[2024-05-06 02:21:03] [32mreplace module (name: relu2, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu2, op_type: ReLU)


[2024-05-06 02:21:03] [32mreplace module (name: pool2, op_type: MaxPool2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: pool2, op_type: MaxPool2d)


[2024-05-06 02:21:03] [32mreplace module (name: fc1, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc1, op_type: Linear)


[2024-05-06 02:21:03] [32mreplace linear with new in_features: 128, out_features: 60[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 128, out_features: 60


[2024-05-06 02:21:03] [32mreplace module (name: relu3, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu3, op_type: ReLU)


[2024-05-06 02:21:03] [32mreplace module (name: fc2, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc2, op_type: Linear)


[2024-05-06 02:21:03] [32mreplace linear with new in_features: 60, out_features: 42[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 60, out_features: 42


[2024-05-06 02:21:03] [32mreplace module (name: relu4, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu4, op_type: ReLU)


[2024-05-06 02:21:03] [32mreplace module (name: fc3, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc3, op_type: Linear)


[2024-05-06 02:21:03] [32mreplace linear with new in_features: 42, out_features: 10[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 42, out_features: 10


[2024-05-06 02:21:03] [32mSpeedup done.[0m


INFO:nni.compression.speedup.model_speedup:Speedup done.


TorchModel(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=128, out_features=60, bias=True)
  (fc2): Linear(in_features=60, out_features=42, bias=True)
  (fc3): Linear(in_features=42, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)

the model will become real smaller after speedup



In [None]:
print(model)
# 这里是pruned model， 经过prunning：
# layer的数量和layer的类型都没有变化
# 但是由于prune掉了一些weights，所以layer的output weights的个数有减少，也是因此，TorchModel()变了，因此最后要测量eval就需要重新定义TorchModel()

TorchModel(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=128, out_features=60, bias=True)
  (fc2): Linear(in_features=60, out_features=42, bias=True)
  (fc3): Linear(in_features=42, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)


In [None]:
# sava the model after compression
torch.save(model.state_dict(), 'compressed_model.pth')

## Fine-tuning Compacted Model
Note that if the model has been sped up, you need to re-initialize a new optimizer for fine-tuning.
Because speedup will replace the masked big layers with dense small ones.



In [None]:
# Start timing
start_time = time.time()

In [None]:
optimizer = SGD(model.parameters(), 1e-2)
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

Average test loss: 0.2120, Accuracy: 9388/10000 (94%)
Average test loss: 0.1459, Accuracy: 9568/10000 (96%)
Average test loss: 0.1438, Accuracy: 9558/10000 (96%)


In [None]:
# End timing
end_time = time.time()
total_time_compressed = end_time - start_time


##  Validation: Model compression
1. the model size
2. execution time


1. the model size

In [None]:
import os

size_original = os.path.getsize('original_model.pth')
size_compressed = os.path.getsize('compressed_model.pth')
print(f'Original Model Size: {size_original} bytes')
print(f'Compressed Model Size: {size_compressed} bytes')
print(f'Reduction in Size: {size_original - size_compressed} bytes')


Original Model Size: 181466 bytes
Compressed Model Size: 49334 bytes
Reduction in Size: 132132 bytes


2. execution time

In [None]:
training_time_reduction=total_time_original-total_time_compressed
print(f'Original Model Execution Time: {total_time_original} (s)')
print(f'Compressed Model Execution Time: {total_time_compressed} (s)')
print(f'Reduction in Execution Time: {training_time_reduction} (s)')

Original Model Execution Time: 91.71007585525513 (s)
Compressed Model Execution Time: 75.06762886047363 (s)
Reduction in Execution Time: 16.642446994781494 (s)


In [None]:
# execution time

import time

# Function to measure inference time
def measure_inference_time(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    start_time = time.time()
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.to(device)
            _ = model(data)
    end_time = time.time()
    return end_time - start_time

# Assuming data_loader is defined and contains the MNIST test dataset
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

# Prepare DataLoader for performance test
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
data_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Measure inference time before compression
time_original = measure_inference_time(model, data_loader, device)

# Load compressed model for testing

# compressed_model = TorchModel().to(device)
# compressed_model.load_state_dict(torch.load('compressed_model.pth'))



In [None]:
# Define the model parameters as a dictionary
model_params = {
    "conv1_out_channels": 3,  # Adjusted output channels after compression
    "conv2_out_channels": 8,
    "fc1_out_features": 60,
    "fc2_out_features": 42,
    "fc3_out_features": 10
}

In [None]:
import torch.nn as nn

class TorchModel(nn.Module):
    def __init__(self, conv1_out_channels=6, conv2_out_channels=16, fc1_out_features=120, fc2_out_features=84, fc3_out_features=10):
        super(TorchModel, self).__init__()
        self.conv1 = nn.Conv2d(1, conv1_out_channels, kernel_size=(5, 5), stride=(1, 1))
        self.conv2 = nn.Conv2d(conv1_out_channels, conv2_out_channels, kernel_size=(5, 5), stride=(1, 1))
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Adjust input features based on the output from the last pool2 layer
        self.fc1 = nn.Linear(conv2_out_channels * 4 * 4, fc1_out_features)
        self.fc2 = nn.Linear(fc1_out_features, fc2_out_features)
        self.fc3 = nn.Linear(fc2_out_features, fc3_out_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x



compressed_model = TorchModel(
    conv1_out_channels=model_params['conv1_out_channels'],
    conv2_out_channels=model_params['conv2_out_channels'],
    fc1_out_features=model_params['fc1_out_features'],
    fc2_out_features=model_params['fc2_out_features'],
    fc3_out_features=model_params['fc3_out_features']
).to(device)


compressed_model.load_state_dict(torch.load('compressed_model.pth'))


<All keys matched successfully>

In [None]:

# Measure inference time after compression
time_compressed = measure_inference_time(compressed_model, data_loader, device)

print(f'Original Inference Time: {time_original} seconds')
print(f'Compressed Inference Time: {time_compressed} seconds')


Original Inference Time: 2.6704070568084717 seconds
Compressed Inference Time: 3.768873929977417 seconds
