In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms, datasets
import torch.backends.cudnn as cudnn

import time
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import copy
import pickle

from networks import Bootstrap_Net
from utils import *
from utils_plotting import *

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
use_cuda = torch.cuda.is_available()

## Dataset

In [None]:
# data augmentation
batch_size = 128
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean, std)
                                    ])
transform_test = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean, std)])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

models_dir = 'Models/Ensemble_models'
results_dir = 'Results/Ensemble_results'

## Training

In [None]:
# Bootstrap part
Nruns = 25 # Number of nets in ensemble.
num_epochs = 200 # Epochs per net
p_subsample = 0.8 # Rate at which to subsample the dataset to train each net in the ensemble
weight_set_samples = [] # Save model state dicts
criterion = nn.CrossEntropyLoss()

for iii in range(Nruns):
    keep_idx = []
    
    for idx in range(len(trainset)):
        if np.random.binomial(1, p_subsample, size=1) == 1:
            keep_idx.append(idx)

    keep_idx = np.array(keep_idx)

    from torch.utils.data.sampler import SubsetRandomSampler

    sampler = SubsetRandomSampler(keep_idx)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=4, sampler=sampler)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False,num_workers=4)

    ###############################################################
    # net = Bootstrap_Net(lr=lr, channels_in=3, side_in=32, cuda=use_cuda, classes=10, batch_size=batch_size,weight_decay=weight_decay)
    net = Bootstrap_Net(18,10)
    net.cuda()

    best_acc = 0
    start = time.time()
    elapsed_time = 0
    for epoch in range(num_epochs):
        start_time = time.time()

        # Train
        net.train()
        train_loss = 0
        correct = 0
        total = 0

        lr = 0.1
        optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate(lr, epoch), momentum=0.9, weight_decay=5e-4)

        # print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, learning_rate(lr, epoch)))
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            optimizer.zero_grad()
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)               # Forward Propagation
            loss = criterion(outputs, targets)  # Loss
            loss.backward()  # Backward Propagation
            optimizer.step() # Optimizer update

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

            # sys.stdout.flush()
        # print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
        #             %(epoch, num_epochs, batch_idx+1,
        #                 (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total))
        
        # Test
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(valloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            with torch.no_grad():
                inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

        # Save checkpoint when best model
        acc = 100.*correct/total
        # print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.item(), acc))

        if acc > best_acc:
            # print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc))
            # state = {
            #         'net':net.state_dict if use_cuda else net,
            #         'acc':acc,
            #         'epoch':epoch,
            # }

            # torch.save(state, models_dir + '/' + 'theta_best.t7')
            best_acc = acc

        epoch_time = time.time() - start_time
        elapsed_time += epoch_time
        # print('| Elapsed time : %d:%02d:%02d'  %(get_hms(elapsed_time)))

    print('\n Net %d: Testing model'%(iii))
    print('* Test results : Acc@1 = %.2f%%' %(best_acc))
    end = time.time()
    print('| Total time : %d:%02d:%02d'  %(get_hms(end - start)))

    weight_set_samples.append(copy.deepcopy(net.state_dict()))


In [None]:
save_object(weight_set_samples, models_dir+'/state_dicts.pkl') # save all weight configurations

## Inference

In [None]:
# net = Bootstrap_Net(lr=lr, channels_in=3, side_in=32, cuda=use_cuda, classes=10, batch_size=batch_size,weight_decay=weight_decay)
net = Bootstrap_Net(18,10)
net.cuda()

In [None]:
 # Test
net.eval()

wrong = 0
for batch_idx, (inputs, targets) in enumerate(valloader):
    if use_cuda:
        inputs, targets = inputs.cuda(), targets.cuda()
    with torch.no_grad():
        inputs, targets = Variable(inputs), Variable(targets)
    cost, err, probs = net.sample_eval(inputs, targets, weight_set_samples, logits=False)
    # loss = criterion(outputs, targets)
    wrong += err.item()
    
acc = 1-(wrong/len(valset))
print(acc)

In [None]:
weight_set_samples = load_object(models_dir+'/state_dicts.pkl')

### CIFAR-10 Rotations

In [None]:
# Rotations
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=True,
                                                num_workers=3)
