In [1]:
import torchvision.models as models
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm
import math
import numpy as np

In [2]:
class Vgg11(nn.Module):
    def __init__(self,cfg, init_weights=True):
        super(Vgg11, self).__init__()
       
        self.cfg = cfg
        self.feature = self.make_layers(self.cfg, True)

        self.num_classes = 10
        self.classifier =  nn.Linear(self.cfg[-1], 10)


        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

In [3]:
oricfg =  [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
vgg11 = Vgg11(oricfg)

resume = False
if resume:
    checkpoint = torch.load('./vgg11.pth')
    model_dict = vgg11.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint['net'].items() if k in model_dict  }
    model_dict.update(pretrained_dict)
    vgg11.load_state_dict(model_dict)
vgg11.cuda()

Vgg11(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [4]:
batchsize = 64

In [5]:
train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data/', train=True, download=False,
                       transform=transforms.Compose([
                           transforms.Pad(4),
                           transforms.RandomCrop(32),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batchsize, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data/', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batchsize, shuffle=False)

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg11.parameters(), lr=1e-4)

s = 0.0001

def updateBN():
    for m in vgg11.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(s*torch.sign(m.weight.data))  # L1


def train(epoch):
    vgg11.train()
 #   prbar = tqdm(total=len(train_loader))
 #   prbar.set_description("training epoch"+str(epoch))
    for data, target in train_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output =  F.softmax(vgg11(data),dim=-1)
        loss = criterion(output, target)


        optimizer.zero_grad()
        loss.backward()
        updateBN()
        optimizer.step()
        
   #     prbar.update(1)
  #      prbar.set_postfix(loss=loss.item())
  #  prbar.close()
    
def test(model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output = F.softmax(model(data),dim=-1)
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('Test :  Accuracy: {}/{} ({:.2f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
for epoch in range(0,100):
    train(epoch)
    test(vgg11)
    if epoch %10 == 1:
        torch.save({'net':vgg11.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch},'./vgg11.pth')

Test :  Accuracy: 5671/10000 (56.71%)

Test :  Accuracy: 6265/10000 (62.65%)

Test :  Accuracy: 6750/10000 (67.50%)

Test :  Accuracy: 7006/10000 (70.06%)



In [None]:
total = 0
for m in vgg11.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in vgg11.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size


In [None]:
percent = 0.5
y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]
print("'gamma' less than {} is thrown away".format(thre.item()))

In [None]:
pruned = 0
cfg = []
cfg_mask = []


for k, m in enumerate(vgg11.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
      #  m.weight.data.mul_(mask)
     #   m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

In [None]:
newvgg11 = Vgg11(cfg)
newvgg11.cuda()

In [None]:
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(vgg11.modules(), newvgg11.modules()):
    if isinstance(m0, nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        m1.weight.data = m0.weight.data[idx1].clone()
        m1.bias.data = m0.bias.data[idx1].clone()
        m1.running_mean = m0.running_mean[idx1].clone()
        m1.running_var = m0.running_var[idx1].clone()
       
        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
        w = m0.weight.data[:, idx0, :, :].clone()
        w = w[idx1, :, :, :].clone()
        m1.weight.data = w.clone()
        
     #   m1.bias.data = m0.bias.data[idx1].clone() # In some conv layers there is no bias.
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        m1.weight.data = m0.weight.data[:, idx0].clone()

In [None]:
test(newvgg11)