In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
# from torchmetrics importo
from torcheval.metrics import MulticlassAccuracy
from torcheval.metrics.functional import multiclass_accuracy
from torch.ao.quantization import QuantStub, DeQuantStub

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        # print(x)
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [4]:
transform_train = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# trainset = torchvision.datasets.CIFAR10(
#     root='./CIFAR10', train=True, download=True, transform=transform_train)
trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True)

# split the train set into two
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

trainset = DatasetWrapper(trainset, transform_train)
validset = DatasetWrapper(validset, transform_test)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)
# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform_test)
# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')




Files already downloaded and verified
Files already downloaded and verified


In [5]:
class SeparableConv2d(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2d, self).__init__()
        self.dw_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=False),
        )
        self.pw_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False),
        )

    def forward(self, x):
        x = self.dw_conv(x)
        x = self.pw_conv(x)
        return x

In [6]:
class MyMobileNet(pl.LightningModule):
    cfg = [
        (32, 64, 1), 
        (64, 128, 2), 
        (128, 128, 1), 
        (128, 256, 2),
        (256, 256, 1),
        (256, 512, 2),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 1024, 2),
        (1024, 1024, 1),
    ]
    
    def __init__(self, num_classes=10, alpha: float = 1):
        super(MyMobileNet, self).__init__()
        conv_out = int(32 * alpha)
        self.conv = nn.Conv2d(3, conv_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(conv_out)
        self.relu = nn.ReLU(inplace=False)
        # self.accuracy = Accuracy("multiclass", num_classes=num_classes)
        # self.accuracy = MulticlassAccuracy()
        self.accuracy = multiclass_accuracy

        self.features = self.make_feature_extractor(alpha)
        self.linear = nn.Linear(1024, num_classes)

    def make_feature_extractor(self, alpha):
        layer_values = [(int(inp*alpha), int(out*alpha), chan) for inp, out, chan in self.cfg]
        layers = nn.Sequential(*[SeparableConv2d(*tup) for tup in layer_values])
        return layers

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = self.features(x)
        x = F.avg_pool2d(x, 2)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)
        return x
        
    def step(self, batch, batch_idx):
    # def training_step(self):
        x, y = batch
        logits = self.forward(x)
        loss = self.compute_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def training_step(self, batch, batch_idx):
    # def training_step(self):
        loss, acc = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
    # def training_step(self):
        loss, acc = self.step(batch, batch_idx)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def test_step(self, batch, batch_idx):
        loss, acc = self.step(batch, batch_idx)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def configure_optimizers(self):
        self.lr = 0.02089
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001, weight_decay=0.001)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.3, total_steps=MAX_EPOCHS * len(trainloader))
        return { "optimizer": optimizer, "lr_scheduler": scheduler }
        
    def compute_loss(self, logits, labels):
        return nn.functional.cross_entropy(logits, labels)

In [17]:
# from lightning.pytorch.tuner import Tuner
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

# MAX_EPOCHS = 5 # leads to better than random guessing on quantized model with default_qconfig
MAX_EPOCHS = 50

model = MyMobileNet()

checkpoint_callback = ModelCheckpoint(
    # monitor='val_accuracy',
    dirpath='checkpoints/',
    filename='QUANT-{epoch:02d}-{val_accuracy:.2f}'
)
progress_callback = RichProgressBar(leave=True)

trainer = pl.Trainer(max_epochs=MAX_EPOCHS, callbacks=[checkpoint_callback, progress_callback])

# tune = Tuner(trainer)
# optimal_lr = tune.lr_find(model, train_dataloaders=trainloader, val_dataloaders=validloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=validloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

`Trainer.fit` stopped: `max_epochs=50` reached.


In [19]:
trainer.test(model, testloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'test_loss': 0.5298764705657959, 'test_accuracy': 0.8281000256538391}]

In [20]:
MODEL_WEIGHTS_PATH = './checkpoints/test.pth'

