In [3]:
import os
import sys
sys.path.append("../../")

In [6]:
import torch
from torchvision import transforms as transforms
from torchvision import datasets as datasets

In [7]:
def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4,
                        dist_sample: bool = False):
    print('==> Using Imagenet Dataset')

    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    #torchvision.set_image_backend('accimage')
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ]))

    if dist_sample:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True, sampler=val_sampler)
    return train_loader, val_loader

In [8]:
dataloaders = {'train':[], 'val':[]}
dataloaders['train'], dataloaders['val'] = build_imagenet_data(data_path='/Users/saarim/Desktop/sample_imagenet')

==> Using Imagenet Dataset


In [17]:
from torch.hub import load_state_dict_from_url
from trailmet.models.resnet import make_resnet50
from trailmet.algorithms.quantize.brecq import BRECQ

In [18]:
cnn = make_resnet50(num_classes=1000, insize=224)
load_url = 'https://github.com/yhhhli/BRECQ/releases/download/v1.0/resnet50_imagenet.pth.tar'
checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True)
cnn.load_state_dict(checkpoint)


<All keys matched successfully>

In [13]:
kwargs = {''}
qnn = BRECQ(cnn, dataloaders, )
qnn.compress_clasifier()