In [8]:
import torch
import os, time
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import random
from torch.optim import SGD, lr_scheduler
from dataloaders.cifar10 import cifar10_dataloaders
from dataloaders.svhn import svhn_dataloaders
from dataloaders.stl10 import stl10_dataloaders
from utils.utils import *
import multiprocessing as mp
mp.set_start_method('spawn', force=True) ### to avoid Caffe Warnings

################################################ Importing Models and Width Multiplier List
from models.slimmable_ops import width_mult_list
from models.cifar10.resnet_slimmable import SlimmableResNet34
from models.svhn.wide_resnet_slimmable import SlimmableWideResNet_16_8
from models.cifar10.mobilenet_v2_slimmable import SlimmableMobileNetV2

# Setting Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('GPU') if str(device) == "cuda:0" else print('GPU not Detected - CPU Selected')
print(f"GPUs Count: {torch.cuda.device_count()}") # Show how many GPUs are available

GPU
GPUs Count: 1


In [9]:
def set_seed(seed=3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_default_dtype(torch.float32)
set_seed()

In [None]:
#################### Setting Hyperparameters
class Arguments:
    gpu = '0'
    cpus = 4
    dataset = 'cifar10'
    model_name = 'ResNet_34'     # ResNet_34     WideResNet_16_8     MobileNetV2
    batch_size, dim = 128, 128
    epochs = 200
    decay_epochs = [50, 150]
    opt = 'sgd'
    decay = 'cos'
    lr = 0.01
    momentum = 0.9
    wd = 5e-4

args = Arguments()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True

############### DATA LOADERS:
if args.dataset == 'cifar10':
    train_loader, val_loader, test_loader = cifar10_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)
elif args.dataset == 'svhn':
    train_loader, val_loader, test_loader = svhn_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)
elif args.dataset == 'stl10':
    train_loader, val_loader = stl10_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)

############### MODEL SELECTION:
FiLM_in_channels = args.dim

if args.model_name == 'ResNet_34':
    model_fn = SlimmableResNet34
    model = model_fn().cuda()
elif args.model_name == 'WideResNet_16_8':
    model_fn = SlimmableWideResNet_16_8
    model = model_fn(depth=16, num_classes=10, widen_factor=8, dropRate=0.0).to(device)
elif args.model_name == 'MobileNetV2':
    model_fn = SlimmableMobileNetV2
    model = model_fn().cuda()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
./SNN_results/cifar10/SlimmableResNet34/Epochs200_BatchSize128_LR0.01


In [None]:
print(f"Number of Sub-Nets: {len(width_mult_list)}")    # Note: width_mult_list should be compatible with the model to be evaluated

Number of Sub-Nets: 32


#### **Load Trained Model**

In [None]:
############### For loading Best MODELs
model_path = "./SNN_results/"    ##############  Provide Correct Model Path
model_name = "Best_Model.pth"    ##############  Use Correct Model Name

model_path = model_path + model_name
ckpt = torch.load(model_path)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [None]:
def test_model(model, sub_net_ID):
    ################################## TESTING
    model.eval()
    requires_grad_(model, False)
    test_accs = AverageMeter()
    for i, (imgs, labels) in enumerate(test_loader):
        imgs, labels = imgs.cuda(), labels.cuda()
        logits = model(imgs)
        test_accs.append((logits.argmax(1) == labels).float().mean().item())
    print(round(test_accs.avg * 100, 2))
    return

In [None]:
sub_net_width = 0
test_accuracies = []

for i in range(len(width_mult_list)):
    sub_net_width += 1/len(width_mult_list)
    i+=1
    model.apply(lambda m: setattr(m, 'width_mult', sub_net_width))
    test_acc = test_model(model, i)