In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets

from networks import BBP_LRT
import metrics
from utils import *
from utils_plotting import *

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## Dataset

In [None]:
batch_size = 128
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean, std)
                                    ])
transform_test = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean, std)])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)

models_dir = "Models/" + 'BBP_models/' + 'lrt'
results_dir = "Results/" + 'BBP_results/' + 'lrt'

In [None]:
# Instantiate network
net = BBP_LRT(18, 10).to(device)

## Training

In [None]:
def train_model(net, criterion,trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
    net.train()
    training_loss = 0.0
    accs = []
    kl_list = []
    
    lr = 0.01
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate(lr,epoch))
    
    for i, (inputs, labels) in enumerate(trainloader, 1):

        optimizer.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)

        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1)
        
        kl = kl / num_ens
        kl_list.append(kl.item())
        log_outputs = logmeanexp(outputs, dim=2)

        beta = metrics.get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
        loss = criterion(log_outputs, labels, kl, beta)
        loss.backward()
        optimizer.step()

        accs.append(metrics.acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()
    return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)

def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
    """Calculate ensemble accuracy and NLL Loss"""
    net.eval()
    valid_loss = 0.0
    accs = []

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(validloader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
            kl = 0.0
            for j in range(num_ens):
                net_out, _kl = net(inputs)
                kl += _kl
                outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

            log_outputs = logmeanexp(outputs, dim=2)

            beta = metrics.get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
            valid_loss += criterion(log_outputs, labels, kl, beta).item()
            accs.append(metrics.acc(log_outputs, labels))

    return valid_loss/len(validloader), np.mean(accs)

In [None]:
num_epochs = 200
criterion = metrics.ELBO(len(trainset)).to(device)
best_acc = 0.6

for epoch in range(num_epochs):
    
    train_loss, train_acc, train_kl = train_model(net, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=epoch, num_epochs=num_epochs)
    valid_loss, acc = validate_model(net, criterion, valloader, num_ens=1, beta_type=0.1, epoch=epoch, num_epochs=num_epochs)

    print('Epoch: {} | Train Loss: {:.4f} | Train Accuracy: {:.4f} | Val Loss: {:.4f} | Val Accuracy: {:.4f} | train_kl_div: {:.4f}'.format(
            epoch, train_loss, train_acc, valid_loss, acc, train_kl))
    
    if acc > best_acc:
        print('| Saving Best model...')
        # state = {
        #         'net':net.module if use_cuda else net,
        #         'acc':acc,
        #         'epoch':epoch,
        # }

        torch.save(net.state_dict(), models_dir + '/' + 'theta_best.t7')
        best_acc = acc

## Inference

In [None]:
net = BBP_LRT(18, 10).cuda()
net.load_state_dict(torch.load(models_dir + '/' + 'theta_best.t7'))
net.eval()

### CIFAR-10 Rotations

In [None]:
x_dev = []
y_dev = []
for x, y in valloader:
    x_dev.append(x.cpu().numpy())
    y_dev.append(y.cpu().numpy())

x_dev = np.concatenate(x_dev)
y_dev = np.concatenate(y_dev)
print(x_dev.shape)
print(y_dev.shape)

plot_rotate(x_dev, y_dev, net, results_dir, im_list = valset.test_data,im_ind = 23, Nsamples = 100, steps=16)

In [None]:
data_rotated = np.load("data/CIFAR10_rotated.npy")
# data_rotated = np.transpose(data_rotated, (0,1,4,2,3))
data_rotated.shape

In [None]:
steps = 16
N = 10000
Nsamples = 100
y_dev = valset.test_labels

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y

all_preds = np.zeros((N, steps, 10))
all_sample_preds = np.zeros((N, Nsamples, steps, 10))

for im_ind in range(N):
    if(im_ind % 500 == 0):
        print(im_ind)

    y =  y_dev[im_ind]
    
    ims = data_rotated[:,im_ind,:,:,:]
    ims = preprocess_test(ims)

    y = np.ones(ims.shape[0])*y
    
    with torch.no_grad():
        sample_probs = net.all_sample_eval(ims, torch.from_numpy(y), Nsamples=Nsamples)
    probs = sample_probs.mean(dim=0)
    
    all_sample_preds[im_ind, :, :, :] = sample_probs.cpu().numpy()
    predictions = probs.cpu().numpy()
    all_preds[im_ind, :, :] = predictions

In [None]:
rotations = (np.linspace(0, 179, steps)).astype(int)

correct_preds = np.zeros((N, steps))
for i in range(N):
    correct_preds[i,:] = all_preds[i,:,y_dev[i]]   

np.save(results_dir+'/correct_preds.npy', correct_preds)
np.save(results_dir+'/all_preds.npy', all_preds)
np.save(results_dir+'/all_sample_preds.npy', all_sample_preds)

plot_predictive_entropy(correct_preds, all_preds, rotations, results_dir)

### CIFAR-10-C

In [None]:
chalPath = 'data/CIFAR-10-C/'
chals = sorted(os.listdir(chalPath))

chal_labels = valset.targets
chal_labels = torch.Tensor(chal_labels)
chal_labels = chal_labels.long()

def preprocess_test(X):

    N, H, W, C = X.shape
    Y = torch.zeros(N, C, H, W)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])

    for n in range(len(X)):
        Y[n] =  transform_test(X[n])

    return Y

preds_list= []
net.model.eval()
avg_list = []

for challenge in range(len(chals)):
    chal_data = np.load(chalPath + chals[challenge])
    # chal_data = np.transpose(chal_data, (0,3,1,2))

    avg = 0
    for j in range(5):
        chal_temp_data = chal_data[j * 10000:(j + 1) * 10000]
        chal_temp_data = preprocess_test(chal_temp_data)

        chal_dataset = torch.utils.data.TensorDataset(chal_temp_data, chal_labels)
        chal_loader = torch.utils.data.DataLoader(chal_dataset, batch_size=100)
        chal_error = 0

        with torch.no_grad():
            for x, y in chal_loader:
                cost, err, probs = net.sample_eval(x, y, Nsamples=10, logits=False)
                preds_list.append(probs.cpu().numpy())
                chal_error += err.cpu().numpy()
                # print(err)

        # print(chal_error)
        chal_acc = 1 - (chal_error/len(chal_dataset))
        avg += chal_acc
        print(chal_acc)
    
    avg /= 5
    avg_list.append(avg)
    print("Average:",avg," ", chals[challenge])

print("Mean: ", np.mean(avg_list))


In [None]:
preds_list = np.vstack(preds_list)
np.save(results_dir+'/preds_CIFAR-10-C.npy', preds_list)
np.save(results_dir+'/avg_list_CIFAR-10-C.npy', avg_list)