In [63]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [64]:
device =torch.device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device=",device)
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)
print(f"model",model)

device= cuda
model LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [65]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0354,  0.0517, -0.1372, -0.1993,  0.0417],
          [-0.0969, -0.1600, -0.1861,  0.0812, -0.0993],
          [ 0.1761,  0.1753,  0.0174, -0.1056,  0.0304],
          [-0.0392,  0.0888, -0.0373,  0.0320, -0.1884],
          [ 0.0716, -0.1792,  0.0240, -0.1539, -0.1080]]],


        [[[ 0.1978,  0.0198,  0.0279,  0.1849,  0.1703],
          [ 0.1940, -0.0396,  0.1681,  0.0322,  0.0005],
          [ 0.1948, -0.0355, -0.1441,  0.0781,  0.1440],
          [-0.1053, -0.1539,  0.1229,  0.0063,  0.0485],
          [ 0.0616, -0.1104,  0.1248,  0.0036,  0.1073]]],


        [[[-0.0493,  0.1004,  0.1612,  0.0756,  0.1108],
          [ 0.1421,  0.0612,  0.1483,  0.0781,  0.1538],
          [ 0.1083, -0.0422, -0.0057, -0.1140, -0.0992],
          [ 0.0655,  0.1091,  0.0163,  0.0197, -0.0312],
          [-0.1367,  0.1911, -0.1769,  0.1613, -0.1197]]],


        [[[ 0.1183, -0.1553, -0.0313,  0.0050, -0.1785],
          [ 0.1213,  0.0587, -0.0251, -0.0

In [66]:
print(list(module.named_buffers()))

[]


In [61]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [62]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1152, -0.0103,  0.0359,  0.1339, -0.0496,  0.0860], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1024, -0.1389,  0.1927,  0.1138, -0.0964],
          [ 0.1144,  0.0619,  0.1473, -0.1015, -0.1579],
          [-0.1148, -0.1109,  0.1737,  0.1895,  0.1797],
          [-0.1551, -0.0859, -0.0569, -0.1433, -0.1217],
          [ 0.0060, -0.1359, -0.0809,  0.1003, -0.0180]]],


        [[[-0.0439,  0.1643,  0.1914, -0.0830, -0.1992],
          [ 0.1197,  0.1506, -0.1293,  0.1289, -0.1525],
          [ 0.0532, -0.1427,  0.1831, -0.0189, -0.1365],
          [ 0.1790, -0.0439,  0.1290, -0.1302, -0.1978],
          [ 0.1035, -0.0024,  0.0538,  0.0215, -0.0246]]],


        [[[-0.0109,  0.1125, -0.1276,  0.1830, -0.0021],
          [ 0.0772, -0.1924,  0.1190,  0.1248,  0.1855],
          [-0.1774,  0.1524,  0.1892,  0.1927, -0.1238],
          [ 0.0920,  0.0325, -0.1764, -0.0742, -0.1683],
          [-0.1

In [52]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 1., 0., 1.],
          [0., 1., 1., 0., 1.],
          [0., 0., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 0., 0., 0., 1.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [1., 0., 0., 1., 0.]]]], 

In [53]:
print(module.weight)

