In [1]:
import os
import argparse

import torch
import numpy as np
from torch.optim import Adam, lr_scheduler
from torch.nn import functional as F
from torch.nn import Parameter
from torch import nn


In [2]:
class ELBO(nn.Module):
    def __init__(self, train_size):
        super(ELBO, self).__init__()
        self.train_size = train_size

    def forward(self, input, target, kl, beta):
        assert not target.requires_grad
        return F.nll_loss(input, target,reduction='mean')  + (1/self.train_size * kl)


# def lr_linear(epoch_num, decay_start, total_epochs, start_value):
#     if epoch_num < decay_start:
#         return start_value
#     return start_value*float(total_epochs-epoch_num)/float(total_epochs-decay_start)


def acc(outputs, targets):
    return np.mean(outputs.cpu().numpy().argmax(axis=1) == targets.data.cpu().numpy())


def calculate_kl(mu_p, sig_p, mu_q, sig_q):
    kl = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
    return kl


def get_beta(batch_idx, m, beta_type, epoch, num_epochs):
    if type(beta_type) is float:
        return beta_type

    if beta_type == "Blundell":
        beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
    elif beta_type == "Soenderby":
        if epoch is None or num_epochs is None:
            raise ValueError('Soenderby method requires both epoch and num_epochs to be passed.')
        beta = min(epoch / (num_epochs // 4), 1)
    elif beta_type == "Standard":
        beta = 1 / m
    else:
        beta = 0
    return beta

In [3]:
from torch import nn


class ModuleWrapper(nn.Module):
    """Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""

    def __init__(self):
        super(ModuleWrapper, self).__init__()

    def set_flag(self, flag_name, value):
        setattr(self, flag_name, value)
        for m in self.children():
            if hasattr(m, 'set_flag'):
                m.set_flag(flag_name, value)

    def forward(self, x):
        for module in self.children():
            x = module(x)

        kl = 0.0
        for module in self.modules():
            if hasattr(module, 'kl_loss'):
                kl = kl + module.kl_loss()

        return x, kl


class FlattenLayer(ModuleWrapper):

    def __init__(self, num_features):
        super(FlattenLayer, self).__init__()
        self.num_features = num_features

    def forward(self, x):
        return x.view(-1, self.num_features)


In [4]:
class BBBLinear(ModuleWrapper):
    def __init__(self, in_features, out_features, bias=False, priors=None):
        super(BBBLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = bias
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        if priors is None:
            priors = {
                'prior_mu': 0,
                'prior_sigma': 0.1,
                'posterior_mu_initial': (0, 0.1),
                'posterior_rho_initial': (-3, 0.1),
            }
        self.prior_mu = priors['prior_mu']
        self.prior_sigma = priors['prior_sigma']
        self.posterior_mu_initial = priors['posterior_mu_initial']
        self.posterior_rho_initial = priors['posterior_rho_initial']

        self.W_mu = Parameter(torch.empty((out_features, in_features), device=self.device))
        self.W_rho = Parameter(torch.empty((out_features, in_features), device=self.device))

        if self.use_bias:
            self.bias_mu = Parameter(torch.empty((out_features), device=self.device))
            self.bias_rho = Parameter(torch.empty((out_features), device=self.device))
        else:
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_rho', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.W_mu.data.normal_(*self.posterior_mu_initial)
        self.W_rho.data.normal_(*self.posterior_rho_initial)

        if self.use_bias:
            self.bias_mu.data.normal_(*self.posterior_mu_initial)
            self.bias_rho.data.normal_(*self.posterior_rho_initial)

    def forward(self, input, sample=True):
        if self.training or sample:
            W_eps = torch.empty(self.W_mu.size()).normal_(0, 1).to(self.device)
            self.W_sigma = torch.log1p(torch.exp(self.W_rho))
            weight = self.W_mu + W_eps * self.W_sigma

            if self.use_bias:
                bias_eps = torch.empty(self.bias_mu.size()).normal_(0, 1).to(self.device)
                self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
                bias = self.bias_mu + bias_eps * self.bias_sigma
            else:
                bias = None
        else:
            weight = self.W_mu
            bias = self.bias_mu if self.use_bias else None

        return F.linear(input, weight, bias)

    def kl_loss(self):
        kl = calculate_kl(self.W_mu, self.W_sigma,self.prior_mu, self.prior_sigma)
        return kl


In [5]:
class Net(ModuleWrapper):

    def __init__(self,priors):
        super().__init__()
        self.priors = priors
        self.flatten = FlattenLayer(4)
        self.fc = BBBLinear(4, 2)

In [6]:
def logmeanexp(x, dim=None, keepdim=False):
    """Stable computation of log(mean(exp(x))"""

    
    if dim is None:
        x, dim = x.view(-1), 0
    x_max, _ = torch.max(x, dim, keepdim=True)
    x = x_max + torch.log(torch.mean(torch.exp(x - x_max), dim, keepdim=True))
    return x if keepdim else x.squeeze(dim)

# check if dimension is correct

# def dimension_check(x, dim=None, keepdim=False):
#     if dim is None:
#         x, dim = x.view(-1), 0

#     return x if keepdim else x.squeeze(dim)


def adjust_learning_rate(optimizer, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def save_array_to_file(numpy_array, filename):
    file = open(filename, 'a')
    shape = " ".join(map(str, numpy_array.shape))
    np.savetxt(file, numpy_array.flatten(), newline=" ", fmt="%.3f")
    file.write("\n")
    file.close()


In [7]:
def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
    net.train()
    training_loss = 0.0
    accs = []
    kl_list = []
    for i, (inputs, labels) in enumerate(trainloader, 1):

        optimizer.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0],2, num_ens).to(device)

        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs.float())
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1)
        
        kl = kl / num_ens
        kl_list.append(kl.item())
        log_outputs = torch.mean(outputs, dim=2)

        beta =get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
        loss = criterion(log_outputs, labels, kl, beta)
        loss.backward()
        optimizer.step()

        accs.append(acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()
    return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)


def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
    """Calculate ensemble accuracy and NLL Loss"""
    net.train()
    valid_loss = 0.0
    accs = []

    for i, (inputs, labels) in enumerate(validloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], 2, num_ens).to(device)
        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs.float())
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

        log_outputs = torch.mean(outputs, dim=2)

        beta = get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
        valid_loss += criterion(log_outputs, labels, kl, beta).item()
        accs.append(acc(log_outputs, labels))

    return valid_loss/len(validloader), np.mean(accs)


In [8]:
############### Configuration file for Bayesian ###############
layer_type = 'lrt'  # 'bbb' or 'lrt'
activation_type = 'softplus'  # 'softplus' or 'relu'
priors={
    'prior_mu': 0,
    'prior_sigma': 0.1,
    'posterior_mu_initial': (0, 0.1),  # (mean, std) normal_
    'posterior_rho_initial': (-1, 0.1),  # (mean, std) normal_
}

n_epochs = 100
lr_start = 0.01
num_workers = 4
valid_size = 0.2
batch_size = 1000
train_ens = 10
valid_ens = 10
beta_type = 'Blundell'  # 'Blundell', 'Standard', etc. Use float for const value
record_mean_var = True
recording_freq_per_epoch = 1
record_layers = ['fc']

In [9]:
def getModel(net_type,priors):
    if (net_type == 'Net'):
        return Net(priors)

In [10]:
device="cpu"
Images=np.load('../../data/orthogonal_var3.npy',allow_pickle=True)
net_type="Net"
label0=Images.item().get('[0]')
label1=Images.item().get('[1]')

train_data=[]
reduction_num=1000
for i in range(reduction_num):
    train_data.append([label0[i],0])
for i in range(reduction_num):
    train_data.append([label1[i],1])
test_label0=label0[1001:1300]
test_label1=label1[1001:1300]
valid_label0=label0[1400:1600]
valid_label1=label1[1400:1600]

test_data=[]
for i in range(200):
    test_data.append([test_label0[i],0])
for i in range(200):
    test_data.append([test_label1[i],1])

valid_data=[]
for i in range(200):
    valid_data.append([valid_label0[i],0])
for i in range(200):
    valid_data.append([valid_label1[i],1])

train_loader = torch.utils.data.DataLoader(train_data, batch_size=reduction_num,drop_last = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=400, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=400, drop_last=True)

net = getModel(net_type, priors).to(device)

ckpt_dir = f'checkpoints/var3/bayesian'
ckpt_name = f'checkpoints/var3/bayesian/model_{net_type}_{layer_type}_{activation_type}.pt'

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir, exist_ok=True)

criterion = ELBO(reduction_num*2).to(device)
optimizer = Adam(net.parameters(), lr=lr_start)
#lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
valid_loss_max = np.Inf
for epoch in range(n_epochs):  # loop over the dataset multiple times
    print('train')
    train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    print('test')
    valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    #lr_sched.step(valid_loss)

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))

    for name, param in net.named_parameters():
        if param.requires_grad:
            print(name, param.data)

    # save model if validation accuracy has increased
    if valid_loss <= valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_max, valid_loss))
        torch.save(net.state_dict(), ckpt_name)
        valid_loss_max = valid_loss


train
test
Epoch: 0 	Training Loss: 0.6518 	Training Accuracy: 0.6770 	Validation Loss: 0.7503 	Validation Accuracy: 0.6350 	train_kl_div: 22.8075
fc.W_mu tensor([[-0.1073, -0.1939,  0.0089, -0.0087],
        [-0.0805,  0.0494, -0.0917, -0.0457]])
fc.W_rho tensor([[-2.9046, -2.9496, -2.9447, -2.9470],
        [-3.1022, -3.0272, -2.9980, -3.0925]])
Validation loss decreased (inf --> 0.750336).  Saving model ...
train
test
Epoch: 1 	Training Loss: 0.6469 	Training Accuracy: 0.6975 	Validation Loss: 0.7294 	Validation Accuracy: 0.7800 	train_kl_div: 20.7397
fc.W_mu tensor([[-0.0921, -0.2047,  0.0096, -0.0179],
        [-0.0948,  0.0617, -0.0722, -0.0264]])
fc.W_rho tensor([[-2.8891, -2.9308, -2.9247, -2.9271],
        [-3.0843, -3.0093, -2.9785, -3.0729]])
Validation loss decreased (0.750336 --> 0.729381).  Saving model ...
train
test
Epoch: 2 	Training Loss: 0.6519 	Training Accuracy: 0.6870 	Validation Loss: 0.7293 	Validation Accuracy: 0.7925 	train_kl_div: 20.2836
fc.W_mu tensor([[-0.

test
Epoch: 24 	Training Loss: 0.4549 	Training Accuracy: 0.9840 	Validation Loss: 0.6759 	Validation Accuracy: 0.9925 	train_kl_div: 48.9048
fc.W_mu tensor([[ 0.2027, -0.4609,  0.0331, -0.0246],
        [-0.3664,  0.3310, -0.0278,  0.0174]])
fc.W_rho tensor([[-2.7937, -2.4784, -2.5505, -2.5671],
        [-2.6461, -2.6295, -2.6734, -2.7122]])
train
test
Epoch: 25 	Training Loss: 0.4487 	Training Accuracy: 0.9865 	Validation Loss: 0.6869 	Validation Accuracy: 0.9925 	train_kl_div: 49.7897
fc.W_mu tensor([[ 0.2133, -0.4709,  0.0395, -0.0246],
        [-0.3766,  0.3410, -0.0266,  0.0193]])
fc.W_rho tensor([[-2.7779, -2.4583, -2.5400, -2.5566],
        [-2.6246, -2.6079, -2.6629, -2.7011]])
train
test
Epoch: 26 	Training Loss: 0.4469 	Training Accuracy: 0.9855 	Validation Loss: 0.6559 	Validation Accuracy: 0.9925 	train_kl_div: 50.5575
fc.W_mu tensor([[ 0.2238, -0.4806,  0.0394, -0.0248],
        [-0.3867,  0.3507, -0.0251,  0.0209]])
fc.W_rho tensor([[-2.7656, -2.4376, -2.5294, -2.5462],


test
Epoch: 48 	Training Loss: 0.3364 	Training Accuracy: 0.9870 	Validation Loss: 0.5613 	Validation Accuracy: 0.9925 	train_kl_div: 49.1051
fc.W_mu tensor([[ 0.4218, -0.6761,  0.0403, -0.0240],
        [-0.5884,  0.5427, -0.0279,  0.0203]])
fc.W_rho tensor([[-2.2923, -1.9953, -2.3875, -2.4064],
        [-2.1186, -2.1168, -2.4909, -2.5169]])
Validation loss decreased (0.562540 --> 0.561344).  Saving model ...
train
test
Epoch: 49 	Training Loss: 0.3286 	Training Accuracy: 0.9905 	Validation Loss: 0.5625 	Validation Accuracy: 0.9925 	train_kl_div: 48.9205
fc.W_mu tensor([[ 0.4300, -0.6844,  0.0310, -0.0238],
        [-0.5970,  0.5508, -0.0273,  0.0202]])
fc.W_rho tensor([[-2.2739, -1.9778, -2.3830, -2.4028],
        [-2.0979, -2.1004, -2.4863, -2.5109]])
train
test
Epoch: 50 	Training Loss: 0.3232 	Training Accuracy: 0.9885 	Validation Loss: 0.5517 	Validation Accuracy: 0.9925 	train_kl_div: 48.6461
fc.W_mu tensor([[ 0.4382, -0.6925,  0.0232, -0.0240],
        [-0.6055,  0.5588, -0.026

test
Epoch: 72 	Training Loss: 0.2613 	Training Accuracy: 0.9880 	Validation Loss: 0.4484 	Validation Accuracy: 0.9925 	train_kl_div: 42.5705
fc.W_mu tensor([[ 0.6062, -0.8580,  0.0350, -0.0331],
        [-0.7819,  0.7203, -0.0251,  0.0261]])
fc.W_rho tensor([[-1.8712, -1.6143, -2.3332, -2.3444],
        [-1.7301, -1.7273, -2.4007, -2.4129]])
Validation loss decreased (0.461295 --> 0.448402).  Saving model ...
train
test
Epoch: 73 	Training Loss: 0.2691 	Training Accuracy: 0.9890 	Validation Loss: 0.4632 	Validation Accuracy: 0.9925 	train_kl_div: 42.2176
fc.W_mu tensor([[ 0.6133, -0.8648,  0.0322, -0.0343],
        [-0.7893,  0.7270, -0.0257,  0.0269]])
fc.W_rho tensor([[-1.8562, -1.6011, -2.3319, -2.3442],
        [-1.7150, -1.7119, -2.3982, -2.4096]])
train
test
Epoch: 74 	Training Loss: 0.2553 	Training Accuracy: 0.9920 	Validation Loss: 0.4488 	Validation Accuracy: 0.9925 	train_kl_div: 41.8935
fc.W_mu tensor([[ 0.6202, -0.8719,  0.0308, -0.0349],
        [-0.7965,  0.7338, -0.026

test
Epoch: 96 	Training Loss: 0.2177 	Training Accuracy: 0.9935 	Validation Loss: 0.3962 	Validation Accuracy: 0.9925 	train_kl_div: 37.4838
fc.W_mu tensor([[ 0.7644, -1.0166,  0.0328, -0.0254],
        [-0.9479,  0.8748, -0.0253,  0.0266]])
fc.W_rho tensor([[-1.5833, -1.3492, -2.3174, -2.3266],
        [-1.3765, -1.4619, -2.3536, -2.3507]])
train
test
Epoch: 97 	Training Loss: 0.2291 	Training Accuracy: 0.9880 	Validation Loss: 0.3868 	Validation Accuracy: 0.9925 	train_kl_div: 37.4217
fc.W_mu tensor([[ 0.7712, -1.0225,  0.0286, -0.0240],
        [-0.9550,  0.8805, -0.0254,  0.0258]])
fc.W_rho tensor([[-1.5747, -1.3388, -2.3169, -2.3256],
        [-1.3762, -1.4479, -2.3521, -2.3494]])
train
test
Epoch: 98 	Training Loss: 0.2004 	Training Accuracy: 0.9945 	Validation Loss: 0.3889 	Validation Accuracy: 0.9925 	train_kl_div: 37.3734
fc.W_mu tensor([[ 0.7777, -1.0283,  0.0228, -0.0236],
        [-0.9619,  0.8862, -0.0249,  0.0254]])
fc.W_rho tensor([[-1.5658, -1.3271, -2.3163, -2.3248],
