[参考サイト](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)

In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [3]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.1935,  0.0434,  0.1336, -0.1980, -0.0981],
          [-0.0811, -0.1568,  0.0636,  0.1082,  0.0504],
          [ 0.0717,  0.1924,  0.0303, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.1679, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.1025, -0.0335,  0.1541, -0.1317, -0.1372],
          [ 0.0841, -0.1767, -0.0190,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0178,  0.1462]]],


        [[[-0.1531,  0.1097,  0.1085,  0.0709,  0.1582],
          [ 0.0869,  0.0509,  0.1763,  0.0884, -0.0649],
          [-0.1121,  0.1561,  0.0061, -0.0467, -0.0112],
          [-0.0755,  0.0282, -0.0176,  0.0711, -0.1489],
          [ 0.0183,  0.0159, -0.1254,  0.0253,  0.0347]]],


        [[[ 0.1984, -0.1126,  0.1715, -0.1774, -0.0907],
          [ 0.1051,  0.1508, -0.1042,  0.0

In [4]:
print(list(module.named_buffers()))

[]


In [5]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [6]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.0909,  0.1756, -0.0428,  0.0359, -0.0734, -0.0719],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1935,  0.0434,  0.1336, -0.1980, -0.0981],
          [-0.0811, -0.1568,  0.0636,  0.1082,  0.0504],
          [ 0.0717,  0.1924,  0.0303, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.1679, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.1025, -0.0335,  0.1541, -0.1317, -0.1372],
          [ 0.0841, -0.1767, -0.0190,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0178,  0.1462]]],


        [[[-0.1531,  0.1097,  0.1085,  0.0709,  0.1582],
          [ 0.0869,  0.0509,  0.1763,  0.0884, -0.0649],
          [-0.1121,  0.1561,  0.0061, -0.0467, -0.0112],
          [-0.0755,  0.0282, -0.0176,  0.0711, -0.1489],
          [ 0.0183,  0.0159, -0.

In [7]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 0., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 1., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]]]))

In [8]:
print(module.weight)

tensor([[[[ 0.0000,  0.0434,  0.1336, -0.0000, -0.0981],
          [-0.0811, -0.1568,  0.0000,  0.0000,  0.0000],
          [ 0.0717,  0.1924,  0.0000, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.0000, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.0000, -0.0000,  0.1541, -0.1317, -0.0000],
          [ 0.0000, -0.1767, -0.0000,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0000,  0.1462]]],


        [[[-0.0000,  0.1097,  0.0000,  0.0709,  0.1582],
          [ 0.0869,  0.0509,  0.1763,  0.0884, -0.0000],
          [-0.1121,  0.1561,  0.0000, -0.0467, -0.0000],
          [-0.0000,  0.0282, -0.0176,  0.0711, -0.0000],
          [ 0.0183,  0.0000, -0.1254,  0.0253,  0.0347]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.1774, -0.0000],
          [ 0.0000,  0.1508, -0.1042,  0.0000,  0.0501],
          [-0.0171,

In [9]:
print(module._forward_pre_hooks)

OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x168b987a0>})


In [10]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [11]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.1935,  0.0434,  0.1336, -0.1980, -0.0981],
          [-0.0811, -0.1568,  0.0636,  0.1082,  0.0504],
          [ 0.0717,  0.1924,  0.0303, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.1679, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.1025, -0.0335,  0.1541, -0.1317, -0.1372],
          [ 0.0841, -0.1767, -0.0190,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0178,  0.1462]]],


        [[[-0.1531,  0.1097,  0.1085,  0.0709,  0.1582],
          [ 0.0869,  0.0509,  0.1763,  0.0884, -0.0649],
          [-0.1121,  0.1561,  0.0061, -0.0467, -0.0112],
          [-0.0755,  0.0282, -0.0176,  0.0711, -0.1489],
          [ 0.0183,  0.0159, -0.1254,  0.0253,  0.0347]]],


        [[[ 0.1984, -0.1126,  0.1715, -0.1774, -0.0907],
          [ 0.1051,  0.1508, -0.1042,

In [12]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 0., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 1., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]]]))

