<a href="https://colab.research.google.com/github/yating-zh/model_compression/blob/main/MNIST_pruning0507.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [87]:
!pip install torch  torchvision==0.16.0 torchaudio==2.1.0

# An error must downgrade to torch 2.1.0
# AttributeError: 'NoneType' object has no attribute 'startswith' at the SpeedUp step

Collecting torchvision==0.16.0
  Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.1.0
  Downloading torchaudio-2.1.0-cp310-cp310-manylinux1_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
Collecting torch
  Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105

In [88]:
!pip install nni

^C


In [1]:
# import pytorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.optim import SGD
from nni.compression.pruning import L1NormPruner

import numpy as np

print(torch.__version__)

2.1.0+cu121


## 1. Pretrain a model using MNIST dataset

In [2]:
# Optional to run code on GPU
# Check if CUDA is available and if device is GPU
print('Cuda Available : {}'.format(torch.cuda.is_available()))
if torch.cuda.is_available():
    print('GPU - {0}'.format(torch.cuda.get_device_name()))

Cuda Available : True
GPU - Tesla T4


In [10]:
# Define the CNN model,


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 3
batch_size = 64
learning_rate = 0.01

# MNIST dataset
train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

# Data loader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Define a convolutional neural network model for MNIST
class TorchModel(nn.Module):
    def __init__(self):
        super(TorchModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
        self.fc1 = nn.Linear(16 * 4 * 4, 120)  # Adjusted to match output of conv2
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.relu4 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x

model = TorchModel().to(device)
print(model)


TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [11]:
# sava the model before compression
torch.save(model.state_dict(), 'original_model.pth')

In [4]:

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate)

# Function to train the model
def train(model, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Function to evaluate the model
def evaluate(model, test_loader):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')

# Training and evaluation loop
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion)
    evaluate(model, test_loader)

Accuracy: 57.70%
Accuracy: 91.18%
Accuracy: 94.26%


In [5]:
# Pruning configuration
config_list = [{
    'op_types': ['Linear', 'Conv2d'],
    'exclude_op_names': ['fc3'],
    'sparse_ratio': 0.5
}]

# Apply L1NormPruner
pruner = L1NormPruner(model, config_list)
# model = pruner.compress()[0]
# print(pruner)
# print(model)



In [6]:
# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

conv2  sparsity :  0.5
conv1  sparsity :  0.5
fc1  sparsity :  0.5
fc2  sparsity :  0.5


In [7]:
# need to unwrap the model, if the model is wrapped before speedup
pruner.unwrap_model()

# speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
from nni.compression.speedup import ModelSpeedup
# from nni.compression.torch import ModelSpeedup


m_speedup = ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks)
m_speedup.speedup_model()


# ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()
# (3, 1, 28, 28) in the code represents the dimensions of a tensor

# 3: The number of data samples in the batch. This means that the input consists of 3 separate images being processed simultaneously.
# 1: The number of channels in each image. For grayscale images, such as those typically used in the MNIST dataset, this number is 1. If it were a color image in a standard RGB format, this number would be 3.
# 28, 28: The dimensions of each image. In the case of the MNIST dataset, each image is 28 pixels wide by 28 pixels high.


