In [1]:
import copy, time, os, random

import numpy as np
import matplotlib.pyplot as plt
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision import models

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [2]:
SEED = 24
BATCH_SIZE = 16
EPOCHS = 12
INIT_LR = 1e-3
MIN_LR = 1e-6
# MODEL_NAME = 'efficientnet_b0'
MODEL_NAME = 'alexnet'
TEACHER_NAME = 'efficientnet_b0'


PATH = f'checkpoints/best_cifar100_{MODEL_NAME}.ptn'
TPATH = f'checkpoints/best_cifar100_{TEACHER_NAME}.ptn'
DPATH = f'checkpoints/best_cifar100_distill_{MODEL_NAME}.ptn'


In [3]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(SEED)

In [4]:
model = models.alexnet(pretrained=True)
model.classifier[4] = torch.nn.Linear(4096,1024)
model.classifier[6] = torch.nn.Linear(1024,100)
model.load_state_dict(torch.load(PATH))
model.to('cuda')


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [4]:
# config = resolve_data_config({}, model=model)
# transform = create_transform(**config)

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
trainset = datasets.CIFAR100(
    root='data',
    train=True,
    transform=transform
)

testset = datasets.CIFAR100(
    root='data',
    train=False,
    transform=transform
)

loaders = {
    'train' : DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False),
    'test' : DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
}

dataset_sizes={
    'train' : len(trainset),
    'test' : len(testset), 
}

In [161]:
criterion = torch.nn.CrossEntropyLoss()
criterion.to('cuda')

# optimizer = torch.optim.SGD(model.parameters(), lr=INIT_LR, momentum=0.9)
optimizer = torch.optim.SGD(distill_model.parameters(), lr=0.3, momentum=0.9)


# optimizer = torch.optim.Adam(model.parameters(), lr = INIT_LR)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=MIN_LR)
# sheduler = torch.optim.lr_scheduler.StepLR(optimizer, 4, 0.1)

