In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [9]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0452, -0.1287,  0.0023,  0.1995,  0.0965],
          [ 0.0974,  0.1561, -0.1720, -0.1242,  0.0905],
          [-0.1531,  0.0870, -0.1435, -0.0851, -0.0249],
          [ 0.0067,  0.0284, -0.1803, -0.0511,  0.0609],
          [ 0.1552,  0.1132,  0.0371, -0.0080,  0.0286]]],


        [[[-0.1255,  0.0127,  0.0372, -0.0412, -0.0561],
          [-0.0597, -0.0685,  0.1560, -0.1623,  0.0015],
          [-0.0826, -0.1312, -0.0195,  0.1474, -0.0230],
          [ 0.1299, -0.0080,  0.0831, -0.0853, -0.0879],
          [-0.1337,  0.0143, -0.1669, -0.1887,  0.1385]]],


        [[[-0.0653,  0.1057, -0.0347,  0.1890, -0.1790],
          [-0.1270, -0.1675,  0.0216,  0.1094, -0.0347],
          [-0.0262, -0.1717,  0.1717, -0.0319,  0.0837],
          [ 0.0965, -0.1510,  0.1830, -0.0457,  0.0981],
          [-0.1108,  0.0782, -0.0769,  0.0453,  0.0116]]],


        [[[ 0.0885,  0.0246, -0.0109,  0.0586,  0.0629],
          [-0.0224,  0.0451,  0.1404, -0.0

In [10]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [11]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 4.00%
Sparsity in conv2.weight: 13.12%
Sparsity in fc1.weight: 22.13%
Sparsity in fc2.weight: 12.41%
Sparsity in fc3.weight: 11.79%
Global sparsity: 20.00%
