# **Template for Torch-Pruning**

This template is just built for your convinience.

You are not required to follow the steps and method given below.

In [None]:
!pip install --upgrade torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-1.3.7-py3-none-any.whl (56 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/56.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->torch_pruning)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->torch_pruning)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m64.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->torch_pruning)
  Downloading nvidia_cuda_cupti_c

In [None]:
import torch
import torchvision
from torchvision.models import mobilenet_v2
import torch_pruning as tp

##A Minimal Example ##  
In this section, you will perform channel pruning using the library [Torch-Pruning](https://github.com/VainF/Torch-Pruning).  

The puuner in Torch-Pruning has three main functions: sparse training (optional), importance estimation, and parameter removal.  
Torch-pruning offers two core features to support this process:

tp.importance(): This criteria is utilized to measure the importance of weights.  

tp.pruner(): This is a pruner used for the actual pruning of the parameters.  

For detailed information on this process, please refer to this [tutorial](https://github.com/VainF/Torch-Pruning/wiki/4.-High%E2%80%90level-Pruners/). Additionally, a more practical example is available in [here](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/main.py).

### 1. Load model


In [None]:
model = torch.load('./mobilenetv2_0.963.pth', map_location="cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### 2. Prepare a pruner
By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [None]:
# Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

ignored_layers = []
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == 10: # ignore the classifier
    ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner ( # you can choose any pruner you like.
    model,
    example_inputs,
    importance=imp,   # Importance Estimator
    pruning_ratio=0.5, # remove 50% channels, ex :ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    # pruning_ratio_dict = {model.conv1: 0.2}, # manually set the sparsity of model.conv1
    iterative_steps = 1,  # number of steps to achieve the target ch_sparsity.
    ignored_layers = ignored_layers,        # ignore some layers such as the finall linear classifier
    channel_groups = channel_groups,  # round channels
)

### 3. Prune the model

In [None]:
# Model size before pruning
base_macs, base_nparams = tp.utils.count_ops_and_params(model, torch.randn(1,3,224,224))

model.eval()

if isinstance(imp, tp.importance.GroupTaylorImportance):
  # Taylor expansion requires gradients for importance estimation
  loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
  loss.backward() # before pruner.step()

# prune
pruner.step()

# Parameter & MACs Counter
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, torch.randn(1,3,224,224))
MFLOPs = pruned_macs/1e6
print("The pruned model:")
print(model)
print("Summary:")
print("MFLOPs: ")
print(MFLOPs)

# finetune the pruned model here
# finetune(model)
# ...