<h1>Network Slimming</h1>
This notebook demonstrates the implementation of this paper <a href=https://arxiv.org/abs/1708.06519> Learning Efficient Convolutional Networks through Network Slimming</a>
<h4>Steps to train a baseline model and then compress it given a channel budget are as follows:</h4>
<ul>
    <li>Load the YAML file. </li>
    <li>Load dataset and create dataloaders. </li>
    <li>Create <b>Network_Slimming</b> object and pass the parameters in the form of a dictionary. </li>
    <li>Pass the dataloaders into the <b>compress_model</b> method to obtain the compressed model. </li>
</ul>
Since this is a demo notebook the number of epochs have been set to 2.

In [1]:
import sys
import os
sys.path.append("../../../")
os.environ['CUDA_VISIBLE_DEVICES']='1'

import torch
from torchvision import transforms

import yaml

from trailmet.datasets.classification import DatasetFactory
from trailmet.models import ModelsFactory
from trailmet.algorithms.prune.network_slimming import Network_Slimming

  from .autonotebook import tqdm as notebook_tqdm


<h3> Loading the YAML file. </h3>

In [8]:
with open(os.path.join("./resnet50_cifar100.yaml"), 'r') as stream:
    data_loaded = yaml.safe_load(stream)
data_loaded['schema_root'] = "./"
data_loaded

{'num_classes': 100,
 'weight_decay': 0.0005,
 'net': 'resnet50',
 'dataset': 'c100',
 'epochs': 2,
 's': 0.003,
 'learning_rate': 0.002,
 'fine_tune_epochs': 2,
 'fine_tune_lr': 0.0004,
 'prune_ratio': 0.5,
 'wandb': True,
 'insize': 32,
 'schema_root': './'}

<h3>Loading CIFAR100Dataset</h3>

In [3]:
transform_train = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Pad(4, padding_mode='reflect'),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomCrop(32),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transforms1 = {
    'train': transform_train, 
    'val': transform_test, 
    'test': transform_test}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

cifar_dataset = DatasetFactory.create_dataset(name = 'CIFAR100', 
                                        root = "./data",
                                        split_types = ['train', 'val', 'test'],
                                        val_fraction = 0.1,
                                        transform = transforms1,
                                        target_transform = target_transforms,
                                        random_seed=42
                                        )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


<h5>Creating the dataloaders</h5>

In [4]:
dataloaders = { 'train' : torch.utils.data.DataLoader(
        cifar_dataset['train'], batch_size=64, 
        sampler=cifar_dataset['train_sampler'],
        num_workers=0
    ),
               'val':  torch.utils.data.DataLoader(
        cifar_dataset['val'], batch_size=64, 
        sampler=cifar_dataset['val_sampler'],
        num_workers=0
    ),  
               'test':  torch.utils.data.DataLoader(
        cifar_dataset['test'], batch_size=64, 
        sampler=cifar_dataset['test_sampler'],
        num_workers=0
    )}

## Loading Model

In [9]:
model = ModelsFactory.create_model(name='resnet50', pretrained=False, **data_loaded)

<h3> Creating the method's object proceed with compression. </h3>

In [10]:
slim = Network_Slimming(model, dataloaders, **data_loaded)

[34m[1mwandb[0m: Currently logged in as: [33manimesh-007[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
slim.compress_model()

Training Epoch [0] (704 / 704 Steps) (batch time=0.26610s) (data time=0.19870s) (loss=4.57249): 100%|| 704/704 [01:41<00:00,  6.92it/s]
Validating Epoch [0] (79 / 79 Steps) (batch time=0.01835s) (loss=4.47627) (top1=0.00000) (top5=12.50000): 100%|| 79/79 [00:01<00:00, 40.26it/s]


 * acc@1 1.100 acc@5 5.900


Training Epoch [1] (704 / 704 Steps) (batch time=0.09909s) (data time=0.05468s) (loss=4.66276): 100%|| 704/704 [02:02<00:00,  5.77it/s]
Validating Epoch [1] (79 / 79 Steps) (batch time=0.01553s) (loss=4.55720) (top1=0.00000) (top5=25.00000): 100%|| 79/79 [00:01<00:00, 39.72it/s]


 * acc@1 1.160 acc@5 5.760
model.layer1.0.bn3: 124
model.layer1.0.downsample.1: 125
model.layer1.1.bn3: 128
model.layer1.2.bn3: 132
merged indexes length: 244
model.layer2.0.bn3: 249
model.layer2.0.downsample.1: 254
model.layer2.1.bn3: 262
model.layer2.2.bn3: 262
model.layer2.3.bn3: 267
merged indexes length: 500
model.layer3.0.bn3: 491
model.layer3.0.downsample.1: 535
model.layer3.1.bn3: 537
model.layer3.2.bn3: 522
model.layer3.3.bn3: 564
model.layer3.4.bn3: 564
model.layer3.5.bn3: 554
merged indexes length: 1019
model.layer4.0.bn3: 975
model.layer4.0.downsample.1: 981
model.layer4.1.bn3: 937
model.layer4.2.bn3: 950
merged indexes length: 1720

BatchNorm2d prune info
|    | name                        | channels   | prune percent   |
|---:|:----------------------------|:-----------|:----------------|
|  0 | model.bn1                   | 34/64      | 46.88%          |
|  1 | model.layer1.0.bn1          | 36/64      | 43.75%          |
|  2 | model.layer1.0.bn2          | 29/64      | 5

Training Epoch [0] (704 / 704 Steps) (batch time=0.32138s) (data time=0.17613s) (loss=4.79924): 100%|| 704/704 [01:47<00:00,  6.57it/s]
Validating Epoch [0] (79 / 79 Steps) (batch time=0.01556s) (loss=4.59786) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:01<00:00, 42.06it/s] 


 * acc@1 1.760 acc@5 8.060


Training Epoch [1] (704 / 704 Steps) (batch time=0.14155s) (data time=0.05515s) (loss=4.93959): 100%|| 704/704 [02:00<00:00,  5.84it/s]
Validating Epoch [1] (79 / 79 Steps) (batch time=0.01560s) (loss=4.63713) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:02<00:00, 38.10it/s] 


 * acc@1 2.080 acc@5 9.020
