# Load data

In [1]:
sys.path.insert(1, '/home/sidtandon/Sid/GitRepo/FastShapExt')

In [2]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os.path
from copy import deepcopy
from resnet import ResNet18

from fastshap import ImageSurrogate
from fastshap.image_surrogate  import generate_labels
from fastshap.image_imputers import ImageImputer
from fastshap.utils import MaskLayer2d, MaskLayer2dSCL, KLDivLoss, DatasetInputOnly
from scl.networks.resnet_big import SupConResNet, LinearClassifier
from scl.util import AverageMeter
import torch.backends.cudnn as cudnn
import time
from torch.utils.data import RandomSampler, BatchSampler, DataLoader, TensorDataset
from fastshap.utils import UniformSampler, DatasetRepeat
import sys
from scl.util import set_optimizer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Select device
device = torch.device('cuda')
torch.cuda.is_available()

True

In [5]:
import argparse
import math
def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=512,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.1,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--ckpt', type=str, default='./scl_models/SupCon/cifar10_models/SupCon_cifar10_resnet18_lr_0.05_decay_0.0001_bsz_256_temp_0.07_trial_0_cosine/ckpt_epoch_500.pth',
                        help='path to pre-trained model')

    opt = parser.parse_args("")

    # set the path according to the environment
    opt.data_folder = './datasets/'

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    if opt.dataset == 'cifar10':
        opt.n_cls = 10
    elif opt.dataset == 'cifar100':
        opt.n_cls = 100
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt

In [6]:
def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [7]:
def set_loader(opt, original_model, num_players):
    # construct data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Load train set
    train_set = dsets.CIFAR10('../', train=True, download=True, transform=transform_train)

    # Load test set (using as validation)
    val_set = dsets.CIFAR10('../', train=False, download=True, transform=transform_test)

    train_set = DatasetInputOnly(train_set)
    val_set = DatasetInputOnly(val_set)
    
    random_sampler = RandomSampler(
        train_set, replacement=True,
        num_samples=int(np.ceil(len(train_set) / opt.batch_size))*opt.batch_size)
    batch_sampler = BatchSampler(
        random_sampler, batch_size=opt.batch_size, drop_last=True)
    train_loader = DataLoader(train_set, batch_sampler=batch_sampler,
                                pin_memory=True, num_workers=opt.num_workers)


    sampler = UniformSampler(num_players)
    torch.manual_seed(1)
    S_val = sampler.sample(len(val_set))
    validation_batch_size = opt.batch_size


        # Generate validation labels.
    y_val = generate_labels(val_set, original_model,
                            validation_batch_size, opt.num_workers)
    
    # Create dataset.
    val_set = DatasetRepeat(
        [val_set, TensorDataset(y_val, S_val)])
   
    val_loader = DataLoader(val_set, batch_size=validation_batch_size,
                            pin_memory=True, num_workers=opt.num_workers)

    return train_loader, val_loader

In [8]:
def set_model(opt):

    mask_layer_model = MaskLayer2dSCL(value=0, append=True, include_second_coalition=False)
    supcon_resnet_model = SupConResNet(name=opt.model)

    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            supcon_resnet_model.encoder = torch.nn.DataParallel(supcon_resnet_model.encoder)
        else:
            new_state_dict = {}
            for k, v in state_dict.items():
                k = k.replace("1.encoder", "encoder")
                k = k.replace("1.head", "head")
                new_state_dict[k] = v
            state_dict = new_state_dict
        supcon_resnet_model = supcon_resnet_model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        mask_layer_model = mask_layer_model.cuda()
        cudnn.benchmark = True

        supcon_resnet_model.load_state_dict(state_dict)

    return mask_layer_model, supcon_resnet_model, classifier, criterion

In [9]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class KLDivLoss(nn.Module):
    '''
    KL divergence loss that applies log softmax operation to predictions.

    Args:
      reduction: how to reduce loss value (e.g., 'batchmean').
      log_target: whether the target is expected as a log probabilities (or as
        probabilities).
    '''

    def __init__(self, reduction='batchmean', log_target=False):
        super().__init__()
        self.kld = nn.KLDivLoss(reduction=reduction, log_target=log_target)

    def forward(self, pred, target):
        '''
        Evaluate loss.

        Args:
          pred:
          target:
        '''
        return self.kld(pred.log_softmax(dim=1), target)