x_dev = []
y_dev = []
for x, y in valloader:
    x_dev.append(x.cpu().numpy())
    y_dev.append(y.cpu().numpy())

x_dev = np.concatenate(x_dev)
y_dev = np.concatenate(y_dev)
print(x_dev.shape)
print(y_dev.shape)

In [None]:
Nsamples = len(weight_set_samples)
im_ind = 23
steps = 16
im_list = valset.test_data

plt.figure()
# plt.imshow( ndim.interpolation.rotate(np.transpose(x_dev[im_ind,:,:,:],(1,2,0)), 0, reshape=False))
plt.imshow(im_list[im_ind])
plt.title('original image')
plt.savefig(results_dir + '/sample_image.png', bbox_inches='tight')
s_rot = 0
end_rot = 179
rotations = (np.linspace(s_rot, end_rot, steps)).astype(int)            

ims = []
predictions = []
# percentile_dist_confidence = []
x, y = x_dev[im_ind], y_dev[im_ind]

fig = plt.figure(figsize=(steps, 8), dpi=80)

# DO ROTATIONS ON OUR IMAGE

for i in range(len(rotations)):
    
    angle = rotations[i]
    x_rot = ndim.interpolation.rotate(x, angle, axes=(1,2),reshape=False, mode='nearest')
    
    ax = fig.add_subplot(3, (steps-1), 2*(steps-1)+i)  
    # ax.imshow(np.transpose(x_rot,(1,2,0))) # Image pixels lie in [-1,1]
    ax.imshow(ndim.interpolation.rotate(im_list[im_ind],angle, axes=(0,1),reshape=False ,mode='nearest'))
    ax.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ims.append(x_rot)
    
ims = np.array(ims)
print(ims.shape)
y = np.ones(ims.shape[0])*y
# ims = np.expand_dims(ims, axis=1)

with torch.no_grad():
    cost, err, probs = net.sample_eval(torch.from_numpy(ims), torch.from_numpy(y), weight_set_samples, logits=False) # , logits=True

predictions = probs.numpy()    
textsize = 15
lw = 5

c = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']  

ax0 = plt.subplot2grid((3, steps-1), (0, 0), rowspan=2, colspan=steps-1)
#ax0 = fig.add_subplot(2, 1, 1)
plt.gca().set_prop_cycle(color = c)
ax0.plot(rotations, predictions, linewidth=lw)


##########################
# Dots at max

for i in range(predictions.shape[1]):

    selections = (predictions[:,i] == predictions.max(axis=1))
    for n in range(len(selections)):
        if selections[n]:
            ax0.plot(rotations[n], predictions[n, i], 'o', c=c[i], markersize=15.0)
##########################  

lgd = ax0.legend(['airplane', 'automobile', 'bird',
            'cat', 'deer', 'dog',
            'frog', 'horse', 'ship',
            'truck'], loc='upper right', prop={'size': textsize, 'weight': 'normal'}, bbox_to_anchor=(1.35,1))
plt.xlabel('rotation angle')
# plt.ylabel('probability')
plt.title('True class: %d, Nsamples %d' % (y[0], Nsamples))
# ax0.axis('tight')
plt.tight_layout()
plt.autoscale(enable=True, axis='x', tight=True)
plt.subplots_adjust(wspace=0, hspace=0)

for item in ([ax0.title, ax0.xaxis.label, ax0.yaxis.label] +
            ax0.get_xticklabels() + ax0.get_yticklabels()):
    item.set_fontsize(textsize)
    item.set_weight('normal')


In [None]:
data_rotated = np.load("data/CIFAR10_rotated.npy")
data_rotated.shape

