Prepare the data:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn import Parameter
from torch.nn.modules.module import Module

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:

class PruningModule(Module):
    def prune_by_percentile(self, q=5.0, **kwargs):
        """
        Note:
             The pruning percentile is based on all layer's parameters concatenated
        Args:
            q (float): percentile in float
            **kwargs: may contain `cuda`
        """
        # Calculate percentile value
        alive_parameters = []
        for name, p in self.named_parameters():
            # We do not prune bias term
            if 'bias' in name or 'mask' in name:
                continue
            tensor = p.data.cpu().numpy()
            alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
            alive_parameters.append(alive)

        all_alives = np.concatenate(alive_parameters)
        percentile_value = np.percentile(abs(all_alives), q)
        print(f'Pruning with threshold : {percentile_value}')

        # Prune the weights and mask
        # Note that module here is the layer
        # ex) fc1, fc2, fc3
        for name, module in self.named_modules():
            if name in ['fc1', 'fc2','fc3']:
                module.prune(threshold=percentile_value)

    def prune_by_std(self, s=0.25):
        """
        Note that `s` is a quality parameter / sensitivity value according to the paper.
        According to Song Han's previous paper (Learning both Weights and Connections for Efficient Neural Networks),
        'The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer’s weights'

        I tried multiple values and empirically, 0.25 matches the paper's compression rate and number of parameters.
        Note : In the paper, the authors used different sensitivity values for different layers.
        """
        for name, module in self.named_modules():
            if name in ['fc1', 'fc2','fc3']:
                threshold = np.std(module.weight.data.cpu().numpy()) * s
                print(f'Pruning with threshold : {threshold} for layer {name}')
                module.prune(threshold)


class MaskedLinear(Module):
    """Applies a masked linear transformation to the incoming data: :math:`y = (A * M)x + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Output: :math:`(N, *, out\_features)` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            (out_features x in_features)
        bias:   the learnable bias of the module of shape (out_features)
        mask: the unlearnable mask for the weight.
            It has the same shape as weight (out_features x in_features)

    """
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        # Initialize the mask with 1
        self.mask = Parameter(torch.ones([out_features, in_features]), requires_grad=False)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.weight * self.mask, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'

    def prune(self, threshold):
        weight_dev = self.weight.device
        mask_dev = self.mask.device
        # Convert Tensors to numpy and calculate
        tensor = self.weight.data.cpu().numpy()
        mask = self.mask.data.cpu().numpy()
        new_mask = np.where(abs(tensor) < threshold, 0.0, mask)
        # Apply new weight and mask
        self.weight.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
        self.mask.data = torch.from_numpy(new_mask).to(mask_dev)


Define the network:

In [5]:
class Net(PruningModule):
    def __init__(self):
        super(Net, self).__init__()
        linear = MaskedLinear
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = linear(16 * 5 * 5, 120)
        self.fc2 = linear(120, 84)
        self.fc3 = linear(84, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
net.to(device)

# Freeze the first layer
#for param in net.fc1.parameters():
#    param.requires_grad = False

# Initialize the first layer
#def weights_init(m):
#    if isinstance(m, nn.Linear):
#        m.weight.data.normal_(0, 0.01)
    
#net.apply(weights_init)

#net.prune_by_std()

print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): MaskedLinear(in_features=400, out_features=120, bias=True)
  (fc2): MaskedLinear(in_features=120, out_features=84, bias=True)
  (fc3): MaskedLinear(in_features=84, out_features=10, bias=True)
)


params[4] are weights in the first layer, params[5] are the masks, etc

In [6]:
params = list(net.parameters())
print(len(params))
print(params[4])

13
Parameter containing:
tensor([[-0.0289,  0.0409,  0.0139,  ...,  0.0412, -0.0190,  0.0167],
        [ 0.0203, -0.0047, -0.0322,  ...,  0.0252, -0.0394, -0.0326],
        [-0.0003, -0.0068, -0.0043,  ...,  0.0417, -0.0143,  0.0086],
        ...,
        [-0.0047,  0.0422,  0.0188,  ...,  0.0183,  0.0103, -0.0428],
        [ 0.0191, -0.0187,  0.0427,  ...,  0.0431, -0.0050,  0.0374],
        [-0.0170,  0.0247, -0.0037,  ...,  0.0441,  0.0049,  0.0115]],
       requires_grad=True)


In [7]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss() # Softmax is built in it so you do not need add that on the last layer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [8]:
def train():
    for epoch in range(200):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
    
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 1000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
                
    print('Finished Training')

In [9]:
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

In [10]:
train()
test()

[1,   200] loss: 2.305
[1,   400] loss: 2.303
[1,   600] loss: 2.302
[2,   200] loss: 2.300
[2,   400] loss: 2.297
[2,   600] loss: 2.292
[3,   200] loss: 2.236
[3,   400] loss: 2.185
[3,   600] loss: 2.137
[4,   200] loss: 2.059
[4,   400] loss: 1.992
[4,   600] loss: 1.910
[5,   200] loss: 1.777
[5,   400] loss: 1.734
[5,   600] loss: 1.695
[6,   200] loss: 1.638
[6,   400] loss: 1.619
[6,   600] loss: 1.602
[7,   200] loss: 1.575
[7,   400] loss: 1.538
[7,   600] loss: 1.522
[8,   200] loss: 1.493
[8,   400] loss: 1.486
[8,   600] loss: 1.488
[9,   200] loss: 1.438
[9,   400] loss: 1.440
[9,   600] loss: 1.436
[10,   200] loss: 1.418
[10,   400] loss: 1.399
[10,   600] loss: 1.386
[11,   200] loss: 1.377
[11,   400] loss: 1.360
[11,   600] loss: 1.350
[12,   200] loss: 1.340
[12,   400] loss: 1.331
[12,   600] loss: 1.304
[13,   200] loss: 1.298
[13,   400] loss: 1.294
[13,   600] loss: 1.284
[14,   200] loss: 1.260
[14,   400] loss: 1.252
[14,   600] loss: 1.254
[15,   200] loss: 1

