In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms, datasets
import torch.backends.cudnn as cudnn

import time
import os
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import math
from scipy.special import softmax

from networks import ResNet
from utils import *
from utils_plotting import *
import tent

In [None]:
print(torch.cuda.is_available())
use_cuda = torch.cuda.is_available()

print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Dataset

In [None]:
batch_size = 128
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean, std)
                                    ])
transform_test = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean, std)])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)

print(trainset.train_data.shape)

models_dir = "Models/" + 'Tent_models'
results_dir = "Results/" + 'Tent_results'

## Network

In [None]:
model = ResNet(18,10).to(device)
print(sum(p.numel() for p in model.parameters())/1000000, "M parameters")

## Training

In [None]:
criterion = nn.CrossEntropyLoss()

def train(epoch): 
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    lr = 0.1
    
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate(lr, epoch), momentum=0.9, weight_decay=5e-4)

    print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, learning_rate(lr, epoch)))
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = model(inputs)               # Forward Propagation
        loss = criterion(outputs, targets)  # Loss
        loss.backward()  # Backward Propagation
        optimizer.step() # Optimizer update

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        # sys.stdout.write('\r')
        # sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
        #         %(epoch, num_epochs, batch_idx+1,
        #             (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total))
        # sys.stdout.flush()
    print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
                %(epoch, num_epochs, batch_idx+1,
                    (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total))

def test(epoch):
    global best_acc, model_best
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(valloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        with torch.no_grad():
            inputs, targets = Variable(inputs), Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    # Save checkpoint when best model
    acc = 100.*correct/total
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.item(), acc))

    if acc > best_acc:
        model_best = model
        # print(acc, best_acc)
        best_acc = acc
        print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc))

        # state = {
        #         'net':net.module if use_cuda else net,
        #         'acc':acc,
        #         'epoch':epoch,
        # }

        # torch.save(state, models_dir + '/' + 'theta_best.t7')
        torch.save(model.state_dict(),models_dir + '/' + 'theta_best.pt')
        


In [None]:
num_epochs = 200
best_acc = 50

start = time.time()
elapsed_time = 0
for epoch in range(num_epochs):
    start_time = time.time()

    train(epoch)
    test(epoch)

    epoch_time = time.time() - start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d'  %(get_hms(elapsed_time)))

print('\n[Phase 4] : Testing model')
print('* Test results : Acc@1 = %.2f%%' %(best_acc))
end = time.time()
print('| Total time : %d:%02d:%02d'  %(get_hms(end - start)))

In [None]:
# model = ResNet(18,10)
# model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))

# print(sum(p.numel() for p in model.parameters())/1000000, "M parameters")

# TENT
net = model_best
net = tent.configure_model(net)
params, param_names = tent.collect_params(net)
# optimizer = torch.optim.Adam(params, lr=1e-1)
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=5e-4)

tented_model = tent.Tent(net, optimizer)
# tented_model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))

tented_model = tented_model.to(device)

## Inference

### Regular CIFAR-10

In [None]:
# Baseline model
model = model_best
model.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(valloader):
    if use_cuda:
        inputs, targets = inputs.cuda(), targets.cuda()
    with torch.no_grad():
        inputs, targets = Variable(inputs), Variable(targets)
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    test_loss += loss.item()
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

acc = (100.*correct/total).item()
print(acc)

In [None]:
# Tented model

tented_model.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(valloader):
    if use_cuda:
        inputs, targets = inputs.cuda(), targets.cuda()
    with torch.no_grad():
        inputs, targets = Variable(inputs), Variable(targets)
    outputs = tented_model(inputs)
    loss = criterion(outputs, targets)

    test_loss += loss.item()
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

acc = (100.*correct/total).item()
print(acc)

### CIFAR-10 Rotations

In [None]:
data_rotated = np.load("data/CIFAR10_rotated.npy")
# data_rotated = np.transpose(data_rotated, (0,1,4,2,3))
data_rotated.shape

In [None]:
data_flattened = []
labels_flattened = []

for n in range(16):
    data_flattened.append(data_rotated[n])
    labels_flattened.append(valset.test_labels)

data_flattened = np.vstack(data_flattened)
labels_flattened = np.hstack(labels_flattened)
print(data_flattened.shape, labels_flattened.shape)

In [None]:
x_dev = valset.test_data
y_dev = valset.test_labels

In [None]:
 # Get Tented Model
model = ResNet(18,10)
model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))

net = model
net = tent.configure_model(net)
params, param_names = tent.collect_params(net)
optimizer = torch.optim.Adam(params, lr=1e-3)
# optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=5e-4)

tented_model = tent.Tent(net, optimizer)
# tented_model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))
# tented_model = tented_model.to(device)
tented_model.eval()


im_ind = 23
Nsamples = 100
steps = 16