In [None]:
Nsamples = len(weight_set_samples)

s_rot = 0
end_rot = 179
steps = 16
rotations = (np.linspace(s_rot, end_rot, steps)).astype(int)            

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y  

all_preds = np.zeros((len(x_dev), steps, 10))
all_sample_preds = np.zeros((len(x_dev), Nsamples, steps, 10))

# DO ROTATIONS ON OUR IMAGE
for im_ind in range(len(x_dev)):
    
    if(im_ind % 500 == 0):
        print(im_ind)
    
    y = y_dev[im_ind]
    y = np.ones(ims.shape[0])*y

    ims = data_rotated[:,im_ind,:,:,:]
    ims = preprocess_test(ims)
    
    with torch.no_grad():
        sample_probs = net.all_sample_eval(ims, torch.from_numpy(y), weight_set_samples)
    probs = sample_probs.mean(dim=0)
    
    all_sample_preds[im_ind, :, :, :] = sample_probs.cpu().numpy()
    predictions = probs.cpu().numpy()
    all_preds[im_ind, :, :] = predictions
   
correct_preds = np.zeros((len(x_dev), steps))
for i in range(len(x_dev)):
    correct_preds[i,:] = all_preds[i,:,y_dev[i]]

In [None]:
np.save(results_dir+'/correct_preds.npy', correct_preds)
np.save(results_dir+'/all_preds.npy', all_preds)
np.save(results_dir+'/all_sample_preds.npy', all_sample_preds) #all_sample_preds

In [None]:
rotations = (np.linspace(0, 179, steps)).astype(int)
N = 10000

correct_preds = np.zeros((N, steps))
for i in range(N):
    correct_preds[i,:] = all_preds[i,:,y_dev[i]]   

plot_predictive_entropy(correct_preds, all_preds, rotations, results_dir)

### CIFAR-10-C

In [None]:
weight_set_samples = load_object(models_dir+'/state_dicts.pkl')

In [None]:
# net = Bootstrap_Net(lr=lr, channels_in=3, side_in=32, cuda=use_cuda, classes=10, batch_size=batch_size,weight_decay=weight_decay)
net = Bootstrap_Net(18,10)
net.cuda()

In [None]:
chalPath = 'data/CIFAR-10-C/'
chals = sorted(os.listdir(chalPath))

chal_labels = valset.test_labels
chal_labels = torch.Tensor(chal_labels)
chal_labels = chal_labels.long()

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y

preds_list= []
net.eval()
avg_list = []

for challenge in range(len(chals)):
    chal_data = np.load(chalPath + chals[challenge])
    # chal_data = np.transpose(chal_data, (0,3,1,2))

    avg = 0
    for j in range(5):
        chal_temp_data = chal_data[j * 10000:(j + 1) * 10000]
        chal_temp_data = preprocess_test(chal_temp_data)

        chal_dataset = torch.utils.data.TensorDataset(chal_temp_data, chal_labels)
        chal_loader = torch.utils.data.DataLoader(chal_dataset, batch_size=100)
        chal_error = 0

        with torch.no_grad():
            for x, y in chal_loader:
                cost, err, probs = net.sample_eval(x, y, weight_set_samples=weight_set_samples, logits=False)
                preds_list.append(probs.cpu().numpy())
                chal_error += err.cpu().numpy()
                # print(err)

        # print(chal_error)
        chal_acc = 1 - (chal_error/len(chal_dataset))
        avg += chal_acc
        print(chal_acc)
    
    avg /= 5
    avg_list.append(avg)
    print("Average:",avg," ", chals[challenge])

print("Mean: ", np.mean(avg_list))

In [None]:
preds_list = np.vstack(preds_list)
np.save(results_dir+'/preds_CIFAR-10-C.npy', preds_list)
np.save(results_dir+'/avg_list_CIFAR-10-C.npy', avg_list)