In [6]:
def fine_tune(model, loaders, criterion, optimizer, scheduler=None):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(EPOCHS):
        print(f'Epoch {epoch}/{EPOCHS - 1}')
        print('-' * 20)

        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in loaders[phase]:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # indices of max probs 
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train' and scheduler:
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60 :.0f}m {time_elapsed % 60 :.0f}s')
    print(f'Best test Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), PATH)

    return model


In [7]:
def train_epoch(model, loader, criterion, optimizer, device='cpu'):
    model.train()

    running_loss = 0.0
    running_corrects = 0
    
    bar = tqdm.tqdm(loader)
    for inputs, labels in bar:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    res_loss = running_loss / len(loader.dataset)
    res_accuracy = running_corrects.double() / len(loader.dataset)


    return res_loss, res_accuracy


def evaluate(model, loader, neval_batches=None, device='cpu'):
    model.eval()

    if neval_batches == None:
        neval_batches = len(loader)
    
    running_loss = 0.0
    running_corrects = 0
    cnt = 0

    bar = tqdm.tqdm(loader)
    with torch.no_grad():
        for inputs, labels in bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            cnt += 1
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
            if cnt >= neval_batches:
                test_loss = running_loss / len(loader.dataset)
                test_accuracy = running_corrects.double() / len(loader.dataset)
                return test_loss, test_accuracy
    
    test_loss = running_loss / len(loader.dataset)
    test_accuracy = running_corrects.double() / len(loader.dataset)


    return test_loss, test_accuracy

In [9]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [15]:
model = fine_tune(model, loaders, criterion, optimizer, scheduler=None)


Epoch 0/11
--------------------
Train Loss: 2.3345 Acc: 0.3938
Test Loss: 1.5734 Acc: 0.5557

Epoch 1/11
--------------------
Train Loss: 1.5375 Acc: 0.5666
Test Loss: 1.3884 Acc: 0.6074

Epoch 2/11
--------------------
Train Loss: 1.2459 Acc: 0.6372
Test Loss: 1.3802 Acc: 0.6164

Epoch 3/11
--------------------
Train Loss: 1.0450 Acc: 0.6891
Test Loss: 1.3323 Acc: 0.6260

Epoch 4/11
--------------------
Train Loss: 0.8804 Acc: 0.7347
Test Loss: 1.3440 Acc: 0.6294

Epoch 5/11
--------------------
Train Loss: 0.7520 Acc: 0.7679
Test Loss: 1.3040 Acc: 0.6476

Epoch 6/11
--------------------
Train Loss: 0.6507 Acc: 0.7980
Test Loss: 1.3361 Acc: 0.6417

Epoch 7/11
--------------------
Train Loss: 0.5577 Acc: 0.8231
Test Loss: 1.3726 Acc: 0.6443

Epoch 8/11
--------------------
Train Loss: 0.4926 Acc: 0.8415
Test Loss: 1.3714 Acc: 0.6508

Epoch 9/11
--------------------
Train Loss: 0.4387 Acc: 0.8581
Test Loss: 1.3722 Acc: 0.6564

Epoch 10/11
--------------------
Train Loss: 0.3974 Acc: 0.8

In [21]:
loss, acc = evaluate(model, loaders['test'], device='cuda')
print()
print(f'Test Loss: {loss:.4f} Acc: {acc:.4f}')

100%|█████████▉| 624/625 [00:22<00:00, 28.23it/s]
Test Loss: 1.4306 Acc: 0.6643



In [22]:
loss, acc = evaluate(teacher_model, loaders['test'], device='cuda')
print()
print(f'Test Loss: {loss:.4f} Acc: {acc:.4f}')

100%|█████████▉| 624/625 [00:27<00:00, 23.05it/s]
Test Loss: 1.5595 Acc: 0.7136



In [16]:
model.to('cpu')

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

## Quantization

In [17]:
import torch.quantization
from timm.models.layers.adaptive_avgmax_pool import SelectAdaptivePool2d

In [18]:
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Conv2d, torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): DynamicQuantizedLinear(in_features=9216, out_features=40

In [13]:
class QuantizedNet(torch.nn.Module):
    def __init__(self, model, config):
        super(QuantizedNet, self).__init__()
        self.qconfig = config
        self.quant = torch.quantization.QuantStub(qconfig=self.qconfig)
        self.dequant = torch.quantization.DeQuantStub()
        self.model = copy.deepcopy(model)

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

In [14]:
config = torch.quantization.get_default_qconfig('fbgemm')
quant_model = QuantizedNet(pruned_model, config)
quant_model.to('cpu')

QuantizedNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (model): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequent

In [24]:
print_size_of_model(model)
print_size_of_model(quant_model)
print_size_of_model(quantized_model)

Size (MB): 178.087279
Size (MB): 44.657919
Size (MB): 51.953055


In [None]:
# quant_model.qconfig
# quant_model.model.act1.qconfig = None
# quant_model.model.act2.qconfig = None


In [15]:
torch.quantization.prepare(quant_model, inplace=True)



QuantizedNet(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
  (model): AlexNet(
    (features): Sequential(
      (0): Conv2d(
        3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)
        (activation_post_process): HistogramObserver()
      )
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(
        64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)
        (activation_post_process): HistogramObserver()
      )
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(
        192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (7): ReLU(inplace=True)
      (8): Conv2d(
        384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): Hi

In [16]:
evaluate(quant_model, loaders['train'], 64)

  2%|▏         | 63/3125 [00:19<15:34,  3.28it/s]


(0.0035561533105373383, tensor(0.0194, dtype=torch.float64))

In [17]:
torch.quantization.convert(quant_model, inplace=True)

QuantizedNet(
  (quant): Quantize(scale=tensor([0.0374]), zero_point=tensor([57]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (model): AlexNet(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), scale=0.38988518714904785, zero_point=66, padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): QuantizedConv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), scale=1.194999098777771, zero_point=91, padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): QuantizedConv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), scale=1.3966145515441895, zero_point=90, padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): QuantizedConv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.8753744959831238, zero_point=86, padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Quant

In [18]:
evaluate(quant_model, loaders['test'])


100%|█████████▉| 624/625 [01:31<00:00,  6.80it/s]


(1.405473821568489, tensor(0.6585, dtype=torch.float64))

In [26]:
evaluate(quantized_model, loaders['test'])


100%|█████████▉| 624/625 [01:34<00:00,  6.60it/s]


(1.4312295955657959, tensor(0.6637, dtype=torch.float64))

## QAT

In [33]:
qat_model = models.alexnet(pretrained=False)
qat_model.classifier[4] = torch.nn.Linear(4096,1024)
qat_model.classifier[6] = torch.nn.Linear(1024,100)
qat_model.load_state_dict(torch.load(PATH))
qat_model.to('cpu')


optimizer = torch.optim.SGD(qat_model.parameters(), lr=INIT_LR, momentum=0.9)
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qat_model, inplace=True)

