
##**ChipNet**

This notebook demonstrates the implementation of the paper - **ChipNet: Budget-Aware Pruning with Heaviside Continuous Approximations**

Steps to train a baseline model and then compress it for a given budget are as follows: 
*   Load the YAML file.
*   Load dataset and create dataloaders.
*   Create ChipNet object and pass the parameters in the form of a dictionary.
*   Pass the dataloaders into the compress_model method to obtain the compressed model.

Since this is a demo notebook the number of epochs have been set to 1, 1 and 2 respectively for pretraining, pruning and finetuning respectively.



In [4]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from trailmet.models import ModelsFactory
from trailmet.datasets.classification import DatasetFactory
import yaml

In [5]:
root = "/content/trailmet/experiments/chipnet"

###Loading the yaml file for the configurations.

In [7]:
import os
with open(os.path.join(root, "resnet50_cifar100.yaml"), 'r') as stream:
    data_loaded = yaml.safe_load(stream)
print(data_loaded)

{'CHIPNET_ARGS': {'A': 10}, 'PRETRAIN': {'EPOCHS': 1}, 'PRUNE': {'EPOCHS': 1}, 'FINETUNE': {'EPOCHS': 2}}


###Loading the model.

In [8]:
model = ModelsFactory.create_model('resnet50', 100, False, insize=32)

###Loading the dataset.

In [9]:
from trailmet.datasets.classification import DatasetFactory
data_root = "/content/data_dir"

In [10]:
mkdir /content/data_dir

mkdir: cannot create directory ‘/content/data_dir’: File exists


In [11]:
train_transform = transforms.Compose(
[transforms.ToTensor()])

val_transform = transforms.Compose(
[transforms.ToTensor()])

test_transform = transforms.Compose(
[transforms.ToTensor()])

transforms1 = {
    'train': train_transform, 
    'val': val_transform, 
    'test': test_transform}
def train_target_transform(label):
    return label

def val_target_transform(label):
    return label

def test_target_transform(label):
    return label

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}


cifar_dataset = DatasetFactory.create_dataset(name = 'CIFAR100', 
                                        root = data_root,
                                        split_types = ['train', 'val', 'test'],
                                        val_fraction = 0.2,
                                        transform = transforms1,
                                        target_transform = target_transforms
)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


###Creating the dataloaders.

In [12]:
train_loader = torch.utils.data.DataLoader(
        cifar_dataset['train'], batch_size=64, 
        sampler=cifar_dataset['train_sampler'],
        num_workers=0
    )
val_loader = torch.utils.data.DataLoader(
        cifar_dataset['val'], batch_size=64, 
        sampler=cifar_dataset['val_sampler'],
        num_workers=0
    )
test_loader = torch.utils.data.DataLoader(
        cifar_dataset['test'], batch_size=64, 
        sampler=cifar_dataset['test_sampler'],
        num_workers=0
    )

###Creating the method's object followed by compression

In [13]:
from trailmet.algorithms.prune.chipnet import ChipNet

In [14]:
a = ChipNet(model, {'train': train_loader, 'val': val_loader, 'test': test_loader}, **data_loaded)

In [15]:
a.compress_model()

100%|██████████| 625/625 [00:56<00:00, 11.13it/s, loss=4.49]
100%|██████████| 157/157 [00:04<00:00, 31.72it/s, acc1=8.03, acc5=28, loss=4.3]


**Saving model**


100%|██████████| 157/157 [00:05<00:00, 29.40it/s, acc1=8.35, acc5=28.6, loss=4.33]


Test Accuracy: 8.349920382165605 | Valid Accuracy: 8.031449044585987
preparing model for pruning
Starting epoch 1 / 1


100%|██████████| 625/625 [01:56<00:00,  5.36it/s, loss=7.18]


[1 / 1] Validation before pruning


100%|██████████| 157/157 [00:12<00:00, 12.19it/s, acc1=19.5, acc5=48.7, loss=4.42]


[1 / 1] Validation after pruning


100%|██████████| 157/157 [00:09<00:00, 15.98it/s, acc1=1.01, acc5=5.21, loss=4.89e+5]


Changed beta to 1.02 changed gamma to 2.8284271247461903
**Saving checkpoint**


100%|██████████| 625/625 [01:00<00:00, 10.29it/s, loss=3.03]
100%|██████████| 157/157 [00:07<00:00, 22.39it/s, acc1=20.8, acc5=50.3, loss=3.44]


**Saving model**


100%|██████████| 625/625 [01:00<00:00, 10.35it/s, loss=2.77]
100%|██████████| 157/157 [00:07<00:00, 22.15it/s, acc1=26.9, acc5=57.9, loss=2.9]


**Saving model**


100%|██████████| 157/157 [00:06<00:00, 23.59it/s, acc1=27.7, acc5=58.8, loss=2.87]

Test Accuracy: 27.726910828025478 | Valid Accuracy: 26.930732484076433