[2024-05-07 09:24:48] [32mStart to speedup the model...[0m


INFO:nni.compression.speedup.model_speedup:Start to speedup the model...


[2024-05-07 09:24:48] [32mResolve the mask conflict before mask propagate...[0m


INFO:nni.compression.speedup.model_speedup:Resolve the mask conflict before mask propagate...


[2024-05-07 09:24:48] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-07 09:24:48] [32mdim1 sparsity: 0.000000[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.000000


0 Filter
[2024-05-07 09:24:48] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-07 09:24:48] [32mdim1 sparsity: 0.000000[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.000000


[2024-05-07 09:24:48] [32mInfer module masks...[0m


INFO:nni.compression.speedup.model_speedup:Infer module masks...


[2024-05-07 09:24:48] [32mPropagate original variables[0m


INFO:nni.compression.speedup.model_speedup:Propagate original variables


[2024-05-07 09:24:48] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for placeholder: x, output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: relu1, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu1, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: pool1, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: pool1, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: relu2, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu2, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: pool2, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: pool2, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_method: view, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_method: view, output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: relu3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu3, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: relu4, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: relu4, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for call_module: fc3, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mPropagate variables for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Propagate variables for output: output, output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate direct sparsity...[0m


INFO:nni.compression.speedup.model_speedup:Update direct sparsity...


[2024-05-07 09:24:48] [32mUpdate direct mask for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for placeholder: x, output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: relu1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu1, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: pool1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: pool1, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: conv2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: relu2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu2, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: pool2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: pool2, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_method: view, output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_method: view, output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: relu3, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu3, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc2, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: relu4, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: relu4, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate direct mask for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for call_module: fc3, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate direct mask for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update direct mask for output: output, output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate indirect sparsity...[0m


INFO:nni.compression.speedup.model_speedup:Update indirect sparsity...


[2024-05-07 09:24:48] [32mUpdate indirect mask for output: output, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for output: output, output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: fc3, , output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc3, , output mask:  0.0000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: relu4, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu4, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: fc2, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc2, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: relu3, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu3, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: fc1, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: fc1, weight:  0.7500 bias:  0.5000 , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_method: view, output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_method: view, output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: pool2, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: pool2, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: relu2, , output mask:  0.5498 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu2, , output mask:  0.5498 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: conv2, weight:  0.7500 bias:  0.5000 , output mask:  0.5498 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: conv2, weight:  0.7500 bias:  0.5000 , output mask:  0.5498 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: pool1, , output mask:  0.5000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: pool1, , output mask:  0.5000 


[2024-05-07 09:24:48] [32mUpdate indirect mask for call_module: relu1, , output mask:  0.5486 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: relu1, , output mask:  0.5486 


[2024-05-07 09:24:49] [32mUpdate indirect mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5486 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5486 


[2024-05-07 09:24:49] [32mUpdate indirect mask for placeholder: x, output mask:  0.0000 [0m


INFO:nni.compression.speedup.model_speedup:Update indirect mask for placeholder: x, output mask:  0.0000 


[2024-05-07 09:24:49] [32mResolve the mask conflict after mask propagate...[0m


INFO:nni.compression.speedup.model_speedup:Resolve the mask conflict after mask propagate...


[2024-05-07 09:24:49] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-07 09:24:49] [32mdim1 sparsity: 0.428571[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.428571






0 Filter
[2024-05-07 09:24:49] [32mdim0 sparsity: 0.500000[0m


INFO:nni.compression.speedup.mask_conflict:dim0 sparsity: 0.500000


[2024-05-07 09:24:49] [32mdim1 sparsity: 0.428571[0m


INFO:nni.compression.speedup.mask_conflict:dim1 sparsity: 0.428571






[2024-05-07 09:24:49] [32mReplace compressed modules...[0m


INFO:nni.compression.speedup.model_speedup:Replace compressed modules...


[2024-05-07 09:24:49] [32mreplace module (name: conv1, op_type: Conv2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: conv1, op_type: Conv2d)


[2024-05-07 09:24:49] [32mreplace conv2d with in_channels: 1, out_channels: 3[0m


INFO:nni.compression.speedup.replacement:replace conv2d with in_channels: 1, out_channels: 3


[2024-05-07 09:24:49] [32mreplace module (name: relu1, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu1, op_type: ReLU)


[2024-05-07 09:24:49] [32mreplace module (name: pool1, op_type: MaxPool2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: pool1, op_type: MaxPool2d)


[2024-05-07 09:24:49] [32mreplace module (name: conv2, op_type: Conv2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: conv2, op_type: Conv2d)


[2024-05-07 09:24:49] [32mreplace conv2d with in_channels: 3, out_channels: 8[0m


INFO:nni.compression.speedup.replacement:replace conv2d with in_channels: 3, out_channels: 8


[2024-05-07 09:24:49] [32mreplace module (name: relu2, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu2, op_type: ReLU)


[2024-05-07 09:24:49] [32mreplace module (name: pool2, op_type: MaxPool2d)[0m


INFO:nni.compression.speedup.replacer:replace module (name: pool2, op_type: MaxPool2d)


[2024-05-07 09:24:49] [32mreplace module (name: fc1, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc1, op_type: Linear)


[2024-05-07 09:24:49] [32mreplace linear with new in_features: 128, out_features: 60[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 128, out_features: 60


[2024-05-07 09:24:49] [32mreplace module (name: relu3, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu3, op_type: ReLU)


[2024-05-07 09:24:49] [32mreplace module (name: fc2, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc2, op_type: Linear)


[2024-05-07 09:24:49] [32mreplace linear with new in_features: 60, out_features: 42[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 60, out_features: 42


[2024-05-07 09:24:49] [32mreplace module (name: relu4, op_type: ReLU)[0m


INFO:nni.compression.speedup.replacer:replace module (name: relu4, op_type: ReLU)


[2024-05-07 09:24:49] [32mreplace module (name: fc3, op_type: Linear)[0m


INFO:nni.compression.speedup.replacer:replace module (name: fc3, op_type: Linear)


[2024-05-07 09:24:49] [32mreplace linear with new in_features: 42, out_features: 10[0m


INFO:nni.compression.speedup.replacement:replace linear with new in_features: 42, out_features: 10


[2024-05-07 09:24:49] [32mSpeedup done.[0m


INFO:nni.compression.speedup.model_speedup:Speedup done.


TorchModel(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=128, out_features=60, bias=True)
  (fc2): Linear(in_features=60, out_features=42, bias=True)
  (fc3): Linear(in_features=42, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [8]:
print(model)
# 这里是pruned model， 经过prunning：
# layer的数量和layer的类型都没有变化
# 但是由于prune掉了一些weights，所以layer的output weights的个数有减少，也是因此，TorchModel()变了，因此最后要测量eval就需要重新定义TorchModel()

TorchModel(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=128, out_features=60, bias=True)
  (fc2): Linear(in_features=60, out_features=42, bias=True)
  (fc3): Linear(in_features=42, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [9]:
# sava the model after compression
torch.save(model.state_dict(), 'compressed_model.pth')

## 2. Pruning model using NNI library (compression)