plt.figure()
plt.imshow( ndim.interpolation.rotate(x_dev[im_ind], 0, reshape=False, mode='nearest'))
plt.title('original image')
plt.savefig(results_dir + '/sample_image.png', bbox_inches='tight')
s_rot = 0
end_rot = 179
rotations = (np.linspace(s_rot, end_rot, steps)).astype(int)            

ims = []
predictions = []
# percentile_dist_confidence = []
x, y = x_dev[im_ind], y_dev[im_ind]

fig = plt.figure(figsize=(steps, 8), dpi=80)

ims = data_rotated[:,im_ind,:,:,:]
# print(ims.shape)
ims  =np.transpose(ims,(0,3,1,2))
# print(ims.shape)

ims = torch.Tensor(ims)
logits = []
for n in range(Nsamples):
    outputs = tented_model(ims)
    outputs = outputs.detach().cpu().numpy()
    logits.append(outputs)

logits = np.stack(logits)
print(logits.shape)
logits = np.mean(logits, axis=0)
print(logits.shape)
predictions = softmax(logits,1)

textsize = 15
lw = 5

c = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']  

ax0 = plt.subplot2grid((3, steps-1), (0, 0), rowspan=2, colspan=steps-1)
#ax0 = fig.add_subplot(2, 1, 1)
plt.gca().set_prop_cycle(color = c)
ax0.plot(rotations, predictions, linewidth=lw)


##########################
# Dots at max

for i in range(predictions.shape[1]):

    selections = (predictions[:,i] == predictions.max(axis=1))
    for n in range(len(selections)):
        if selections[n]:
            ax0.plot(rotations[n], predictions[n, i], 'o', c=c[i], markersize=15.0)
##########################  

lgd = ax0.legend(['airplane', 'automobile', 'bird',
            'cat', 'deer', 'dog',
            'frog', 'horse', 'ship',
            'truck'], loc='upper right', prop={'size': textsize, 'weight': 'normal'}, bbox_to_anchor=(1.35,1))
plt.xlabel('rotation angle')
# plt.ylabel('probability')
plt.title('True class: %d, Nsamples %d' % (y, Nsamples))
# ax0.axis('tight')
plt.tight_layout()
plt.autoscale(enable=True, axis='x', tight=True)
plt.subplots_adjust(wspace=0, hspace=0)

for item in ([ax0.title, ax0.xaxis.label, ax0.yaxis.label] +
            ax0.get_xticklabels() + ax0.get_yticklabels()):
    item.set_fontsize(textsize)
    item.set_weight('normal')

plt.savefig(results_dir + '/percentile_label_probabilities.png', bbox_extra_artists=(lgd,), bbox_inches='tight')

In [None]:
 # Get Tented Model
model = ResNet(18,10)
model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))

net = model
net = tent.configure_model(net)
params, param_names = tent.collect_params(net)
optimizer = torch.optim.Adam(params, lr=1e-3)
# optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=5e-4)

tented_model = tent.Tent(net, optimizer)
# tented_model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))
tented_model = tented_model.to(device)
tented_model.eval()

steps = 16
N = 10000
Nsamples = 100
y_dev = valset.test_labels

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y

# for im_ind in range(N):
#     if(im_ind % 500 == 0):
#         print(im_ind)

#     y =  y_dev[im_ind]
    
#     ims = data_rotated[:,im_ind,:,:,:]
#     ims = preprocess_test(ims)
#     # print(ims.shape)

#     y = np.ones(ims.shape[0])*y
    
#     # sample_probs = tented_model.all_sample_eval(ims, torch.from_numpy(y), Nsamples=Nsamples)

#     predictions = torch.zeros(Nsamples, steps, 10)

#     ims = ims.to(device)
#     for i in range(Nsamples):
#         y = tented_model(ims)
#         # print(y.shape)
#         predictions[i] = y

#     probs = F.softmax(predictions, dim=2)
#     probs = probs.detach().numpy()

#     all_sample_preds[im_ind] = probs
#     all_preds[im_ind] = np.mean(probs, axis=0)

preds_list = []
preds_samples_list = []

data_flattened = preprocess_test(data_flattened)
labels_flattened = torch.Tensor(labels_flattened)
labels_flattened = labels_flattened.long()

chal_dataset = torch.utils.data.TensorDataset(data_flattened, labels_flattened)
batch_size = 200
chal_loader = torch.utils.data.DataLoader(chal_dataset, batch_size=batch_size)
chal_error = 0