tensor([[[[-0.1141, -0.0872, -0.0636, -0.0000, -0.1176],
          [-0.0661,  0.0158,  0.0069, -0.0000,  0.1685],
          [-0.0000, -0.0427,  0.1277, -0.0882, -0.1116],
          [-0.0000,  0.1149, -0.0631, -0.0841,  0.1209],
          [-0.0792, -0.1951,  0.0000, -0.0291, -0.1651]]],


        [[[ 0.1238,  0.0000,  0.1294,  0.0337,  0.1134],
          [ 0.0000, -0.0560,  0.0344, -0.0524,  0.0000],
          [ 0.0225,  0.1764, -0.0000,  0.1337, -0.1678],
          [-0.1911, -0.0982,  0.1630,  0.0782,  0.0365],
          [ 0.1318,  0.0365, -0.1186,  0.0000,  0.0301]]],


        [[[ 0.0000,  0.0758, -0.0975,  0.0687,  0.0000],
          [ 0.1762, -0.1359, -0.0000, -0.0000,  0.0021],
          [-0.0000,  0.0671,  0.0768,  0.0851,  0.0748],
          [-0.1724, -0.0910, -0.1004, -0.0000, -0.1034],
          [ 0.0631,  0.1546, -0.1642, -0.0000, -0.0381]]],


        [[[ 0.0000, -0.0892, -0.1225,  0.0000,  0.1097],
          [ 0.0000,  0.1599, -0.1959, -0.0000,  0.0365],
          [ 0.0000,

In [54]:
print(module._forward_pre_hooks)

OrderedDict([(7, <torch.nn.utils.prune.RandomUnstructured object at 0x7f7fc1f0c710>)])


In [None]:
prune.l1_unstructured(module, name="bias", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [55]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0386, -0.0401, -0.0324,  0.0289,  0.0745, -0.0778], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1141, -0.0872, -0.0636, -0.1292, -0.1176],
          [-0.0661,  0.0158,  0.0069, -0.1852,  0.1685],
          [-0.0706, -0.0427,  0.1277, -0.0882, -0.1116],
          [-0.1895,  0.1149, -0.0631, -0.0841,  0.1209],
          [-0.0792, -0.1951,  0.1391, -0.0291, -0.1651]]],


        [[[ 0.1238,  0.1868,  0.1294,  0.0337,  0.1134],
          [ 0.1591, -0.0560,  0.0344, -0.0524,  0.0565],
          [ 0.0225,  0.1764, -0.0216,  0.1337, -0.1678],
          [-0.1911, -0.0982,  0.1630,  0.0782,  0.0365],
          [ 0.1318,  0.0365, -0.1186,  0.0133,  0.0301]]],


        [[[ 0.1231,  0.0758, -0.0975,  0.0687,  0.1713],
          [ 0.1762, -0.1359, -0.0458, -0.1141,  0.0021],
          [-0.0439,  0.0671,  0.0768,  0.0851,  0.0748],
          [-0.1724, -0.0910, -0.1004, -0.0114, -0.1034],
          [ 0.0

In [67]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 50% (3 out of 6) of the channels, while preserving the action of the previous mask.
print(module)
print(module.weight)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
tensor([[[[ 0.0354,  0.0517, -0.1372, -0.1993,  0.0417],
          [-0.0969, -0.1600, -0.1861,  0.0812, -0.0993],
          [ 0.1761,  0.1753,  0.0174, -0.1056,  0.0304],
          [-0.0392,  0.0888, -0.0373,  0.0320, -0.1884],
          [ 0.0716, -0.1792,  0.0240, -0.1539, -0.1080]]],


        [[[ 0.1978,  0.0198,  0.0279,  0.1849,  0.1703],
          [ 0.1940, -0.0396,  0.1681,  0.0322,  0.0005],
          [ 0.1948, -0.0355, -0.1441,  0.0781,  0.1440],
          [-0.1053, -0.1539,  0.1229,  0.0063,  0.0485],
          [ 0.0616, -0.1104,  0.1248,  0.0036,  0.1073]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,

In [68]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1706, -0.0999, -0.1277,  0.1704, -0.0590,  0.1991], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0354,  0.0517, -0.1372, -0.1993,  0.0417],
          [-0.0969, -0.1600, -0.1861,  0.0812, -0.0993],
          [ 0.1761,  0.1753,  0.0174, -0.1056,  0.0304],
          [-0.0392,  0.0888, -0.0373,  0.0320, -0.1884],
          [ 0.0716, -0.1792,  0.0240, -0.1539, -0.1080]]],


        [[[ 0.1978,  0.0198,  0.0279,  0.1849,  0.1703],
          [ 0.1940, -0.0396,  0.1681,  0.0322,  0.0005],
          [ 0.1948, -0.0355, -0.1441,  0.0781,  0.1440],
          [-0.1053, -0.1539,  0.1229,  0.0063,  0.0485],
          [ 0.0616, -0.1104,  0.1248,  0.0036,  0.1073]]],


        [[[-0.0493,  0.1004,  0.1612,  0.0756,  0.1108],
          [ 0.1421,  0.0612,  0.1483,  0.0781,  0.1538],
          [ 0.1083, -0.0422, -0.0057, -0.1140, -0.0992],
          [ 0.0655,  0.1091,  0.0163,  0.0197, -0.0312],
          [-0.1

In [70]:
print(model.state_dict().keys())

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [71]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1706, -0.0999, -0.1277,  0.1704, -0.0590,  0.1991], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0354,  0.0517, -0.1372, -0.1993,  0.0417],
          [-0.0969, -0.1600, -0.1861,  0.0812, -0.0993],
          [ 0.1761,  0.1753,  0.0174, -0.1056,  0.0304],
          [-0.0392,  0.0888, -0.0373,  0.0320, -0.1884],
          [ 0.0716, -0.1792,  0.0240, -0.1539, -0.1080]]],


        [[[ 0.1978,  0.0198,  0.0279,  0.1849,  0.1703],
          [ 0.1940, -0.0396,  0.1681,  0.0322,  0.0005],
          [ 0.1948, -0.0355, -0.1441,  0.0781,  0.1440],
          [-0.1053, -0.1539,  0.1229,  0.0063,  0.0485],
          [ 0.0616, -0.1104,  0.1248,  0.0036,  0.1073]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000, 

In [72]:
print(list(module.named_buffers()))

[]


In [None]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

Global pruning

In [73]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [74]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 4.67%
Sparsity in conv2.weight: 13.25%
Sparsity in fc1.weight: 22.20%
Sparsity in fc2.weight: 12.08%
Sparsity in fc3.weight: 11.07%
Global sparsity: 20.00%
