In [12]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

from transformers import AutoModelForSequenceClassification

In [35]:
model = AutoModelForSequenceClassification.from_pretrained('bert-large-uncased')

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

In [36]:
class Pruner():
    def __init__(self, model, layer_type, pct):
        self.model = model
        self.layer_type = layer_type
        self.pct = pct
        self._method_type = {'random': prune.RandomUnstructured,
                       'l1': prune.L1Unstructured,
                       'ln': prune.LnStructured
                      }
        self._parameters_to_prune = []
        self._parameters_to_prune_names = []
        
        _lin_cnt, _conv_cnt = 0, 0
        for name, module in model.named_modules():
            if 'linear' in layer_type:
                if isinstance(module, nn.Linear):
                    self._parameters_to_prune.append((module, 'weight'))
                    self._parameters_to_prune_names.append(name)
                    _lin_cnt += 1
            if 'conv' in layer_type:
                if isinstance(module, nn.Conv2d):
                    self._parameters_to_prune.append((module, 'weight'))
                    self._parameters_to_prune_names.append(name)
                    _conv_cnt += 1
        print("Detected {} Linear layers".format(_lin_cnt))
        print("Detected {} Conv layers".format(_conv_cnt))
        
    def perform_pruning(self, method, **kwargs):
        chosen_method = self._method_type[method]
        prune.global_unstructured(
            self._parameters_to_prune,
            pruning_method=chosen_method,
            amount=self.pct,
        )
        
    def make_permanent(self):
        for module in self._parameters_to_prune:
            prune.remove(module[0], 'weight')

In [37]:
prune_proc = Pruner(model, ['linear'], 0.675)

Detected 146 Linear layers
Detected 0 Conv layers


In [38]:
prune_proc.perform_pruning('random')

In [43]:
prune_proc.make_permanent()

In [44]:
model = prune_proc.model

In [64]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet()

In [66]:
prune_proc = Pruner(model, ['conv', 'linear'], 0.675)

Detected 3 Linear layers
Detected 2 Conv layers


In [67]:
prune_proc.perform_pruning('random')

In [69]:
len(prune_proc.model.conv1._forward_pre_hooks)==1

1

In [70]:
prune_proc.make_permanent()

In [71]:
len(prune_proc.model.conv1._forward_pre_hooks)==0

0