Prepare the data:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn import Parameter
from torch.nn.modules.module import Module

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=50,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=50,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:

class PruningModule(Module):
    def prune_by_percentile(self, q=5.0, **kwargs):
        """
        Note:
             The pruning percentile is based on all layer's parameters concatenated
        Args:
            q (float): percentile in float
            **kwargs: may contain `cuda`
        """
        # Calculate percentile value
        alive_parameters = []
        for name, p in self.named_parameters():
            # We do not prune bias term
            if 'bias' in name or 'mask' in name:
                continue
            tensor = p.data.cpu().numpy()
            alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
            alive_parameters.append(alive)

        all_alives = np.concatenate(alive_parameters)
        percentile_value = np.percentile(abs(all_alives), q)
        print(f'Pruning with threshold : {percentile_value}')

        # Prune the weights and mask
        # Note that module here is the layer
        # ex) fc1, fc2, fc3
        for name, module in self.named_modules():
            if name in ['fc1', 'fc2']:
                module.prune(threshold=percentile_value)

    def prune_by_std(self, s=0.25):
        """
        Note that `s` is a quality parameter / sensitivity value according to the paper.
        According to Song Han's previous paper (Learning both Weights and Connections for Efficient Neural Networks),
        'The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer’s weights'

        I tried multiple values and empirically, 0.25 matches the paper's compression rate and number of parameters.
        Note : In the paper, the authors used different sensitivity values for different layers.
        """
        for name, module in self.named_modules():
            if name in ['fc1', 'fc2']:
                threshold = np.std(module.weight.data.cpu().numpy()) * s
                print(f'Pruning with threshold : {threshold} for layer {name}')
                module.prune(threshold)


class MaskedLinear(Module):
    """Applies a masked linear transformation to the incoming data: :math:`y = (A * M)x + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Output: :math:`(N, *, out\_features)` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            (out_features x in_features)
        bias:   the learnable bias of the module of shape (out_features)
        mask: the unlearnable mask for the weight.
            It has the same shape as weight (out_features x in_features)

    """
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        # Initialize the mask with 1
        self.mask = Parameter(torch.ones([out_features, in_features]), requires_grad=False)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.weight * self.mask, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'

    def prune(self, threshold):
        weight_dev = self.weight.device
        mask_dev = self.mask.device
        # Convert Tensors to numpy and calculate
        tensor = self.weight.data.cpu().numpy()
        mask = self.mask.data.cpu().numpy()
        new_mask = np.where(abs(tensor) < threshold, 0.0, mask)
        # Apply new weight and mask
        self.weight.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
        self.mask.data = torch.from_numpy(new_mask).to(mask_dev)


Define the network:

In [4]:
class Net(PruningModule):
    def __init__(self):
        super(Net, self).__init__()
        linear = MaskedLinear
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = linear(16 * 5 * 5, 120, bias=True)
        self.fc2 = linear(120, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()

# Freeze the first layer
#for param in net.fc1.parameters():
#    param.requires_grad = False

# Initialize the first layer
#def weights_init(m):
#    if isinstance(m, nn.Linear):
#        m.weight.data.normal_(0, 0.01)
    
#net.apply(weights_init)

#net.prune_by_std()

print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): MaskedLinear(in_features=400, out_features=120, bias=True)
  (fc2): MaskedLinear(in_features=120, out_features=10, bias=True)
)


params[4] are weights in the first layer, params[5] are the masks, etc

In [5]:
params = list(net.parameters())
print(len(params))
print(params[4])

10
Parameter containing:
tensor([[-0.0027, -0.0318, -0.0415,  ..., -0.0132,  0.0069, -0.0323],
        [ 0.0305,  0.0495, -0.0110,  ..., -0.0285,  0.0276,  0.0133],
        [ 0.0079,  0.0090,  0.0116,  ..., -0.0286,  0.0186, -0.0427],
        ...,
        [-0.0296,  0.0463,  0.0181,  ..., -0.0282,  0.0156, -0.0391],
        [ 0.0451, -0.0375,  0.0385,  ..., -0.0035,  0.0463,  0.0251],
        [ 0.0463,  0.0081,  0.0022,  ..., -0.0094,  0.0477,  0.0154]],
       requires_grad=True)


