Prepare the data:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import torch.optim as optim
import torchvision.datasets as datasets
from torch.autograd import Variable
from torchsummary import summary

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.is_available()

cpu


False

In [3]:
class CIFAR10RandomLabels(datasets.CIFAR10):
    """CIFAR10 dataset, with support for randomly corrupt labels.
    Params
    ------  
    corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
    num_classes: int
    Default 10. The number of classes in the dataset.
    """
    def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
        super(CIFAR10RandomLabels, self).__init__(**kwargs)
        self.n_classes = num_classes
        if corrupt_prob > 0:
            self.corrupt_labels(corrupt_prob)

    def corrupt_labels(self, corrupt_prob):
        labels = np.array(self.targets)
        np.random.seed(12345)
        mask = np.random.rand(len(labels)) <= corrupt_prob
        rnd_labels = np.random.choice(self.n_classes, mask.sum())
        labels[mask] = rnd_labels
        # we need to explicitly cast the labels from npy.int64 to
        # builtin int type, otherwise pytorch will fail...
        labels = [int(x) for x in labels]
        self.targets = torch.tensor(labels)

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainloader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=True, download=True, 
                            transform=transform, corrupt_prob=0.1),
                            batch_size=128, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=False, download=True,
                            transform=transform, corrupt_prob=0.1),
                            batch_size=100, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.linear(x, weight, self.bias)
        else:
            return F.linear(x, self.weight, self.bias)
        

        
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, 
            kernel_size, stride, padding, dilation, groups, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
        else:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

Define the network:

In [20]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        conv2d = MaskedConv2d
        self.conv1 = conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        conv2d = MaskedConv2d
        self.conv1 = conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        linear = MaskedLinear
        conv2d = MaskedConv2d
        self.in_planes = 64

        self.conv1 = conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    
    def weight_prune(self, pruning_perc):
        all_weights = []
        for params in self.parameters():
            if len(params.data.size()) != 1:
                all_weights += list(params.cpu().data.abs().numpy().flatten())
        threshold = np.percentile(np.array(all_weights), pruning_perc)
        print(f'Pruning with threshold : %.4f' % threshold)
        
        # Module here refers to layer
        prune_name = ['conv1','conv2','linear']
        for name, module in self.named_modules():
            if any(x in name for x in prune_name):
                module.prune(threshold)   

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


In [7]:
net = ResNet18()
net.to(device)
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      MaskedConv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
      MaskedConv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
      MaskedConv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
        BasicBlock-7           [-1, 64, 32, 32]               0
      MaskedConv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
     MaskedConv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
       BasicBlock-12           [-1, 64, 32, 32]               0
     MaskedConv2d-13          [-1, 128, 16, 16]          73,728
      BatchNorm2d-14          [-1, 128,

params[4] are weights in the first layer, params[5] are the masks, etc

In [19]:
#params = list(net.parameters())
#for name, module in net.named_modules():
#    print(name)
#print(params[4])


conv1
bn1
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.conv2
layer1.0.bn2
layer1.0.shortcut
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.conv2
layer1.1.bn2
layer1.1.shortcut
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.conv2
layer2.0.bn2
layer2.0.shortcut
layer2.0.shortcut.0
layer2.0.shortcut.1
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.conv2
layer2.1.bn2
layer2.1.shortcut
layer3
layer3.0
layer3.0.conv1
layer3.0.bn1
layer3.0.conv2
layer3.0.bn2
layer3.0.shortcut
layer3.0.shortcut.0
layer3.0.shortcut.1
layer3.1
layer3.1.conv1
layer3.1.bn1
layer3.1.conv2
layer3.1.bn2
layer3.1.shortcut
layer4
layer4.0
layer4.0.conv1
layer4.0.bn1
layer4.0.conv2
layer4.0.bn2
layer4.0.shortcut
layer4.0.shortcut.0
layer4.0.shortcut.1
layer4.1
layer4.1.conv1
layer4.1.bn1
layer4.1.conv2
layer4.1.bn2
layer4.1.shortcut
linear


In [9]:
criterion = nn.CrossEntropyLoss() # Softmax is built in it so you do not need add that on the last layer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
initial_optimizer_state_dict = optimizer.state_dict()

In [10]:
def train(model, epochs):
    for epoch in range(epochs):  # loop over the dataset multiple times
        
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
    
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
                
            # training accuracy
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        print('Train Accuracy: %.3f %%' % (100 * train_correct / train_total))
        
    print('Finished Training')

In [11]:
def test(model):
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            test_loss += criterion(outputs, labels).item()
    
    print('Accuracy of the network on the 10000 test images: %.2f %%' % (
        100 * correct / total))
    print('Test loss: %.4f ' % (test_loss / 100))

In [12]:
def calculate_size(model):
    non_zeros = 0
    for param in model.parameters():
        if param.requires_grad==True:
            non_zeros += param.nonzero().size(0)
    params_size = non_zeros* 4. / (1024 ** 2)
    print("Params size (MB): %0.3f" % params_size)

In [14]:
#train(net, 2)
torch.save(net.state_dict(), "models/ResNet.pt")
test(net)
calculate_size(net)

Accuracy of the network on the 10000 test images: 61.06 %
Test loss: 1.2373 
Params size (MB): 42.625


In [22]:
net_p = ResNet18()
net_p.load_state_dict(torch.load("models/ResNet.pt"))
net_p.to(device)
net_p.eval()
net_p.weight_prune(90)
test(net_p)
#summary(net, (3, 32, 32))
calculate_size(net_p)

Pruning with threshold : 0.0170
Accuracy of the network on the 10000 test images: 11.70 %
Test loss: 2.6761 
Params size (MB): 4.460


In [23]:
optimizer = optim.SGD(net_p.parameters(), lr=0.01, momentum=0.9)
train(net_p, 1)
test(net_p)

[1,   200] loss: 1.943
[1,   400] loss: 1.690
[1,   600] loss: 1.570
Train Accuracy: 41.732 %
Finished Training
Accuracy of the network on the 10000 test images: 48.20 %
Test loss: 1.5034 