In [21]:
print('==> saving model')
state = {
    'net': model.state_dict(),
    'acc': '0.747299',
    'epoch': 50,
}

torch.save(state, MODEL_WEIGHTS_PATH)

==> saving model


In [22]:
# quantized_model.to("cuda")
saved = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(saved['state_dict'])
model.load_state_dict(saved['net'])
trainer.test(model, testloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'test_loss': 0.5298764705657959, 'test_accuracy': 0.8280999660491943}]

# Quantization!

In [27]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')


class QuantizedMobileNet(MyMobileNet):
    def __init__(self, num_classes=10, alpha: float = 1):
        super(QuantizedMobileNet, self).__init__(num_classes, alpha)
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.relu(self.bn(self.conv(x)))
        x = self.features(x)
        x = F.avg_pool2d(x, 2)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)
        x = self.dequant(x)
        return x
        
    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics
    def fuse_model(self, is_qat=False):
        fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
        fuse_modules(self, ["conv", "bn", 'relu'], inplace=True)
        for m in self.modules():
            if type(m) == SeparableConv2d:
                fuse_modules(m.dw_conv, ['0', '1', '2'], inplace=True)
                fuse_modules(m.pw_conv, ['0', '1', '2'], inplace=True)

In [28]:
# Baseline
model = QuantizedMobileNet()
saved = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(saved['state_dict'])
model.load_state_dict(saved['net'])
print_size_of_model(model)

progress_callback = RichProgressBar(leave=True)

trainer = pl.Trainer(callbacks=[progress_callback])

trainer.test(model, testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Size (MB): 13.002226


Output()

[{'test_loss': 0.5298764705657959, 'test_accuracy': 0.8281000256538391}]

In [29]:
# Fused model
model = QuantizedMobileNet().to("cpu")
saved = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(saved['state_dict'])
model.load_state_dict(saved['net'])
model.eval()
model.fuse_model()
print_size_of_model(model)

progress_callback = RichProgressBar(leave=True)

trainer = pl.Trainer(accelerator = "cpu", callbacks=[progress_callback])

trainer.test(model, testloader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Size (MB): 12.846458


Output()

[{'test_loss': 0.5299545526504517, 'test_accuracy': 0.8281000256538391}]

In [30]:
backend = "fbgemm"
model = QuantizedMobileNet()
checkpoint = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['net'])
model.eval()
model.fuse_model()
# model.qconfig = torch.ao.quantization.default_qconfig # This leads to bad results.
model.qconfig = torch.quantization.get_default_qconfig(backend)

print(model.qconfig)
torch.ao.quantization.prepare(model, inplace=True)

# Calibrate first
print('Post Training Quantization Prepare: Inserting Observers')
print('\n SeparableConv Block:After observer insertion \n\n', model.features[1].dw_conv)

progress_callback = RichProgressBar(leave=True)
trainer = pl.Trainer(accelerator="cpu", callbacks=[progress_callback])

trainer.validate(model, validloader)

torch.ao.quantization.convert(model, inplace=True)

print('Post Training Quantization: Convert done')
print('\n SeparableConv Block: After fusion and quantization, note fused modules: \n\n', model.features[1].dw_conv)
print_size_of_model(model)

progress_callback = RichProgressBar(leave=True)
trainer = pl.Trainer(accelerator="cpu", callbacks=[progress_callback])

trainer.validate(model, validloader)
trainer.test(model, testloader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 SeparableConv Block:After observer insertion 

 Sequential(
  (0): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (1): Identity()
  (2): Identity()
)


Output()

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Post Training Quantization: Convert done

 SeparableConv Block: After fusion and quantization, note fused modules: 

 Sequential(
  (0): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(2, 2), scale=0.13965807855129242, zero_point=0, padding=(1, 1), groups=64)
  (1): Identity()
  (2): Identity()
)
Size (MB): 3.461908


Output()

Output()

[{'test_loss': 0.5302149653434753, 'test_accuracy': 0.829800009727478}]