# Customize Your Own Pruners

Torch-pruning is a scalable tool that enables you to create your own pruners with customized importance criteria and pruning schemes. For instance, you can use torch-pruning to implement the [Slimming pruner](https://arxiv.org/abs/1708.06519), which utilizes the scaling parameters in Batch Normalization (BN) to identify and remove unimportant channels. 

In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp

### 1. Pruner Definition

Slimming Pruner leverages the scaling factor in Batch Normalization (BN) layers to determine the importance score of different channels. This technique follows a "training-pruning-fine-tuning" paradigm, which involves sparse training of the original model. In Torch-Pruning, the base class ``tp.pruner.MetaPruner`` provides a convenient ``.regularize(model)`` method for sparse training. Our first task is to implement such an interface to enable efficient regularization of BN parameters.

In [2]:
class MySlimmingPruner(tp.pruner.MetaPruner):
    def regularize(self, model, reg):
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True:
                m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) # Lasso for sparsity

### 2. Importance function
Now, we need a new importance criterion for slimming, which compares the magnitude of the scaling parameter in BN. In this work, importance criterion is a callable function or object which accept a group ``tp.PruningGroup`` as inputs. ``tp.PruningGroup`` records all coupled layers as well as their pruning indices. We can scan the group to design our own importance function as follows:

In [6]:
class MySlimmingImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        #note that we have multiple BNs in a group, 
        # we store layer-wise scores in a list and then reduce them to get the final results
        group_imp = [] # (num_bns, num_channels) 
        # 1. iterate the group to estimate importance
        for dep, idxs in group:
            layer = dep.target.module # get the target model
            prune_fn = dep.handler    # get the pruning function of target model, unused in this example
            if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and layer.affine:
                local_imp = torch.abs(layer.weight.data)
                group_imp.append(local_imp)
        if len(group_imp)==0: return None # return None if the group contains no BN layer
        # 2. reduce your group importance to a 1-D scroe vector. Here we use the average score across layers.
        group_imp = torch.stack(group_imp, dim=0).mean(dim=0) 
        return group_imp # (num_channels, )

# You can implement any importance functions, as long as it transforms a group to a 1-D score vector.
class RandomImportance(tp.importance.Importance):
    @torch.no_grad()
    def __call__(self, group, **kwargs):
        _, idxs = group[0]
        return torch.rand(len(idxs))

### 3. Pruning
Now let's leverage the customized pruner to slim a resnet-18 models

In [7]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

In [8]:
# 0. importance criterion 
imp = MySlimmingImportance()

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

# 2. Pruner initialization
iterative_steps = 5 # You can prune your model to the target pruning ratio iteratively.
pruner = MySlimmingPruner(
    model, 
    example_inputs, 
    global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
    importance=imp, # importance criterion for parameter selection
    iterative_steps=iterative_steps, # the number of iterations to achieve target pruning ratio
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

Sparse training with ``pruner.regularize``. Rember to regularize the model before ``optimizer.step()``.

In [6]:
# Training
for _ in range(100):
    pass
    # optimizer.zero_grad()
    # ...
    # loss.backward()
    # pruner.regularize(model, reg=1e-5)
    # optimizer.step()

Pruning and finetuning

In [7]:
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()

    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    print("="*16)
    # finetune your model here
    # finetune(model)
    # ...

ResNet(
  (conv1): Conv2d(3, 57, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

torch.Size([1, 1000])
  Iter 2/5, Params: 11.69 M => 7.53 M
  Iter 2/5, MACs: 1.82 G => 1.18 G
ResNet(
  (conv1): Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(44, eps=

torch.Size([1, 1000])
  Iter 4/5, Params: 11.69 M => 4.32 M
  Iter 4/5, MACs: 1.82 G => 0.68 G
ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=