In [None]:
import os
import sys
sys.path.append("../../")

In [None]:
import torch
from torchvision import transforms as transforms
from torchvision import datasets as datasets
torch.cuda.set_device(7)

In [None]:
def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4,
                        dist_sample: bool = False):
    print('==> Using Imagenet Dataset')

    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    #torchvision.set_image_backend('accimage')
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ]))

    if dist_sample:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True, sampler=val_sampler)
    return train_loader, val_loader

In [None]:
# import libraries
from trailmet.models import resnet
from trailmet.algorithms.quantize.lapq import LAPQ
from trailmet.algorithms.quantize.quantize import BaseQuantization

In [None]:
# load data
dataloaders = {'train':[], 'val':[]}
dataloaders['train'], dataloaders['val'] = build_imagenet_data(data_path='/workspace/code/Akash/ImageNet')

In [None]:
# load model
cnn = resnet.get_resnet_model('resnet50', 1000, 224, pretrained=True)

In [None]:
# quantize model
kwargs = {
    'W_BITS':8, 
    'A_BITS':8, 
    'ACT_QUANT':True,
    'SYMM':True,
    'UINT':True,
    'SET_8BIT_HEAD_STEM':False,
    'SKIP_HEAD_STEM':True,
    'NUM_SAMPLES':1024, 
    'MAX_ITER':500,
    'VERBOSE':True,
    'PRINT_FREQ':20,
    'GPU_ID':7,
    'SEED':42
    }
qnn = LAPQ(cnn, dataloaders, **kwargs)
qnn.compress_model()