criterion = torch.nn.CrossEntropyLoss()
criterion.to('cpu')

 

CrossEntropyLoss()

In [34]:
since = time.time()


best_model_wts = copy.deepcopy(qat_model.state_dict())
best_acc = 0.0

for epoch in range(6):
    train_epoch(qat_model, loaders['train'], criterion, optimizer)
    if epoch > 3:
        qat_model.apply(torch.quantization.disable_observer)
    qzed_model = torch.quantization.convert(qat_model.eval(), inplace=False)
    qzed_model.eval()
    loss, acc = evaluate(qzed_model, loaders['test'])
    
    if acc > best_acc:
        best_acc = acc
        best_model_wts = copy.deepcopy(qzed_model.state_dict())
    
    print (f'Epoch {epoch}: Eval Acc {acc} ')

time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60 :.0f}m {time_elapsed % 60 :.0f}s')
print(f'Best test Acc: {best_acc:.4f}')

qat_model.load_state_dict(best_model_wts)
# torch.save(model.state_dict(), PATH)


100%|██████████| 3125/3125 [1:19:21<00:00,  1.52s/it]
  0%|          | 0/625 [00:00<?, ?it/s]


RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at ..\aten\src\ATen\native\quantized\cpu\qconv.cpp:873 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:43 [backend fallback]
AutogradXLA: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:47 [backend fallback]
Tracer: fallthrough registered at ..\torch\csrc\jit\frontend\tracer.cpp:999 [backend fallback]
Autocast: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:250 [backend fallback]
Batched: registered at ..\aten\src\ATen\BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]


## Pruning

In [8]:
pruned_model = models.alexnet(pretrained=True)
pruned_model.classifier[4] = torch.nn.Linear(4096,1024)
pruned_model.classifier[6] = torch.nn.Linear(1024,100)
pruned_model.load_state_dict(torch.load(PATH))
pruned_model.to('cuda')

criterion = torch.nn.CrossEntropyLoss()
criterion.to('cuda')

CrossEntropyLoss()

In [9]:
parameters_to_prune = (
    (pruned_model.features[0], 'weight'),
    (pruned_model.features[3], 'weight'),
    (pruned_model.features[6], 'weight'),
    (pruned_model.features[8], 'weight'),
    (pruned_model.features[10], 'weight'),
    (pruned_model.classifier[1], 'weight'),
    (pruned_model.classifier[4], 'weight'),
    (pruned_model.classifier[6], 'weight'),
)

In [10]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.5
)

In [12]:
evaluate(pruned_model, loaders['test'], device='cuda')

100%|█████████▉| 624/625 [00:29<00:00, 20.87it/s]


(1.3934545657634736, tensor(0.6605, device='cuda:0', dtype=torch.float64))

In [74]:
evaluate(model, loaders['test'], device='cuda')


100%|█████████▉| 624/625 [00:26<00:00, 23.36it/s]


(1.4306268773555755, tensor(0.6643, device='cuda:0', dtype=torch.float64))

In [50]:
pruned_model.state_dict().keys()

odict_keys(['features.0.bias', 'features.0.weight', 'features.3.bias', 'features.3.weight', 'features.6.bias', 'features.6.weight', 'features.8.bias', 'features.8.weight', 'features.10.bias', 'features.10.weight', 'classifier.1.bias', 'classifier.1.weight', 'classifier.4.bias', 'classifier.4.weight', 'classifier.6.bias', 'classifier.6.weight'])