[114,   600] loss: 0.236
[115,   200] loss: 0.192
[115,   400] loss: 0.194
[115,   600] loss: 0.212
[116,   200] loss: 0.157
[116,   400] loss: 0.185
[116,   600] loss: 0.188
[117,   200] loss: 0.167
[117,   400] loss: 0.196
[117,   600] loss: 0.212
[118,   200] loss: 0.170
[118,   400] loss: 0.166
[118,   600] loss: 0.202
[119,   200] loss: 0.158
[119,   400] loss: 0.183
[119,   600] loss: 0.202
[120,   200] loss: 0.168
[120,   400] loss: 0.182
[120,   600] loss: 0.201
[121,   200] loss: 0.158
[121,   400] loss: 0.172
[121,   600] loss: 0.199
[122,   200] loss: 0.164
[122,   400] loss: 0.169
[122,   600] loss: 0.190
[123,   200] loss: 0.165
[123,   400] loss: 0.164
[123,   600] loss: 0.185
[124,   200] loss: 0.148
[124,   400] loss: 0.188
[124,   600] loss: 0.207
[125,   200] loss: 0.172
[125,   400] loss: 0.193
[125,   600] loss: 0.207
[126,   200] loss: 0.162
[126,   400] loss: 0.162
[126,   600] loss: 0.168
[127,   200] loss: 0.138
[127,   400] loss: 0.150
[127,   600] loss: 0.182


In [11]:
print('parameters before pruning:')
for parameter in model.parameters():
    print(parameter)
net.prune_by_std()
print('parameters after pruning:')
#print(params[4])
test()

Pruning with threshold : 0.025127505883574486 for layer fc1
Pruning with threshold : 0.03324936330318451 for layer fc2
Pruning with threshold : 0.0784255862236023 for layer fc3
Accuracy of the network on the 10000 test images: 61 %


In [12]:
train()
#print(params[4])

[1,   200] loss: 0.071
[1,   400] loss: 0.105
[1,   600] loss: 0.118
[2,   200] loss: 0.117
[2,   400] loss: 0.118
[2,   600] loss: 0.120
[3,   200] loss: 0.066
[3,   400] loss: 0.072
[3,   600] loss: 0.094
[4,   200] loss: 0.085
[4,   400] loss: 0.089
[4,   600] loss: 0.086
[5,   200] loss: 0.055
[5,   400] loss: 0.059
[5,   600] loss: 0.047
[6,   200] loss: 0.020
[6,   400] loss: 0.017
[6,   600] loss: 0.022
[7,   200] loss: 0.012
[7,   400] loss: 0.008
[7,   600] loss: 0.007
[8,   200] loss: 0.004
[8,   400] loss: 0.004
[8,   600] loss: 0.004
[9,   200] loss: 0.003
[9,   400] loss: 0.003
[9,   600] loss: 0.003
[10,   200] loss: 0.003
[10,   400] loss: 0.003
[10,   600] loss: 0.003
[11,   200] loss: 0.002
[11,   400] loss: 0.003
[11,   600] loss: 0.003
[12,   200] loss: 0.002
[12,   400] loss: 0.002
[12,   600] loss: 0.002
[13,   200] loss: 0.002
[13,   400] loss: 0.002
[13,   600] loss: 0.002
[14,   200] loss: 0.002
[14,   400] loss: 0.002
[14,   600] loss: 0.002
[15,   200] loss: 0

[114,   600] loss: 0.000
[115,   200] loss: 0.000
[115,   400] loss: 0.000
[115,   600] loss: 0.000
[116,   200] loss: 0.000
[116,   400] loss: 0.000
[116,   600] loss: 0.000
[117,   200] loss: 0.000
[117,   400] loss: 0.000
[117,   600] loss: 0.000
[118,   200] loss: 0.000
[118,   400] loss: 0.000
[118,   600] loss: 0.000
[119,   200] loss: 0.000
[119,   400] loss: 0.000
[119,   600] loss: 0.000
[120,   200] loss: 0.000
[120,   400] loss: 0.000
[120,   600] loss: 0.000
[121,   200] loss: 0.000
[121,   400] loss: 0.000
[121,   600] loss: 0.000
[122,   200] loss: 0.000
[122,   400] loss: 0.000
[122,   600] loss: 0.000
[123,   200] loss: 0.000
[123,   400] loss: 0.000
[123,   600] loss: 0.000
[124,   200] loss: 0.000
[124,   400] loss: 0.000
[124,   600] loss: 0.000
[125,   200] loss: 0.000
[125,   400] loss: 0.000
[125,   600] loss: 0.000
[126,   200] loss: 0.000
[126,   400] loss: 0.000
[126,   600] loss: 0.000
[127,   200] loss: 0.000
[127,   400] loss: 0.000
[127,   600] loss: 0.000


In [13]:
test()

Accuracy of the network on the 10000 test images: 62 %