In [6]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss() # Softmax is built in it so you do not need add that on the last layer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [7]:
def train():
    for epoch in range(150):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
    
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 1000 == 999:    # print every 1000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 1000))
                running_loss = 0.0
                
    print('Finished Training')

In [8]:
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

In [9]:
train()
test()

[1,  1000] loss: 2.175
[2,  1000] loss: 1.767
[3,  1000] loss: 1.590
[4,  1000] loss: 1.498
[5,  1000] loss: 1.428
[6,  1000] loss: 1.361
[7,  1000] loss: 1.303
[8,  1000] loss: 1.253
[9,  1000] loss: 1.208
[10,  1000] loss: 1.171
[11,  1000] loss: 1.138
[12,  1000] loss: 1.109
[13,  1000] loss: 1.082
[14,  1000] loss: 1.053
[15,  1000] loss: 1.034
[16,  1000] loss: 1.010
[17,  1000] loss: 0.989
[18,  1000] loss: 0.969
[19,  1000] loss: 0.949
[20,  1000] loss: 0.932
[21,  1000] loss: 0.915
[22,  1000] loss: 0.895
[23,  1000] loss: 0.880
[24,  1000] loss: 0.863
[25,  1000] loss: 0.848
[26,  1000] loss: 0.835
[27,  1000] loss: 0.818
[28,  1000] loss: 0.804
[29,  1000] loss: 0.791
[30,  1000] loss: 0.778
[31,  1000] loss: 0.767
[32,  1000] loss: 0.749
[33,  1000] loss: 0.739
[34,  1000] loss: 0.726
[35,  1000] loss: 0.714
[36,  1000] loss: 0.708
[37,  1000] loss: 0.694
[38,  1000] loss: 0.680
[39,  1000] loss: 0.672
[40,  1000] loss: 0.658
[41,  1000] loss: 0.648
[42,  1000] loss: 0.636
[

In [10]:
net.prune_by_std()
#print(params[4])
test()

Pruning with threshold : 0.02721398137509823 for layer fc1
Pruning with threshold : 0.07363494485616684 for layer fc2
Accuracy of the network on the 10000 test images: 59 %


In [11]:
train()
#print(params[4])

[1,  1000] loss: 0.163
[2,  1000] loss: 0.129
[3,  1000] loss: 0.096
[4,  1000] loss: 0.081
[5,  1000] loss: 0.068
[6,  1000] loss: 0.062
[7,  1000] loss: 0.056
[8,  1000] loss: 0.051
[9,  1000] loss: 0.042
[10,  1000] loss: 0.039
[11,  1000] loss: 0.039
[12,  1000] loss: 0.031
[13,  1000] loss: 0.024
[14,  1000] loss: 0.021
[15,  1000] loss: 0.017
[16,  1000] loss: 0.015
[17,  1000] loss: 0.013
[18,  1000] loss: 0.012
[19,  1000] loss: 0.011
[20,  1000] loss: 0.010
[21,  1000] loss: 0.009
[22,  1000] loss: 0.009
[23,  1000] loss: 0.008
[24,  1000] loss: 0.008
[25,  1000] loss: 0.007
[26,  1000] loss: 0.007
[27,  1000] loss: 0.007
[28,  1000] loss: 0.006
[29,  1000] loss: 0.006
[30,  1000] loss: 0.006
[31,  1000] loss: 0.006
[32,  1000] loss: 0.005
[33,  1000] loss: 0.005
[34,  1000] loss: 0.005
[35,  1000] loss: 0.005
[36,  1000] loss: 0.005
[37,  1000] loss: 0.005
[38,  1000] loss: 0.004
[39,  1000] loss: 0.004
[40,  1000] loss: 0.004
[41,  1000] loss: 0.004
[42,  1000] loss: 0.004
[

In [12]:
test()

Accuracy of the network on the 10000 test images: 60 %