In [11]:
list(map(lambda x: prune.remove(x[0],x[1]), parameters_to_prune))

[Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)),
 Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
 Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Linear(in_features=9216, out_features=4096, bias=True),
 Linear(in_features=4096, out_features=1024, bias=True),
 Linear(in_features=1024, out_features=100, bias=True)]

In [51]:
pruned_model.features[0].weight

Parameter containing:
tensor(indices=tensor([[ 0,  0,  0,  ..., 63, 63, 63],
                       [ 0,  0,  0,  ...,  2,  2,  2],
                       [ 0,  0,  0,  ..., 10, 10, 10],
                       [ 0,  1,  2,  ...,  8,  9, 10]]),
       values=tensor([ 0.1653,  0.1319,  0.1154,  ..., -0.0641,  0.0840,
                       0.0720]),
       device='cuda:0', size=(64, 3, 11, 11), nnz=21527, layout=torch.sparse_coo,
       requires_grad=True)

In [58]:
for module, _ in parameters_to_prune:
    module.weight = torch.nn.Parameter(module.weight.data.to_sparse())

In [59]:
# print_size_of_model(model)
print_size_of_model(pruned_model)

Size (MB): 203.334511


## Distillation

In [234]:
def KD_loss(outputs, labels, teacher_outputs, alpha=0.1, T=2):
    # student_loss = F.cross_entropy(F.softmax(outputs, dim=1), labels)
    # distill_loss = nn.KLDivLoss(reduction='batchmean')(
    #     F.log_softmax(outputs / T, dim=1),
    #     F.softmax(teacher_outputs / T, dim=1)
    # )
    # loss = alpha * student_loss + (1. - alpha) * distill_loss

    soft_student_out = F.softmax(outputs / T, dim=1)
    soft_teacher_out = F.softmax(teacher_outputs / T, dim=1)

    loss = (1 - alpha) * F.cross_entropy(outputs, labels)
    loss += (alpha * T * T) * nn.MSELoss()(soft_student_out, soft_teacher_out)

    return loss

def distill(model, teacher_model, loaders, loss_fn, optimizer, scheduler=None):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(EPOCHS):
        print(f'Epoch {epoch + 1}/{EPOCHS}')
        print('-' * 20)

        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            teacher_model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in loaders[phase]:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')

                optimizer.zero_grad()
                
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # indices of max probs 
                    
                    loss = loss_fn(outputs, labels, teacher_outputs)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train' and scheduler:
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60 :.0f}m {time_elapsed % 60 :.0f}s')
    print(f'Best test Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), DPATH)

    return model

In [239]:
distill_model = models.alexnet(pretrained=True)
distill_model.classifier[4] = torch.nn.Linear(4096,1024)
distill_model.classifier[6] = torch.nn.Linear(1024,100)
distill_model.to('cuda')

teacher_model = timm.create_model(TEACHER_NAME, pretrained=False, num_classes=100)
teacher_model.load_state_dict(torch.load(TPATH))
teacher_model.to('cuda')

optim = torch.optim.SGD(distill_model.parameters(), lr=0.1, momentum=0.9)
# optimizer = torch.optim.Adam(distill_model.parameters(), lr = 0.1)
# optimizer = torch.optim.RMSprop()

# optimizer = torch.optim.SGD(distill_model.parameters(), lr=0.3, momentum=0.9)


# optimizer = torch.optim.Adam(model.parameters(), lr = INIT_LR)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=MIN_LR)
# sheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, 0.3)

In [236]:
distill_model = distill(distill_model, teacher_model, loaders, KD_loss, optimizer)


Epoch 1/12
--------------------


RuntimeError: CUDA out of memory. Tried to allocate 74.00 MiB (GPU 0; 6.00 GiB total capacity; 4.27 GiB already allocated; 0 bytes free; 4.37 GiB reserved in total by PyTorch)

In [64]:
torch.cuda.empty_cache()

In [65]:
import gc
gc.collect()

315

In [233]:
evaluate(teacher_model, loaders['test'], device='cuda')


100%|█████████▉| 624/625 [00:36<00:00, 16.91it/s]


(1.5594913409113884, tensor(0.7136, device='cuda:0', dtype=torch.float64))