In [12]:
import torch
import os, time
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import random
from torch.optim import SGD, lr_scheduler
from dataloaders.cifar10 import cifar10_dataloaders
from dataloaders.svhn import svhn_dataloaders
from dataloaders.stl10 import stl10_dataloaders
from utils.utils import *
from utils.context import ctx_noparamgrad_and_eval
from utils.sample_lambda import element_wise_sample_lambda
from attacks.pgd import PGD
import multiprocessing as mp
mp.set_start_method('spawn', force=True) ### to avoid Caffe Warnings

################################################ Importing Models and Width Multiplier List
from models.slimmable_ops import width_mult_list
from models.cifar10.resnet_slimmable_OAT import SlimmableResNet34OAT, SlimmableResNet34OAT_SOL
from models.svhn.wide_resnet_slimmable_OAT import SlimmableWideResNet_16_8_OAT, SlimmableWideResNet_16_8_OAT_SOL
from models.stl10.wide_resnet_slimmable_OAT import SlimmableWideResNet_40_2_OAT, SlimmableWideResNet_40_2_OAT_SOL

# Setting Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('GPU') if str(device) == "cuda:0" else print('GPU not Detected - CPU Selected')
print(f"GPUs Count: {torch.cuda.device_count()}") # Show how many GPUs are available

GPU
GPUs Count: 1


In [13]:
def set_seed(seed=3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_default_dtype(torch.float32)
set_seed()

In [None]:
#################### Setting Hyperparameters
class Arguments:
    gpu = '0'
    cpus = 4
    dataset = 'cifar10'
    batch_size, dim = 128, 128
    epochs = 120
    lr = 0.1
    momentum = 0.9
    wd = 5e-4
    eps = 8     # eps/255 (L-inf norm bound)
    steps = 7   # PGD Steps
    distribution = 'disc'
    probs = -1
    lambda_choices = [0.0, 0.1, 0.2, 0.3, 0.4, 1.0]
    use2BN = True    ### Use Dual Batch Norm
    efficient = True

args = Arguments()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True

############### DATA LOADERS:
if args.dataset == 'cifar10':
    train_loader, val_loader, test_loader = cifar10_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)
elif args.dataset == 'svhn':
    train_loader, val_loader, test_loader = svhn_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)
elif args.dataset == 'stl10':
    train_loader, val_loader = stl10_dataloaders(train_batch_size=args.batch_size, num_workers=args.cpus)

############### MODEL SELECTION:
FiLM_in_channels = args.dim

if args.dataset == 'cifar10':
    model_fn = SlimmableResNet34OAT
    model = model_fn(use2BN=args.use2BN, FiLM_in_channels=FiLM_in_channels).cuda()
elif args.dataset == 'svhn':
    model_fn = SlimmableWideResNet_16_8_OAT
    model = model_fn(depth=16, num_classes=10, widen_factor=8, dropRate=0.0, use2BN=True, FiLM_in_channels=FiLM_in_channels).cuda()
elif args.dataset == 'stl10':
    model_fn = SlimmableWideResNet_40_2_OAT
    model = model_fn(depth=40, num_classes=10, widen_factor=2, dropRate=0.0, use2BN=True, FiLM_in_channels=FiLM_in_channels).cuda()

############### LAMBDA Encoding Matrix:
rand_mat = np.random.randn(args.dim, args.dim)
rand_otho_mat, _ = np.linalg.qr(rand_mat)
encoding_mat = rand_otho_mat

############### ATTACKER
attacker = PGD(eps=args.eps/255, steps=args.steps, use_FiLM=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
./OATS_results/cifar10/SlimmableResNet34OAT/PGD_7_Epochs120_BatchSize128_LR0.1


In [None]:
print(f"Number of Sub-Nets: {len(width_mult_list)}")    # Note: width_mult_list should be compatible with the model to be evaluated

#### **Load Trained Model**

In [None]:
############### For loading Best MODELs
model_path = "./OATS_results/"   ##############  Provide Correct Model Path
model_name = "Best_Model.pth"    ##############  Use Correct Model Name

model_path = model_path + model_name
ckpt = torch.load(model_path)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [11]:
def test_model(model, test_loader):
    ################################## TESTING
    model.eval()
    requires_grad_(model, False)
    test_lambdas = [0.0, 0.1, 0.2, 0.3, 0.4, 1.0]
    
    test_accs, test_accs_adv = {}, {}
    for test_lambda in test_lambdas:
        test_accs[test_lambda], test_accs_adv[test_lambda] = AverageMeter(), AverageMeter()

    for i, (imgs, labels) in enumerate(test_loader):
        imgs, labels = imgs.cuda(), labels.cuda()
        for j, test_lambda in enumerate(test_lambdas):
            # sample _lambda:
            if args.distribution == 'disc' and encoding_mat is not None:
                _lambda = np.expand_dims( np.repeat(j, labels.size()[0]), axis=1 ).astype(np.uint8)
                _lambda = encoding_mat[_lambda,:] 
            else:
                _lambda = np.expand_dims( np.repeat(test_lambda, labels.size()[0]), axis=1 )
            _lambda = torch.from_numpy(_lambda).float().cuda()
            if args.use2BN:
                idx2BN = int(labels.size()[0]) if test_lambda==0 else 0
            else:
                idx2BN = None
            ##### TA:
            logits = model(imgs, _lambda, idx2BN)
            test_accs[test_lambda].append((logits.argmax(1) == labels).float().mean().item())
            
            ##### ATA:
            with ctx_noparamgrad_and_eval(model):
                imgs_adv = attacker.attack(model, imgs, labels=labels, _lambda=_lambda, idx2BN=idx2BN)  # generate adversarial images:
            logits_adv = model(imgs_adv.detach(), _lambda, idx2BN)
            test_accs_adv[test_lambda].append((logits_adv.argmax(1) == labels).float().mean().item())

    lambdas, accuracies, robustness = [], [], []
    for test_lambda in test_lambdas:
        lambdas.append(test_lambda)
        accuracies.append(test_accs[test_lambda].avg)
        robustness.append(test_accs_adv[test_lambda].avg)
    print(accuracies)
    print(robustness)
    print("\n")

In [None]:
sub_net_width = 0
test_accuracies = []

for i in range(len(width_mult_list)):
    sub_net_width += 1/len(width_mult_list)
    i+=1
    print(f"SubNet-ID: {i}")
    model.apply(lambda m: setattr(m, 'width_mult', sub_net_width))
    test_model(model, test_loader)