In [2]:
%%capture
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [6]:
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

[('weight', Parameter containing:
tensor([[[[-0.0160,  0.0219,  0.1897],
          [-0.0759,  0.3223, -0.3320],
          [ 0.0951, -0.1537,  0.1104]]],


        [[[ 0.2471,  0.1803,  0.2617],
          [ 0.2884,  0.0868,  0.2820],
          [-0.1179, -0.3237, -0.0119]]],


        [[[ 0.1588, -0.0374,  0.2655],
          [-0.1122,  0.1555,  0.0783],
          [-0.0316,  0.2062,  0.0100]]],


        [[[-0.2645, -0.1667,  0.1493],
          [-0.1339, -0.1278, -0.1950],
          [-0.1236, -0.1453, -0.1251]]],


        [[[-0.2411, -0.1354,  0.2701],
          [-0.0532, -0.0379,  0.0947],
          [ 0.0527,  0.2622, -0.3064]]],


        [[[ 0.0064,  0.0065, -0.0417],
          [ 0.2583, -0.0291, -0.2060],
          [-0.2816, -0.0771,  0.2594]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1107,  0.0557,  0.1571,  0.1655,  0.2612, -0.0051], device='cuda:0',
       requires_grad=True))]
[]


In [7]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [8]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1107,  0.0557,  0.1571,  0.1655,  0.2612, -0.0051], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.0160,  0.0219,  0.1897],
          [-0.0759,  0.3223, -0.3320],
          [ 0.0951, -0.1537,  0.1104]]],


        [[[ 0.2471,  0.1803,  0.2617],
          [ 0.2884,  0.0868,  0.2820],
          [-0.1179, -0.3237, -0.0119]]],


        [[[ 0.1588, -0.0374,  0.2655],
          [-0.1122,  0.1555,  0.0783],
          [-0.0316,  0.2062,  0.0100]]],


        [[[-0.2645, -0.1667,  0.1493],
          [-0.1339, -0.1278, -0.1950],
          [-0.1236, -0.1453, -0.1251]]],


        [[[-0.2411, -0.1354,  0.2701],
          [-0.0532, -0.0379,  0.0947],
          [ 0.0527,  0.2622, -0.3064]]],


        [[[ 0.0064,  0.0065, -0.0417],
          [ 0.2583, -0.0291, -0.2060],
          [-0.2816, -0.0771,  0.2594]]]], device='cuda:0', requires_grad=True))]


In [9]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 0.],
          [0., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]]], device='cuda:0'))]


In [10]:
print(module.weight)

tensor([[[[-0.0160,  0.0219,  0.1897],
          [-0.0759,  0.3223, -0.3320],
          [ 0.0000, -0.1537,  0.0000]]],


        [[[ 0.2471,  0.1803,  0.2617],
          [ 0.2884,  0.0868,  0.2820],
          [-0.0000, -0.3237, -0.0119]]],


        [[[ 0.0000, -0.0374,  0.0000],
          [-0.1122,  0.0000,  0.0783],
          [-0.0000,  0.2062,  0.0100]]],


        [[[-0.2645, -0.1667,  0.0000],
          [-0.0000, -0.1278, -0.1950],
          [-0.1236, -0.0000, -0.1251]]],


        [[[-0.0000, -0.0000,  0.2701],
          [-0.0532, -0.0379,  0.0000],
          [ 0.0000,  0.2622, -0.0000]]],


        [[[ 0.0064,  0.0065, -0.0417],
          [ 0.2583, -0.0291, -0.2060],
          [-0.2816, -0.0771,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


In [11]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa727e55a00>)])


In [15]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [16]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.0160,  0.0219,  0.1897],
          [-0.0759,  0.3223, -0.3320],
          [ 0.0951, -0.1537,  0.1104]]],


        [[[ 0.2471,  0.1803,  0.2617],
          [ 0.2884,  0.0868,  0.2820],
          [-0.1179, -0.3237, -0.0119]]],


        [[[ 0.1588, -0.0374,  0.2655],
          [-0.1122,  0.1555,  0.0783],
          [-0.0316,  0.2062,  0.0100]]],


        [[[-0.2645, -0.1667,  0.1493],
          [-0.1339, -0.1278, -0.1950],
          [-0.1236, -0.1453, -0.1251]]],


        [[[-0.2411, -0.1354,  0.2701],
          [-0.0532, -0.0379,  0.0947],
          [ 0.0527,  0.2622, -0.3064]]],


        [[[ 0.0064,  0.0065, -0.0417],
          [ 0.2583, -0.0291, -0.2060],
          [-0.2816, -0.0771,  0.2594]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1107,  0.0557,  0.1571,  0.1655,  0.2612, -0.0051], device='cuda:0',
       requires_grad=True))]


In [17]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 0.],
          [0., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 1., 0.], device='cuda:0'))]


In [18]:
print(module.bias)

tensor([-0.0000, 0.0000, 0.1571, 0.1655, 0.2612, -0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)


In [19]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa727e55a00>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fa727e5aee0>)])


In [20]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.1107,  0.0557,  0.1571,  0.1655,  0.2612, -0.0051], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0160,  0.0219,  0.1897],
          [-0.0759,  0.3223, -0.3320],
          [ 0.0000, -0.1537,  0.0000]]],


        [[[ 0.2471,  0.1803,  0.2617],
          [ 0.2884,  0.0868,  0.2820],
          [-0.0000, -0.3237, -0.0119]]],


        [[[ 0.0000, -0.0374,  0.0000],
          [-0.1122,  0.0000,  0.0783],
          [-0.0000,  0.2062,  0.0100]]],


        [[[-0.2645, -0.1667,  0.0000],
          [-0.0000, -0.1278, -0.1950],
          [-0.1236, -0.0000, -0.1251]]],


        [[[-0.0000, -0.0000,  0.2701],
          [-0.0532, -0.0379,  0.0000],
          [ 0.0000,  0.2622, -0.0000]]],


        [[[ 0.0064,  0.0065, -0.0417],
          [ 0.2583, -0.0291, -0.2060],
          [-0.2816, -0.0771,  0.0000]]]], device='cuda:0', requires_grad=True))]


In [21]:
print(list(module.named_buffers()))

[('bias_mask', tensor([0., 0., 1., 1., 1., 0.], device='cuda:0'))]


In [22]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


In [23]:
for name, module in new_model.named_modules():
    try:
        prune.remove(module, 'weight')
        prune.remove(module, 'bias')
    except:
        pass

In [24]:
print(dict(new_model.named_buffers()).keys())

dict_keys([])


In [25]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [27]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)

Sparsity in conv1.weight: 5.56%
Sparsity in conv2.weight: 7.29%
Sparsity in fc1.weight: 22.06%
Sparsity in fc2.weight: 12.23%
Sparsity in fc3.weight: 9.64%


In [28]:
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Global sparsity: 20.00%
