In [12]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from scipy import ndimage
import os, sys
import math
import pickle
import notebook_utils as nbutils
import data_utils as datutil
import datetime as dt
import hmc
from models import *
import gpytorch
from notebook_utils import *
import torch.nn.functional as F

In [8]:
class Identity(nn.Module):
    '''
    A dummy empty class to place whenever we
    do not need any nn block but have to put something
    '''
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

class bayes_fc_block(nn.Module):
    '''
    bayesian linear layer, instead of W, b we have 
    4 parameters in this block, W_mu, W_sigma, b_mu
    and b_sigma. For now the sigmas are diagonal
    
    input_dim: the second dimension of input x of size (batch_size, input_size)
    output_dim: the second dimension of output block(x) (batch_size, output_size)
    device: device in which all the params need to be placed
    '''
    def __init__(self, input_dim, output_dim, device):
        super(bayes_fc_block, self).__init__()
        self.register_parameter('fc_w_mu', Parameter(torch.Tensor(input_dim, output_dim)))
        self.register_parameter('fc_b_mu', Parameter(torch.Tensor(output_dim)))
        self.register_parameter('fc_w_sig', Parameter(torch.Tensor(input_dim, output_dim)))
        self.register_parameter('fc_b_sig', Parameter(torch.Tensor(output_dim)))
        
        self.device = device
        self.ind = input_dim
        self.outd = output_dim
        param_dict = dict(self.named_parameters())
        stdv = 1. / math.sqrt(input_dim)
        param_dict['fc_w_mu'].data.uniform_(-stdv, stdv)
        param_dict['fc_w_sig'].data.uniform_(math.log(stdv)-0.5, math.log(stdv))
        param_dict['fc_b_mu'].data.uniform_(-stdv, stdv)
        param_dict['fc_b_sig'].data.uniform_(math.log(stdv)-0.5, math.log(stdv))
        
    def forward(self, x):
        param_dict = dict(self.named_parameters())
        w_mu = param_dict['fc_w_mu']
        b_mu = param_dict['fc_b_mu']
        w_sig = param_dict['fc_w_sig']
        b_sig = param_dict['fc_b_sig']
        
        # simulating noise and then in turn simulating W and b
        noise_w = torch.randn((self.ind, self.outd), device=self.device)
        noise_b = torch.randn((self.outd), device=self.device)
        w = w_mu + torch.exp(w_sig) * noise_w
        b = b_mu + torch.exp(b_sig) * noise_b
        x = F.linear(x, w.t(), b)
        
        return x


class feature(nn.Module):
    def __init__(self, base_feature, fc_layers, device):
        '''
        Wrapper class for a base feature extractor and 
        a series of bayesian layers.
        this implements the final loss as KL divergence of
        proportional true posterior and approximation density
        
        base_feature: base feature extractor (maybe resnet till before dense)
                    set it to Identity class when you have encoded feature.
                    Otherwise feature extractor parameters will be jointly learned
        fc_layers: array containing dense layer lengths starting with base feature dim
                    e.g. [256, 100, 10] will expect a 256 dimensional input and then
                    place linear(256, 100) and then linear(100, 10) sequentially
        device: cuda device in which (local as well) parameters will be put.
        '''
        super(feature, self).__init__()
        self.base_layer = base_feature
        self.fc_architecture = fc_layers
        self.device = device
        if len(fc_layers) > 0:
            linear_list = []
            for comp_idx in range(2*len(fc_layers)-3):
                if comp_idx%2==0:
                    idx = comp_idx // 2
                    linear_list.append(bayes_fc_block(fc_layers[idx], fc_layers[idx+1], device))
                else:
                    linear_list.append(nn.ReLU())

            self.bayes_fc_blocks = nn.Sequential(*linear_list)
        
    def forward(self, x, labels, num_sample=1):

        x = self.base_layer(x)
        out = 0
        batch_size = x.size()[0]
        
        # forward function of dense layers are called multiple times
        # to average over different samples, then KL div is taken, for now
        # only expectation term is implemented and KL term is ignored
        for count in range(num_sample):
            last_layer = self.bayes_fc_blocks(x)
            probs = F.softmax(last_layer, dim=1)
            class_logprob = torch.log(torch.gather(probs, dim=1, index=labels.reshape((batch_size, 1))))
            out -= torch.sum(class_logprob) / (num_sample*batch_size)
        
        return out
    
    def infer(self, x, num_sample=20):
        '''
        function to generate class probabilities with
        multiple samples from posterior
        
        x: input (image/encoded features)
        num_sample: how many samples to get from posterior
        
        return: class probabilities of shape (num_sample, x.shape[0], num_classes)
        '''
        x = self.base_layer(x)
        class_prob = torch.zeros((num_sample, x.size()[0], self.fc_architecture[-1]), device=self.device)
        for count in range(num_sample):
            class_outp = self.bayes_fc_blocks(x)
            class_prob[count,:,:] = F.softmax(class_outp, dim=1)

        return class_prob
    