correct = 0
with torch.no_grad():
    for idx, data in enumerate(chal_loader):
        if (idx % 100==0):
            print(idx)
        x,y = data
        # cost, err, probs = net.sample_eval(x, y, Nsamples=10, logits=False)
        # preds_list.append(probs.cpu().numpy())
        # chal_error += err.cpu().numpy()
        x = x.to(device)
        predictions = torch.zeros(Nsamples, batch_size, 10)
        for i in range(Nsamples):
            outputs = tented_model(x)
            # print(y.shape)
            predictions[i] = F.softmax(outputs, dim=1)
        
        preds_samples_list.append(predictions.cpu().numpy())
        preds = predictions.mean(0)
        preds_list.append(preds.cpu().numpy())
        correct += torch.sum(torch.argmax(preds,1) == y)
    # print(err)

# print(chal_error)
acc = (correct/len(chal_dataset)).item()
print(acc)


In [None]:
preds_samples_list = np.concatenate(preds_samples_list,axis=1)
preds_list = np.vstack(preds_list)
print(preds_samples_list.shape, preds_list.shape)

In [None]:
preds_samples_list = np.transpose(preds_samples_list,(1,0,2))
preds_samples_list.shape

In [None]:
all_preds = np.zeros((N, steps, 10))
all_sample_preds = np.zeros((N, Nsamples,steps, 10))

for n in range(16):
    all_preds[:,n,:] = preds_list[n*10000:(n+1)*10000]
    all_sample_preds[:,:,n,:] = preds_samples_list[n*10000:(n+1)*10000]

all_preds += 1e-12
all_sample_preds += 1e-12

In [None]:
rotations = (np.linspace(0, 179, steps)).astype(int)

correct_preds = np.zeros((N, steps))
for i in range(N):
    correct_preds[i,:] = all_preds[i,:,y_dev[i]]   

np.save(results_dir+'/correct_preds.npy', correct_preds)
np.save(results_dir+'/all_preds.npy', all_preds)
np.save(results_dir+'/all_sample_preds.npy', all_sample_preds)

plot_predictive_entropy(correct_preds, all_preds, rotations, results_dir)

In [None]:
print(np.mean(preds == valset.test_labels))

### CIFAR10-C

In [None]:
# model = ResNet(18,10).to(device)
# model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))
# checkpoint = torch.load(models_dir + '/' + 'theta_best.t7')
# model = checkpoint['net']

chalPath = 'data/CIFAR-10-C/'
chals = sorted(os.listdir(chalPath))

chal_labels = valset.test_labels
chal_labels = torch.Tensor(chal_labels)
chal_labels = chal_labels.long()

Nsamples = 10

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y

preds_list= []
# net.eval()
avg_list = []

for challenge in range(len(chals)):
    chal_data = np.load(chalPath + chals[challenge])
    # chal_data = np.transpose(chal_data, (0,3,1,2))

    avg = 0
    for j in range(5):

        # Get Tented Model
        model = ResNet(18,10)
        model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))

        net = model
        net = tent.configure_model(net)
        params, param_names = tent.collect_params(net)
        optimizer = torch.optim.Adam(params, lr=1e-3)
        # optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=5e-4)

        tented_model = tent.Tent(net, optimizer)
        # tented_model.load_state_dict(torch.load(models_dir + '/' + 'theta_best.pt'))
        tented_model = tented_model.to(device)
        tented_model.eval()

        # Load CIFAR10-C Data
        chal_temp_data = chal_data[j * 10000:(j + 1) * 10000]
        chal_temp_data = preprocess_test(chal_temp_data)

        chal_dataset = torch.utils.data.TensorDataset(chal_temp_data, chal_labels)
        batch_size = 200
        chal_loader = torch.utils.data.DataLoader(chal_dataset, batch_size=batch_size)
        chal_error = 0

        correct = 0
        with torch.no_grad():
            for x, y in chal_loader:
                # cost, err, probs = net.sample_eval(x, y, Nsamples=10, logits=False)
                # preds_list.append(probs.cpu().numpy())
                # chal_error += err.cpu().numpy()
                x = x.to(device)
                predictions = torch.zeros(Nsamples, batch_size, 10)
                for i in range(Nsamples):
                    outputs = tented_model(x)
                    # print(y.shape)
                    predictions[i] = F.softmax(outputs, dim=1)
                
                preds = predictions.mean(0)
                preds_list.append(preds.cpu().numpy())
                correct += torch.sum(torch.argmax(preds,1) == y)
            # print(err)

        # print(chal_error)
        chal_acc = (correct/len(chal_dataset)).item()
        avg += chal_acc
        print(round(chal_acc,4))
    
    avg /= 5
    avg_list.append(avg)
    print("Average:", round(avg,4)," ", chals[challenge])

print("Mean: ", np.mean(avg_list))

In [None]:
preds_list = np.vstack(preds_list)
np.save(results_dir+'/preds_CIFAR-10-C.npy', preds_list)
np.save(results_dir+'/avg_list_CIFAR-10-C.npy', avg_list)