In [None]:
DEVICE = 'cuda:2'
DATA_PATH = '/workspace/code/Akash/ImageNet'
BATCH_SIZE = 256

In [None]:
import os
import sys
import torch
import torchvision.datasets as Datasets
from torchvision import transforms as tfms
from torch.utils.data import DataLoader
sys.path.append("../../")
torch.cuda.set_device(int(DEVICE[-1]))

In [None]:
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
train_tfms = tfms.Compose([
    tfms.RandomCrop(32, padding=4, padding_mode='reflect'),
    tfms.RandomHorizontalFlip(),
    tfms.ToTensor(),
    tfms.Normalize(*stats, inplace=True)
])
test_tfms = tfms.Compose([
    tfms.ToTensor(),
    tfms.Normalize(*stats)
])

In [None]:
cifar100_train = Datasets.CIFAR100(root='./data', train=True, download=True, transform=test_tfms)
cifar100_test = Datasets.CIFAR100(root='./data', train=False, download=True, transform=test_tfms)

train_loader = DataLoader(cifar100_train, shuffle=True, num_workers=1, batch_size=BATCH_SIZE)
test_loader = DataLoader(cifar100_test, shuffle=False, num_workers=1, batch_size=BATCH_SIZE)

dataloaders = {"train" : train_loader , "val" : test_loader}

In [None]:
# import libraries
from trailmet.models import resnet
from trailmet.algorithms.quantize.lapq import LAPQ

In [None]:
# load model
cnn=resnet.make_resnet50(100,32)
checkpoint = torch.load("./resnet50_cifar100-pretrained.pth", map_location=DEVICE)
cnn.load_state_dict(checkpoint['state_dict'])

In [None]:
# test model
from trailmet.algorithms.algorithms import BaseAlgorithm
BaseAlgorithm().test(model=cnn, dataloader=test_loader, device=torch.device(DEVICE))

In [None]:
# quantize model
kwargs = {
    'W_BITS':4, 
    'A_BITS':4, 
    'ACT_QUANT':True,
    'CALIB_BATCHES':1024//BATCH_SIZE, 
    'MAX_ITER':1,
    'MAX_FEV':1,
    'VERBOSE':True,
    'PRINT_FREQ':20,
    'GPU_ID':int(DEVICE[-1]),
    'SEED':42
    }
qnn = LAPQ(cnn, dataloaders, **kwargs)
qnn.compress_model()