In [13]:
print(module.bias)

tensor([-0.0909,  0.1756, -0.0000,  0.0000, -0.0734, -0.0000],
       grad_fn=<MulBackward0>)


In [14]:
print(module._forward_pre_hooks)

OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x168b987a0>, 1: <torch.nn.utils.prune.L1Unstructured object at 0x16974fce0>})


In [15]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

tensor([[[[ 0.0000,  0.0434,  0.1336, -0.0000, -0.0981],
          [-0.0811, -0.1568,  0.0000,  0.0000,  0.0000],
          [ 0.0717,  0.1924,  0.0000, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.0000, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.0000, -0.0000,  0.1541, -0.1317, -0.0000],
          [ 0.0000, -0.1767, -0.0000,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0000,  0.1462]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,

In [16]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

[<torch.nn.utils.prune.RandomUnstructured object at 0x168b987a0>, <torch.nn.utils.prune.LnStructured object at 0x168c8a420>]


In [17]:
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [18]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.1935,  0.0434,  0.1336, -0.1980, -0.0981],
          [-0.0811, -0.1568,  0.0636,  0.1082,  0.0504],
          [ 0.0717,  0.1924,  0.0303, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.1679, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.1025, -0.0335,  0.1541, -0.1317, -0.1372],
          [ 0.0841, -0.1767, -0.0190,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0178,  0.1462]]],


        [[[-0.1531,  0.1097,  0.1085,  0.0709,  0.1582],
          [ 0.0869,  0.0509,  0.1763,  0.0884, -0.0649],
          [-0.1121,  0.1561,  0.0061, -0.0467, -0.0112],
          [-0.0755,  0.0282, -0.0176,  0.0711, -0.1489],
          [ 0.0183,  0.0159, -0.1254,  0.0253,  0.0347]]],


        [[[ 0.1984, -0.1126,  0.1715, -0.1774, -0.0907],
          [ 0.1051,  0.1508, -0.1042,

In [19]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 0., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]]]))

In [20]:
print(module.weight)

tensor([[[[ 0.0000,  0.0434,  0.1336, -0.0000, -0.0981],
          [-0.0811, -0.1568,  0.0000,  0.0000,  0.0000],
          [ 0.0717,  0.1924,  0.0000, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.0000, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.0000, -0.0000,  0.1541, -0.1317, -0.0000],
          [ 0.0000, -0.1767, -0.0000,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0000,  0.1462]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,

In [21]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.0909,  0.1756, -0.0428,  0.0359, -0.0734, -0.0719],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0434,  0.1336, -0.0000, -0.0981],
          [-0.0811, -0.1568,  0.0000,  0.0000,  0.0000],
          [ 0.0717,  0.1924,  0.0000, -0.0222, -0.1819],
          [ 0.1877,  0.1397, -0.0855, -0.1898, -0.0787],
          [-0.1723, -0.0803, -0.0000, -0.1421,  0.0557]]],


        [[[-0.0643,  0.1876, -0.1257, -0.0872,  0.1297],
          [ 0.0000, -0.0000,  0.1541, -0.1317, -0.0000],
          [ 0.0000, -0.1767, -0.0000,  0.0504,  0.0929],
          [ 0.1461, -0.0231,  0.0979,  0.1938, -0.1650],
          [ 0.1398, -0.1144,  0.0556, -0.0000,  0.1462]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.

In [22]:
print(list(module.named_buffers()))

[('bias_mask', tensor([1., 1., 0., 0., 1., 0.]))]


In [23]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


In [24]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [25]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 7.33%
Sparsity in conv2.weight: 14.38%
Sparsity in fc1.weight: 22.19%
Sparsity in fc2.weight: 11.98%
Sparsity in fc3.weight: 9.40%
Global sparsity: 20.00%


In [26]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

In [27]:
def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

In [28]:
model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