In [3]:
# Data loader initialization
trainloader = datutil.generate_dataloaders('ENCODED256_D110_CIFAR10_TRAIN', batch_size=300, shuffle=False, num_workers=2)
testloader = datutil.generate_dataloaders('ENCODED256_D110_CIFAR10_TEST', batch_size=200, shuffle=False, num_workers=2)

device = torch.device('cuda:1')

In [10]:
fc_layer_setup = [256, 128, 10]
# base_model = PreResNet(num_classes=fc_layer_setup[-1], depth=164)
# base_model.fc = Identity()
base_model = Identity()
final_model = feature(base_model, fc_layer_setup, device)

final_model.to(device)
print(final_model)
optimizer = torch.optim.SGD(final_model.parameters(), lr=0.01, weight_decay=0.0003, momentum=0.9)

feature(
  (base_layer): Identity()
  (bayes_fc_blocks): Sequential(
    (0): bayes_fc_block()
    (1): ReLU()
    (2): bayes_fc_block()
  )
)


In [11]:
running_loss = 0

for epoch in range(0, 20):  # loop over the dataset multiple times

    for i, data in enumerate(trainloader):

        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        loss = final_model(inputs, labels)
        loss.sum().backward()
        optimizer.step()
        running_loss = 0.9*running_loss + 0.1*loss.item() if running_loss != 0 else loss.item()

        if i% (len(trainloader) // 1) == 0:
            print('[%d, %5d] loss: %.3f' %(epoch + 1, i, running_loss))

    print("=== Accuracy using SGD params ===")
    accuracy, ece = nbutils.validate(model=final_model, dataloader=testloader, device=device)


[1,     0] loss: 2.392
=== Accuracy using SGD params ===
Accuracy statistics
Overall accuracy : 90 %
ECE values are 0.059, 0.057 when mid bin and avg used respectively
Loss : tensor(0.3449, device='cuda:1')
[2,     0] loss: 0.093
=== Accuracy using SGD params ===
Accuracy statistics
Overall accuracy : 90 %
ECE values are 0.019, 0.013 when mid bin and avg used respectively
Loss : tensor(0.3238, device='cuda:1')
[3,     0] loss: 0.039
=== Accuracy using SGD params ===
Accuracy statistics
Overall accuracy : 90 %
ECE values are 0.016, 0.011 when mid bin and avg used respectively
Loss : tensor(0.3251, device='cuda:1')
[4,     0] loss: 0.025
=== Accuracy using SGD params ===
Accuracy statistics
Overall accuracy : 90 %
ECE values are 0.017, 0.014 when mid bin and avg used respectively
Loss : tensor(0.3299, device='cuda:1')
[5,     0] loss: 0.021
=== Accuracy using SGD params ===
Accuracy statistics
Overall accuracy : 90 %
ECE values are 0.018, 0.017 when mid bin and avg used respectively
Loss

In [None]:
savefile = 'VIBayesNN_notebook_model_file'
savedir = 'saved_models/'
checkpoint = {'model_state': final_model.state_dict(),
              'optim_state': optimizer.state_dict(),
              'acc': accuracy}
curtime = dt.datetime.now()
tm = curtime.strftime("%Y-%m-%d-%H.%M")
torch.save(checkpoint, savedir + savefile + '-' + tm + '.model')