In [None]:
!pip install git+https://github.com/PyTorchLightning/pytorch-lightning
import pytorch_lightning as pl

Collecting git+https://github.com/PyTorchLightning/pytorch-lightning
  Cloning https://github.com/PyTorchLightning/pytorch-lightning to /tmp/pip-req-build-h7apm12q
  Running command git clone -q https://github.com/PyTorchLightning/pytorch-lightning /tmp/pip-req-build-h7apm12q
  Running command git submodule update --init --recursive -q
  From https://github.com/PyTorchLightning/lightning-tutorials
   * branch            290fb466de1fcc2ac6025f74b56906592911e856 -> FETCH_HEAD
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting pyDeprecate<0.4.0,>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.7.3-py3-none-any.whl (398 kB)
[K     |████████████████████████████████| 398 kB 5.3 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.2.0-py3-none-any.whl (134 kB)
[K   

In [None]:
!pip install wandb -qqq
import wandb

[K     |████████████████████████████████| 1.7 MB 5.5 MB/s 
[K     |████████████████████████████████| 181 kB 47.5 MB/s 
[K     |████████████████████████████████| 144 kB 44.9 MB/s 
[K     |████████████████████████████████| 63 kB 1.5 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
'''
This program implements the ResNet architecture.
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable


__all__ = ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202',
           'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']


def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlockSmall(nn.Module):
    ''' A basic block for small ResNet architectures. '''
    expansion = 1

    def __init__(self, in_planes, planes, stride = 1, option = 'A'):
        super(BasicBlockSmall, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                # For CIFAR10 ResNet paper, uses option A.
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size = 1, stride = stride, bias = False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNetSmall(nn.Module):
    ''' Small ResNet architectures. '''
    
    def __init__(self, block, num_blocks, num_classes = 10):
        super(ResNetSmall, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride = 1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride = 2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride = 2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out



def conv3x3(in_planes, out_planes, stride = 1):
    return nn.Conv2d(in_planes, out_planes, kernel_size = 3, stride = stride, padding = 1, bias = False)


class BasicBlockLarge(nn.Module):
    ''' A basic block for large ResNet architectures. '''
    expansion = 1

    def __init__(self, in_planes, planes, stride = 1):
        super(BasicBlockLarge, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBlockLarge(nn.Module):
    ''' Pre-activation version of the BasicBlockLarge. '''
    expansion = 1

    def __init__(self, in_planes, planes, stride = 1):
        super(PreActBlockLarge, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size = 1, stride = stride, bias = False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class BottleneckLarge(nn.Module):
    ''' Bottleneck for large ResNet architectures. '''
    expansion = 4

    def __init__(self, in_planes, planes, stride = 1):
        super(BottleneckLarge, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBottleneckLarge(nn.Module):
    ''' Pre-activation version of the original BottleneckLarge module. '''
    expansion = 4

    def __init__(self, in_planes, planes, stride = 1):
        super(PreActBottleneckLarge, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size = 1, bias = False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size = 1, stride = stride, bias = False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class ResNetLarge(nn.Module):
    ''' Large ResNet architectures. '''
    
    def __init__(self, block, num_blocks, num_classes = 10):
        super(ResNetLarge, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride = 1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride = 2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride = 2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride = 2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, lin = 0, lout = 5):
        out = x
        if lin < 1 and lout > -1:
            out = self.conv1(out)
            out = self.bn1(out)
            out = F.relu(out)
        if lin < 2 and lout > 0:
            out = self.layer1(out)
        if lin < 3 and lout > 1:
            out = self.layer2(out)
        if lin < 4 and lout > 2:
            out = self.layer3(out)
        if lin < 5 and lout > 3:
            out = self.layer4(out)
        if lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        return out



def ResNet(arch, num_classes = 10):
    ''' Constructs a ResNet model. '''
    if arch == 'resnet20':
        return ResNetSmall(BasicBlockSmall, [3, 3, 3], num_classes)
    elif arch == 'resnet32':
        return ResNetSmall(BasicBlockSmall, [5, 5, 5], num_classes)
    elif arch == 'resnet44':
        return ResNetSmall(BasicBlockSmall, [7, 7, 7], num_classes)
    elif arch == 'resnet56':
        return ResNetSmall(BasicBlockSmall, [9, 9, 9], num_classes)
    elif arch == 'resnet110':
        return ResNetSmall(BasicBlockSmall, [18, 18, 18], num_classes)
    elif arch == 'resnet1202':
        return ResNetSmall(BasicBlockSmall, [200, 200, 200], num_classes)

    elif arch == 'resnet18':
        return ResNetLarge(BasicBlockLarge, [2,2,2,2], num_classes)
    elif arch == 'resnet34':
        return ResNetLarge(BasicBlockLarge, [3,4,6,3], num_classes)
    elif arch == 'resnet50':
        return ResNetLarge(BottleneckLarge, [3,4,6,3], num_classes)
    elif arch == 'resnet101':
        return ResNetLarge(BottleneckLarge, [3,4,23,3], num_classes)
    elif arch == 'resnet152':
        return ResNetLarge(BottleneckLarge, [3,8,36,3], num_classes)

    else:
        raise Exception("Invalid architecture.")


def test():
    ''' Tests the implementation of ResNet. '''
    for arch in __all__: 
        net = ResNet(arch)
        y = net(Variable(torch.randn(1,3,32,32)))
        print(y.size())
        print()
        


In [None]:


import torch
import numpy as np


class Mixup(object):
    ''' Mixup regularization, which adds convex combinations of training examples. '''
    
    def __init__(self):
        ''' Initializes a Mixup object. '''
        self.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def mixup_data(self, x, y, alpha = 1.0):
        '''
        Adds convex combinations of training examples.
            
        Args:
            x (Tensor): Inputs.
            y (Tensor): Target labels.
            alpha (int): Mixup interpolation coefficient.
        
        Returns:
            tuple: Mixed inputs, pairs of targets, and lambda.
        '''
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        '''
        Loss function for mixup.
            
        Args:
            criterion (Loss): A loss function.
            pred (Tensor): Predicted outputs.
            y_a, y_b (Tensor): Mixup labels.
            lam (float): lambda coefficient.
        
        Returns:
            Tensor: The value of mixup loss function.
        '''
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


class Cutout(object):
    ''' Cutout regularization, which randomly masks out one or more patches from an image. '''
    
    def __init__(self, n_holes, length):
        '''
        Initializes a Cutout object.
        
        Args:
            n_holes (int): The number of patches to cut out of each image.
            length (int): The length (in pixels) of each square patch.
        '''
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        '''
        Randomly masks out one or more patches from an image.
        
        Args:
            img (Tensor): A Tensor image of size (C, H, W).
            
        Returns:
            Tensor: An image with n_holes of dimension length x length cut out of it.
        '''
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img


In [None]:


import torch
import numpy as np


class Mask(object):
    ''' A Mask that performs soft filter pruning. '''

    def __init__(self, model, args):
        ''' Initializes the mask. '''
        self.model_size = {}
        self.model_length = {}
        self.compress_rate = {}
        self.mat = {}
        self.mask_index = []
        self.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = model.to(self.device)
        self.args = args

    def get_codebook(self, weight_torch, compress_rate, length):
        ''' Gets codebook. '''
        weight_vec = weight_torch.view(length)
        weight_np = weight_vec.cpu().numpy()

        weight_abs = np.abs(weight_np)
        weight_sort = np.sort(weight_abs)

        threshold = weight_sort[int(length * (1 - compress_rate))]
        weight_np[weight_np <= -threshold] = 1
        weight_np[weight_np >= threshold] = 1
        weight_np[weight_np != 1] = 0

        print("Codebook done.")
        return weight_np

    def get_filter_codebook(self, weight_torch, compress_rate, length):
        ''' Gets filter codebook. '''
        codebook = np.ones(length)
        if len(weight_torch.size()) == 4:
            filter_pruned_num = int(weight_torch.size()[0] * (1 - compress_rate))
            weight_vec = weight_torch.view(weight_torch.size()[0], -1)
            norm2 = torch.norm(weight_vec, 2, 1)
            norm2_np = norm2.cpu().numpy()
            filter_index = norm2_np.argsort()[:filter_pruned_num]
            kernel_length = weight_torch.size()[1] * weight_torch.size()[2] * weight_torch.size()[3]
            for x in range(0, len(filter_index)):
                codebook[filter_index[x] * kernel_length: (filter_index[x] + 1) * kernel_length] = 0
          
            print("Filter codebook done.")
        else:
            pass
        return codebook

    def convert2tensor(self, x):
        ''' Converts an input to PyTorch tensor. '''
        x = torch.FloatTensor(x)
        return x

    def init_length(self):
        ''' Initializes the length of each layer. '''
        for index, item in enumerate(self.model.parameters()):
            self.model_size[index] = item.size()

        for index1 in self.model_size:
            for index2 in range(0, len(self.model_size[index1])):
                if index2 == 0:
                    self.model_length[index1] = self.model_size[index1][0]
                else:
                    self.model_length[index1] *= self.model_size[index1][index2]

    def init_rate(self, layer_rate):
        ''' Initializes the compression rate of each layer. '''
        for index, item in enumerate(self.model.parameters()):
            self.compress_rate[index] = 1
        for key in range(self.args.layer_begin, self.args.layer_end + 1, self.args.layer_inter):
            self.compress_rate[key] = layer_rate
      
        # Last index includes last fully connected layer.
        last_index = 0
        skip_list = []
      
        if self.args.arch == 'resnet20':
            last_index = 57
        elif self.args.arch == 'resnet32':
            last_index = 93
        elif self.args.arch == 'resnet44':
            last_index = 129
        elif self.args.arch == 'resnet56':
            last_index = 165
        elif self.args.arch == 'resnet110':
            last_index = 327
        elif self.args.arch == 'resnet1202':
            last_index = 3603

        elif self.args.arch == 'resnet18':
            last_index = 60
            skip_list = [21, 36, 51]
        elif self.args.arch == 'resnet34':
            last_index = 108
            skip_list = [27, 54, 93]
        elif self.args.arch == 'resnet50':
            last_index = 159
            skip_list = [12, 42, 81, 138]
        elif self.args.arch == 'resnet101':
            last_index = 312
            skip_list = [12, 42, 81, 291]
        elif self.args.arch == 'resnet152':
            last_index = 465
            skip_list = [12, 42, 117, 444]
          
        self.mask_index = [x for x in range(0, last_index, 3)]

        # Skips downsample layer.
        if self.args.skip_downsample == 1:
            for x in skip_list:
                self.compress_rate[x] = 1
                self.mask_index.remove(x)

    def init_mask(self, layer_rate):
        ''' Initializes the mask. '''
        self.init_rate(layer_rate)
        for index, item in enumerate(self.model.parameters()):
            if (index in self.mask_index):
                self.mat[index] = self.get_filter_codebook(item.data, self.compress_rate[index],
                                                           self.model_length[index])
                self.mat[index] = self.convert2tensor(self.mat[index])
                self.mat[index] = self.mat[index].to(self.device)
        print("Mask ready.")

    def do_mask(self):
        ''' Performs pruning. '''
        for index, item in enumerate(self.model.parameters()):
            if (index in self.mask_index):
                a = item.data.view(self.model_length[index])
                b = a * self.mat[index]
                item.data = b.view(self.model_size[index])
        print("Mask done.")

    def if_zero(self):
        ''' Prints information about network weights. '''
        for index, item in enumerate(self.model.parameters()):
            if index in [x for x in range(self.args.layer_begin, self.args.layer_end + 1, self.args.layer_inter)]:
                a = item.data.view(self.model_length[index])
                b = a.cpu().numpy()

                print("Layer: %d, number of nonzero weight is %d, zero is %d" % (
                      index, np.count_nonzero(b), len(b) - np.count_nonzero(b)))


In [None]:
'''
This program implements a Lightning wrapper for ResNet and trains the model.
'''

from __future__ import print_function

import numpy as np
import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets


from pytorch_lightning import Trainer


# from resnet import ResNet
# from regularization import *
# from soft_filter_pruning import Mask

import warnings
warnings.filterwarnings('ignore')


class Net(pl.LightningModule):
    ''' Lightning wrapper for a ResNet model. '''
    
    def __init__(self, arch, criterion, args):
        '''
        Initializes the model.
        
        Args:
            arch (str): ResNet architecture.
            criterion (Loss): Loss function.
            args (Args): Arguments.
        '''
        super(Net, self).__init__()
        self.args = args
        if self.args.seed != 0:
            torch.manual_seed(self.args.seed)

        num_classes = 100 if self.args.dataset == 'cifar100' else 10
        self.net = ResNet(arch, num_classes)
        self.criterion = criterion
        self.mixup = Mixup() if self.args.regularize == 'mixup' else None
        
        if self.args.prune == 'soft_filter':
            self.mask = Mask(self.net, self.args)
            self.mask.init_length()
            self.mask.model = self.net
            self.mask.init_mask(self.args.pruning_rate)
            self.mask.do_mask()
            self.net = self.mask.model
        else:
            self.mask = None
            
    def forward(self, x):
        '''
        Performs a forward pass through the network.
        
        Args:
            x (Tensor): An input image.
            
        Returns:
            Tensor: An output vector.
        '''
        return self.net(x)

    def training_step(self, batch, batch_nb):
        '''
        Trains the model on a batch.
        
        Args:
            batch (Tensor): A batch.
            batch_nb (int): A batch index.
        '''
        inputs, targets = batch
        if self.args.regularize == 'mixup':
            inputs, targets_a, targets_b, lam = self.mixup.mixup_data(inputs, targets, self.args.alpha_mixup)
            inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b))
            outputs = self.forward(inputs)
            loss = self.mixup.mixup_criterion(self.criterion, outputs, targets_a, targets_b, lam)
        else:
            outputs = self.forward(inputs)
            loss = self.criterion(outputs, targets)
        
        tensorboard_logs = {'train_loss': loss}
        _, predicted = torch.max(outputs.data, 1)
        return {'loss': loss, 'progress_bar': tensorboard_logs, 'log': tensorboard_logs}

    def on_epoch_end(self):
        ''' Prunes the network at the end of each epoch. '''
        if self.args.prune == 'soft_filter':
            if self.current_epoch % self.args.epoch_prune == 0 or self.current_epoch == self.args.epochs - 1:
                self.mask.model = self.net
                self.mask.if_zero()
                self.mask.init_mask(self.args.pruning_rate)
                self.mask.do_mask()
                self.mask.if_zero()
                self.net = self.mask.model

    def validation_step(self, batch, batch_nb):
        '''
        Evaluates the model on a batch.
        
        Args:
            batch (Tensor): A batch.
            batch_nb (int): A batch index.
        '''
        inputs, targets = batch
        if self.args.regularize == 'mixup':
            inputs, targets = Variable(inputs, volatile = True), Variable(targets)
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)

        _, predicted = torch.max(outputs.data, 1)
        correct = predicted.eq(targets.data).cpu().sum()
        total = targets.size(0)
        acc = 1.0 * correct / total

        acc = torch.tensor(acc)
        return {'val_loss': loss, 'val_acc': acc}

    def validation_end(self, outputs):
        '''
        Records validation outcomes.
        
        Args:
            outputs (dict): Validation outputs for the whole test set.
        '''
        val_loss_mean = 0
        val_acc_mean = 0
        for output in outputs:
            val_loss_mean += output['val_loss']
            val_acc_mean += output['val_acc']
        val_loss_mean /= len(outputs)
        val_acc_mean /= len(outputs)
        
        tensorboard_logs = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
        wandb.log({"Test Accuracy": val_acc_mean, "Test Loss": val_loss_mean})
        return {'avg_val_loss': val_loss_mean, 'progress_bar': tensorboard_logs, 'log': tensorboard_logs}

    def configure_optimizers(self):
        ''' Configures optimizers and learning schedules. '''
        optimizer = optim.SGD(self.parameters(), lr = self.args.lr,
                              momentum = self.args.momentum,
                              weight_decay = self.args.decay)
        if self.args.dataset == 'svhn':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [80, 120], gamma = 0.1,
                                                       last_epoch = self.args.start_epoch - 1)
        elif self.args.regularize == 'cutout':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [60, 120, 160], gamma = 0.2,
                                                       last_epoch=self.args.start_epoch - 1)
        else:
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100, 150], gamma = 0.1,
                                                       last_epoch = self.args.start_epoch - 1)
        return [optimizer], [scheduler]

    # @pl.data_loader
    def train_dataloader(self):
        ''' Loads training dataset. '''
        if self.args.augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding = 4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

        if self.args.regularize == 'cutout':
            transform_train.transforms.append(Cutout(n_holes = self.args.n_holes_cutout,
                                                     length = self.args.length_cutout))
            
        dataloader = self.load_dataset(dataset = self.args.dataset, train = True,
                                       transform = transform_train, shuffle = True,
                                       batch_size = self.args.batch_size)
        return dataloader

    # @pl.data_loader
    def val_dataloader(self):
        ''' Loads validation dataset. '''
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        dataloader = self.load_dataset(dataset = self.args.dataset, train = False,
                                       transform = transform_test, shuffle = False,
                                       batch_size = 100)
        return dataloader

    # @pl.data_loader
    def test_dataloader(self):
        ''' Loads test dataset. '''
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        dataloader = self.load_dataset(dataset = self.args.dataset, train = False,
                                       transform = transform_test, shuffle = False,
                                       batch_size = 100)
        return dataloader

    def load_dataset(self, dataset, train, transform, shuffle, batch_size):
        '''
        Loads a dataset.
        
        Args:
            dataset (str): The name of the dataset.
            train (bool): Whether to load the training set or not.
            transform (Transform): Image transformations.
            shuffle (bool): Whether to shuffle the data or not.
            batch_size (int): The size of a batch.
            
        Returns:
            DataLoader: A DataLoader object containing the dataset.
        '''
        if self.args.dataset == 'cifar10':
            dataset = datasets.CIFAR10(root = '~/data',
                                       train = train,
                                       transform = transform,
                                       download = True)
            
        elif self.args.dataset == 'cifar100':
            dataset = datasets.CIFAR100(root = '~/data',
                                        train = train,
                                        transform = transform,
                                        download = True)
        
        elif self.args.dataset == 'stl10':
            split = 'train' if train else 'test'
            dataset = datasets.STL10(root = '~/data',
                                     split=split,
                                     transform = transform,
                                     download = True)

        elif self.args.dataset == 'svhn':
            if train:
                dataset = datasets.SVHN(root = '~/data',
                                        split = 'train',
                                        transform = transform,
                                        download = True)
                extra_dataset = datasets.SVHN(root = '~/data',
                                              split = 'extra',
                                              transform = transform,
                                              download = True)
              
                data = np.concatenate([dataset.data, extra_dataset.data], axis = 0)
                labels = np.concatenate([dataset.labels, extra_dataset.labels], axis = 0)
                dataset.data = data
                dataset.labels = labels
            else:
                dataset = datasets.SVHN(root = '~/data',
                                        split = 'test',
                                        transform = transform,
                                        download = True)
        
        dataloader = torch.utils.data.DataLoader(dataset = dataset,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 pin_memory = True,
                                                 num_workers = 2)
        return dataloader


In [None]:
'''
This program implements the ResNet architecture with mixup and cutout regularizations
and soft filter pruning.
'''

import os, sys
sys.path.append('.')
sys.path.append('./src')
import argparse

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import wandb

# from model import Net

import warnings
warnings.filterwarnings('ignore')


arch_options = ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202',
                'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
dataset_options = ['cifar10', 'cifar100', 'svhn', 'stl10', 'imagenet']
regularize_options = [None, 'mixup', 'cutout']
prune_options = [None, 'soft_filter']
layers_end = {'resnet20': 54, 'resnet32': 90, 'resnet44': 126, 'resnet56': 162, 'resnet110': 324, 'resnet1202': 3600,
              'resnet18': 57, 'resnet34': 105, 'resnet50': 156, 'resnet101': 309, 'resnet152': 462}


def parseArgs():
    ''' Reads command line arguments. '''
    parser = argparse.ArgumentParser(description = 'PyTorch ResNet Training.',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--arch', type = str, default = 'resnet20',
                        help = 'ResNet architecture.', choices = arch_options)
    parser.add_argument('--dataset', type = str, default = 'cifar10',
                        help = 'Dataset.', choices = dataset_options)
    parser.add_argument('--regularize', type = str, default = "cutout",
                        help = 'Regularization.', choices = regularize_options)
    parser.add_argument('--prune', type = str, default = "soft_filter",
                        help = 'Pruning.', choices = prune_options)
    
    # Arguments for training.
    parser.add_argument('--batch-size', type = int, default = 128, help = 'Batch size.')
    parser.add_argument('--lr', type = float, default = 0.1, help = 'Learning rate.')
    parser.add_argument('--start-epoch', type = int, default = 0, help = 'Starting epoch.')
    parser.add_argument('--epochs', type = int, default = 20, help = 'Number of epochs.')
    parser.add_argument('--augment', action = 'store_true', default = False,
                        help = 'Augment data by flipping and cropping.')
    parser.add_argument('--decay', type = float, default = 1e-4, help = 'Weight decay.')
    parser.add_argument('--momentum', default = 0.9, type = float,
                        metavar = 'M', help = 'Momentum.')
    parser.add_argument('--seed', type = int, default = 0, help = 'Random seed.')
    parser.add_argument('--resume', action = 'store_true', default = False,
                        help = 'Resume from checkpoint.')
    
    # Arguments for regularization.
    parser.add_argument('--alpha-mixup', type = float, default = 1.,
                        help = 'Mixup interpolation coefficient.')
    parser.add_argument('--n-holes-cutout', type = int, default = 1,
                        help = 'Number of holes to cut out from image.')
    parser.add_argument('--length-cutout', type = int, default = 16,
                        help = 'Length of the holes in cutout.')
    
     # Arguments for pruning.
    parser.add_argument('--pruning-rate', type = float, default = 0.9,
                        help = 'Compress rate of model.')
    parser.add_argument('--epoch-prune', type = int, default = 1,
                        help = 'Frequency of pruning.')
    parser.add_argument('--skip-downsample', type = int, default = 1,
                        help = 'Compress layer of model.')
    
    args = parser.parse_known_args()[0]
    args.layer_begin = 0
    args.layer_end = layers_end[args.arch]
    args.layer_inter = 3
    return args


def main():
    ''' Main program. '''
    print("Welcome to Our CNN Program.")
    args = parseArgs()
    
    model = Net(arch = args.arch, criterion = nn.CrossEntropyLoss(), args = args)
    wandb.init(project = "ResNet-Regularization-Pruning", tags = [args.arch], name = (args.arch))
    wandb.watch(model)
    
    gpus = 1 if torch.cuda.is_available() else 0
    trainer = Trainer(gpus = gpus, min_epochs = 1, max_epochs = args.epochs,
                       check_val_every_n_epoch = 1)
    
    trainer.fit(model)
    wandb.save('model.h5')
    print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
    print('and going to http://localhost:6006 on your browser')
    # trainer.test(model) 

    ## Idhar error h
    ## Idhar error h
    ## Idhar error h
    ## Idhar error h
    ## Idhar error h
    ## Idhar error h
    ## Idhar error h
    
    
if __name__ == '__main__':
    main()


Welcome to Our CNN Program.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Filter codebook done.
Mask ready.
Mask done.





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | net       | ResNetSmall      | 269 K 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
269 K     Trainable params
0         Non-trainable params
269 K     Total params
1.079     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Files already downloaded and verified
Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2160, zero is 144
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2160, zero is 144
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2160, zero is 144
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8352, zero is 864
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 8352, zero is 864
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 8352, zero is 864
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 33408, zero is 3456
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 33408, zero

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

Validation: 0it [00:00, ?it/s]

Layer: 0, number of nonzero weight is 405, zero is 27
Layer: 3, number of nonzero weight is 2160, zero is 144
Layer: 6, number of nonzero weight is 2295, zero is 9
Layer: 9, number of nonzero weight is 2160, zero is 144
Layer: 12, number of nonzero weight is 2295, zero is 9
Layer: 15, number of nonzero weight is 2160, zero is 144
Layer: 18, number of nonzero weight is 2295, zero is 9
Layer: 21, number of nonzero weight is 4176, zero is 432
Layer: 24, number of nonzero weight is 8613, zero is 603
Layer: 27, number of nonzero weight is 8352, zero is 864
Layer: 30, number of nonzero weight is 9135, zero is 81
Layer: 33, number of nonzero weight is 8352, zero is 864
Layer: 36, number of nonzero weight is 9135, zero is 81
Layer: 39, number of nonzero weight is 16704, zero is 1728
Layer: 42, number of nonzero weight is 34452, zero is 2412
Layer: 45, number of nonzero weight is 33408, zero is 3456
Layer: 48, number of nonzero weight is 36540, zero is 324
Layer: 51, number of nonzero weight is

In [None]:
'''
This program calculates the number of floating point operations (FLOPs)
for various network architectures and pruning rates.


'''

def cifar_resnet_flop(layer = 110, prune_rate = 1):
    '''
    Compares the number of FLOPs for a ResNet model with a pruning rate.
    
    Args:
        layer (int): The ResNet network size for CIFAR.
        prune_rate (int): Compression rate, 1 means baseline.
        
    Returns:
        int: The number of FLOPs of the network.
    '''
    flop = 0
    channel = [16, 32, 64]
    width = [32, 16, 8]

    stage = int(layer / 3)
    for index in range(0, layer, 1):
        if index == 0:  # First convolutional layer before block.
            flop += channel[0] * width[0] * width[0] * 9 * 3 * prune_rate
        elif index in [1, 2]:  # First block of first stage.
            flop += channel[0] * width[0] * width[0] * 9 * channel[0] * (prune_rate ** 2)
        elif 2 < index <= stage:  # Other blocks of first stage.
            if index % 2 != 0:
                # First layer of block, only output channal reduced, input channel remain the same.
                flop += channel[0] * width[0] * width[0] * 9 * channel[0] * (prune_rate)
            elif index % 2 == 0:
                # Second layer of block, both input and output channal reduced.
                flop += channel[0] * width[0] * width[0] * 9 * channel[0] * (prune_rate ** 2)
        elif stage < index <= stage * 2:  # Second stage.
            if index % 2 != 0:
                flop += channel[1] * width[1] * width[1] * 9 * channel[1] * (prune_rate)
            elif index % 2 == 0:
                flop += channel[1] * width[1] * width[1] * 9 * channel[1] * (prune_rate ** 2)
        elif stage * 2 < index <= stage * 3:  # Third stage.
            if index % 2 != 0:
                flop += channel[2] * width[2] * width[2] * 9 * channel[2] * (prune_rate)
            elif index % 2 == 0:
                flop += channel[2] * width[2] * width[2] * 9 * channel[2] * (prune_rate ** 2)

    # Offset for dimension change between blocks.
    offset1 = channel[1] * width[1] * width[1] * 9 * channel[1] * prune_rate - channel[1] * width[1] * width[1] * 9 * \
              channel[0] * prune_rate
    offset2 = channel[2] * width[2] * width[2] * 9 * channel[2] * prune_rate - channel[2] * width[2] * width[2] * 9 * \
              channel[1] * prune_rate
    flop = flop - offset1 - offset2
    return flop


def cal_cifar_resnet_flop(layer, prune_rate):
    '''
    Compares the number of FLOPs for a ResNet model
    with and without a pruning rate.
    
    Args:
        layer (int): The ResNet network size for CIFAR.
        prune_rate (int): Compression rate, 1 means baseline.
    '''
    pruned_flop = cifar_resnet_flop(layer, prune_rate)
    baseline_flop = cifar_resnet_flop(layer, 1)

    print(
        "Pruning rate of layer {:d} is {:.1f}, Pruned FLOP is {:.0f}, "
        "Baseline FLOP is {:.0f}, FLOP reduction rate is {:.4f}"
        .format(layer, prune_rate, pruned_flop, baseline_flop, 1 - pruned_flop / baseline_flop))


def main():
    ''' Main program. '''
    layer_list = [20, 32, 44, 56, 110, 1202,
                  18, 34, 50, 101, 152]
    pruning_rate_list = [0.9, 0.8, 0.7]
    for layer in layer_list:
        for pruning_rate in pruning_rate_list:
            cal_cifar_resnet_flop(layer, pruning_rate)


if __name__ == '__main__':
    main()


Pruning rate of layer 20 is 0.9, Pruned FLOP is 34371994, Baseline FLOP is 40550400, FLOP reduction rate is 0.1524
Pruning rate of layer 20 is 0.8, Pruned FLOP is 28665446, Baseline FLOP is 40550400, FLOP reduction rate is 0.2931
Pruning rate of layer 20 is 0.7, Pruned FLOP is 23430758, Baseline FLOP is 40550400, FLOP reduction rate is 0.4222
Pruning rate of layer 32 is 0.9, Pruned FLOP is 58578371, Baseline FLOP is 68861952, FLOP reduction rate is 0.1493
Pruning rate of layer 32 is 0.8, Pruned FLOP is 49049764, Baseline FLOP is 68861952, FLOP reduction rate is 0.2877
Pruning rate of layer 32 is 0.7, Pruned FLOP is 40276132, Baseline FLOP is 68861952, FLOP reduction rate is 0.4151
Pruning rate of layer 44 is 0.9, Pruned FLOP is 82784748, Baseline FLOP is 97173504, FLOP reduction rate is 0.1481
Pruning rate of layer 44 is 0.8, Pruned FLOP is 69434081, Baseline FLOP is 97173504, FLOP reduction rate is 0.2855
Pruning rate of layer 44 is 0.7, Pruned FLOP is 57121505, Baseline FLOP is 97173