In [18]:
def train(train_loader, mask_layer_model, supcon_resnet_model, original_model, classifier, criterion, optimizer, epoch, opt
            , image_imputer: ImageImputer):
    """one epoch training"""
    mask_layer_model.eval()
    supcon_resnet_model.eval()
    classifier.train()

    kldiv_criterion = KLDivLoss()
    sampler = UniformSampler(image_imputer.num_players)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    kl_top1= AverageMeter()

    end = time.time()
    for idx, (images,) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        bsz = images.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        with torch.no_grad():
            S = sampler.sample(opt.batch_size).to(device=device)
            S = image_imputer.resize(S)
            mod_images =  mask_layer_model((images, S))
            features = supcon_resnet_model.encoder(mod_images)
            y_for_klloss = original_model(images)
            labels = torch.argmax(original_model(images), dim = 1)
            
        output = classifier(features.detach())
        loss = criterion(output, labels)

        with torch.no_grad():
            klloss = kldiv_criterion(output, y_for_klloss)
            
        # update metric
        losses.update(loss.item(), bsz)
        acc1 = accuracy(output, labels, topk=[1]) # Mod by Sid: topk=(1,5). Similarly in validate function
        top1.update(acc1[0], bsz)
        kl_top1.update(klloss, bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'KL loss {klloss.val:.3f} ({klloss.avg:.3f})\t'
                  'Acc@1 {top1.val[0]:.3f} ({top1.avg[0]:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, klloss = kl_top1,top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg[0]


def validate(val_loader, mask_layer_model, supcon_resnet_model, classifier, criterion, opt, image_imputer: ImageImputer):
    """validation"""
    supcon_resnet_model.eval()
    mask_layer_model.eval()
    classifier.eval()

    kldiv_criterion = KLDivLoss()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    kl_top1= AverageMeter()

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels, S) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            S = S.cuda()
            bsz = labels.shape[0]

            # forward
            S = image_imputer.resize(S)
            mod_images =  mask_layer_model((images, S))
            output = classifier(supcon_resnet_model.encoder(mod_images))
            labels_index = torch.argmax(labels, dim = 1)
            loss = criterion(output, labels_index)
            klloss = kldiv_criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1 = accuracy(output, labels_index, topk=[1])
            top1.update(acc1[0], bsz)
            kl_top1.update(klloss, bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'KL loss {klloss.val:.3f} ({klloss.avg:.3f})\t'
                      'Acc@1 {top1.val[0]:.3f} ({top1.avg[0]:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time,
                       loss=losses, klloss = kl_top1,top1=top1))

    print(' * Acc@1 {top1.avg[0]:.3f}'.format(top1=top1))
    return losses.avg, top1.avg[0]

In [19]:
best_acc = 0
opt = parse_option()

print('Loading saved model')
resnet_orig_model = ResNet18(num_classes=10)
resnet_orig_model.load_state_dict(torch.load('cifar resnet.pt'))
resnet_orig_model.to(device)
original_model = nn.Sequential(resnet_orig_model, nn.Softmax(dim=1))

image_imputer = ImageImputer(width=32, height=32, superpixel_size=2)

# build data loader
train_loader, val_loader = set_loader(opt, original_model,image_imputer.num_players )

# build model and criterion
mask_layer_model, supcon_resnet_model, classifier, criterion = set_model(opt)

# build optimizer
optimizer = set_optimizer(opt, classifier)

# training routine
for epoch in range(1, opt.epochs + 1):
    adjust_learning_rate(opt, optimizer, epoch)

    # train for one epoch
    time1 = time.time()
    loss, acc = train(train_loader, mask_layer_model, supcon_resnet_model, original_model, classifier, criterion,
                        optimizer, epoch, opt, image_imputer)
    time2 = time.time()
    print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
        epoch, time2 - time1, acc))

    # eval for one epoch
    loss, val_acc = validate(val_loader, mask_layer_model, supcon_resnet_model, classifier, criterion, opt, image_imputer)
    if val_acc > best_acc:
        best_acc = val_acc

print('best accuracy: {:.2f}'.format(best_acc))

Loading saved model
Files already downloaded and verified
Files already downloaded and verified
Train: [1][10/98]	BT 0.668 (0.666)	DT 0.278 (0.276)	loss 1.642 (3.535)	KL loss 1.682 (3.563)	Acc@1 90.820 (73.867)
Train: [1][20/98]	BT 0.668 (0.668)	DT 0.275 (0.277)	loss 1.279 (2.521)	KL loss 1.252 (2.580)	Acc@1 94.336 (83.418)
Train: [1][30/98]	BT 0.662 (0.668)	DT 0.269 (0.276)	loss 1.488 (2.200)	KL loss 1.541 (2.266)	Acc@1 93.750 (86.934)
Train: [1][40/98]	BT 0.667 (0.670)	DT 0.272 (0.278)	loss 1.196 (2.026)	KL loss 1.317 (2.117)	Acc@1 92.773 (88.638)
Train: [1][50/98]	BT 0.672 (0.670)	DT 0.278 (0.278)	loss 1.745 (1.878)	KL loss 1.869 (1.981)	Acc@1 93.945 (89.816)
Train: [1][60/98]	BT 0.666 (0.670)	DT 0.271 (0.277)	loss 1.327 (1.753)	KL loss 1.560 (1.875)	Acc@1 94.727 (90.602)
Train: [1][70/98]	BT 0.672 (0.675)	DT 0.277 (0.282)	loss 1.270 (1.696)	KL loss 1.394 (1.814)	Acc@1 94.922 (91.152)
Train: [1][80/98]	BT 0.669 (0.676)	DT 0.273 (0.282)	loss 1.149 (1.619)	KL loss 1.250 (1.735)	Acc@1 

In [3]:
MaskLayer2dSCL(value=0,append=0,include_second_coalition= False)

MaskLayer2dSCL()