In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [3]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.1188,  0.0299, -0.2168],
          [ 0.0082, -0.1437, -0.2159],
          [ 0.1061,  0.0975,  0.2083]]],


        [[[-0.0437,  0.0198, -0.0922],
          [ 0.2921,  0.1690, -0.1004],
          [ 0.1769,  0.3268, -0.2383]]],


        [[[ 0.0128,  0.1592,  0.3118],
          [-0.1280,  0.1068,  0.1134],
          [ 0.0599, -0.1745,  0.2205]]],


        [[[ 0.0857,  0.1030,  0.0239],
          [-0.1819, -0.1736, -0.2166],
          [-0.2796, -0.2213, -0.1616]]],


        [[[-0.0593, -0.1760, -0.1008],
          [-0.2283, -0.1946,  0.0400],
          [-0.1087, -0.1572,  0.1258]]],


        [[[ 0.1060,  0.2579, -0.0451],
          [ 0.0119, -0.3029, -0.3019],
          [ 0.2623,  0.2487, -0.3102]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.3224, -0.0993, -0.0392,  0.2575,  0.1557,  0.3288],
       requires_grad=True))]


In [4]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [5]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.3224, -0.0993, -0.0392,  0.2575,  0.1557,  0.3288],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1188,  0.0299, -0.2168],
          [ 0.0082, -0.1437, -0.2159],
          [ 0.1061,  0.0975,  0.2083]]],


        [[[-0.0437,  0.0198, -0.0922],
          [ 0.2921,  0.1690, -0.1004],
          [ 0.1769,  0.3268, -0.2383]]],


        [[[ 0.0128,  0.1592,  0.3118],
          [-0.1280,  0.1068,  0.1134],
          [ 0.0599, -0.1745,  0.2205]]],


        [[[ 0.0857,  0.1030,  0.0239],
          [-0.1819, -0.1736, -0.2166],
          [-0.2796, -0.2213, -0.1616]]],


        [[[-0.0593, -0.1760, -0.1008],
          [-0.2283, -0.1946,  0.0400],
          [-0.1087, -0.1572,  0.1258]]],


        [[[ 0.1060,  0.2579, -0.0451],
          [ 0.0119, -0.3029, -0.3019],
          [ 0.2623,  0.2487, -0.3102]]]], requires_grad=True))]


In [6]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[0., 1., 1.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [1., 1., 1.],
          [1., 0., 1.]]]]))]


In [7]:
print(module.weight)

tensor([[[[ 0.0000,  0.0299, -0.2168],
          [ 0.0082, -0.0000, -0.2159],
          [ 0.1061,  0.0975,  0.2083]]],


        [[[-0.0437,  0.0198, -0.0922],
          [ 0.0000,  0.1690, -0.0000],
          [ 0.1769,  0.0000, -0.2383]]],


        [[[ 0.0000,  0.1592,  0.3118],
          [-0.0000,  0.1068,  0.1134],
          [ 0.0000, -0.0000,  0.2205]]],


        [[[ 0.0857,  0.1030,  0.0239],
          [-0.0000, -0.1736, -0.2166],
          [-0.2796, -0.2213, -0.1616]]],


        [[[-0.0593, -0.1760, -0.1008],
          [-0.2283, -0.0000,  0.0000],
          [-0.1087, -0.1572,  0.1258]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0119, -0.3029, -0.3019],
          [ 0.2623,  0.0000, -0.3102]]]], grad_fn=<MulBackward0>)


In [8]:
print(module._forward_pre_hooks)


OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x123b77198>)])


In [9]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 29.63%
Sparsity in conv2.weight: 0.00%
Sparsity in fc1.weight: 0.00%
Sparsity in fc2.weight: 0.00%
Sparsity in fc3.weight: 0.00%
Global sparsity: 0.03%


In [10]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [11]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 1.85%
Sparsity in conv2.weight: 6.60%
Sparsity in fc1.weight: 22.10%
Sparsity in fc2.weight: 12.14%
Sparsity in fc3.weight: 9.05%
Global sparsity: 20.00%
