In [1]:
import torch
import numpy as np
from hsvi.pytorch import Hierarchy_SVI

In [2]:
from torch import nn
from torch.nn import Parameter
from torch.distributions import Normal, OneHotCategorical
import torch.nn.functional as F

In [3]:
from torch.utils.data import Dataset,Subset
from torchvision import datasets
from torchvision.transforms import ToTensor,Lambda,Compose
from torch.utils.data import DataLoader

In [4]:
class Bayesian_Linear(nn.Module):
    def __init__(self,d1,d2,bias=True):
        super(Bayesian_Linear, self).__init__()
        self.bias = bias
        self.w_loc = torch.normal(torch.zeros([d1,d2]),torch.ones([d1,d2])*0.001)
        self.w_loc = Parameter(self.w_loc)
        self.w_logv = torch.ones([d1,d2])*-3
        self.w_logv = Parameter(self.w_logv)
        self.w = Normal(self.w_loc,torch.exp(self.w_logv))
        self.w_prior = Normal(loc=torch.zeros([d1,d2]),scale=torch.ones([d1,d2])) 
        
        if bias:
            self.b_loc = torch.normal(torch.zeros([d2]),torch.ones([d2])*0.001)
            self.b_loc = Parameter(self.b_loc)
            self.b_logv = torch.ones([d2])*-3
            self.b_logv = Parameter(self.b_logv)
            self.b = Normal(self.b_loc,torch.exp(self.b_logv))
            self.b_prior = Normal(loc=torch.zeros([d2]),scale=torch.ones([d2]))
            
    def forward(self,x):
        h = torch.matmul(x,self.w.rsample())
        if self.bias:
            h += self.b.rsample()
        return h

In [5]:
class Bayesian_MLP(nn.Module):
    def __init__(self,net_shape,ac_fn=nn.ReLU):
        super(Bayesian_MLP, self).__init__()
        self.net_shape = net_shape
        self.ac_fn = ac_fn
        self.net = self._build_net()
    
    def _build_net(self):
        modules = []
        for i in range(len(self.net_shape)-1):
            print('conf layer {}'.format(i))
            d1 = self.net_shape[i]
            d2 = self.net_shape[i+1]
            modules.append(Bayesian_Linear(d1,d2))
            if i != len(net_shape) - 2:
                modules.append(self.ac_fn())
        return nn.Sequential(*modules)
    
    def forward(self,x):
        return self.net(x)
        

In [6]:
def config_inference(model,TRAIN_SIZE,vi_type='KLqp',learning_rate=0.001,scale=1.):
    
    
    ### config variational inference ###
    inference = Hierarchy_SVI(vi_types={'global':vi_type},var_dict={'global':model.parameters()},learning_rate={'global':0.001},train_size=TRAIN_SIZE,scale={'global':{}})        


    return inference

In [7]:
train_size = 50000
test_size = 10000
batch_size = 256
epoch = 50
hidden = [100,100]
use_cuda = False
lr = 0.01

In [8]:
device = torch.device('cuda' if use_cuda else 'cpu')

In [9]:
X_TRAIN = datasets.MNIST(
    root="/home/yu/gits/data/mnist/",
    train=True,
    download=False,
    transform=Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))]),
    #target_transform=Compose([
    #                              lambda x:torch.LongTensor([x]), 
    #                                lambda x:F.one_hot(x,10)])
)
indices = torch.arange(train_size)
X_TRAIN = Subset(X_TRAIN, indices)

X_TEST = datasets.MNIST(
    root="/home/yu/gits/data/mnist/",
    train=False,
    download=False,
    transform=Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))]),
    #target_transform=Compose([
    #                             lambda x:torch.LongTensor([x]), 
    #                             lambda x:F.one_hot(x,10)])
)
indices = torch.arange(test_size)
X_TEST = Subset(X_TEST, indices)

In [10]:
train_dataloader = DataLoader(X_TRAIN, batch_size=batch_size, shuffle=True,num_workers=8)
test_dataloader = DataLoader(X_TEST, batch_size=batch_size, shuffle=True,num_workers=8)

In [11]:
### config net shape ###
d_dim = 784
out_dim = 10
net_shape = [d_dim] + hidden + [out_dim]

In [12]:
model = Bayesian_MLP(net_shape)
model.to(device)

conf layer 0
conf layer 1
conf layer 2


Bayesian_MLP(
  (net): Sequential(
    (0): Bayesian_Linear()
    (1): ReLU()
    (2): Bayesian_Linear()
    (3): ReLU()
    (4): Bayesian_Linear()
  )
)

In [13]:
pqz = {}
for module in model.modules():
    if isinstance(module,Bayesian_Linear):
        pqz.update({module.w_prior:module.w,module.b_prior:module.b})

In [14]:
inference = config_inference(model,train_size,learning_rate=lr,vi_type='KLqp')
inference.latent_vars = {'global':pqz}

start init hsvi
global KLqp


In [15]:
### training process ###

for e in range(epoch):
    
    for i,(x_batch,y_batch) in enumerate(train_dataloader): 
        x_batch.to(device)
        y_batch.to(device)
        
        ### use to compute the liklihood of raw data ###        
        ll = F.cross_entropy(model(x_batch),y_batch) 
        
        inference.data = {'global':{}}
        inference.extra_loss = {'global':ll}

        loss = inference.update('global',retain_graph=True)
    if (e+1)%10==0:
        print('epoch {} loss {}'.format(e+1, loss))

  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag


epoch 10 loss 4.550666332244873
epoch 20 loss 4.528886318206787
epoch 30 loss 4.500082969665527
epoch 40 loss 4.497791767120361
epoch 50 loss 4.5003461837768555


In [16]:
### test process ###
correct = 0
for i,(x_batch,y_batch) in enumerate(test_dataloader): 
    x_batch.to(device)
    y_batch.to(device)
    py = torch.argmax(model(x_batch),1)
    #print((py==y_batch).sum())
    #correct += (py==y_batch).sum()
    correct += (py==y_batch).sum()
acc = correct/test_size
print('accuracy is {}'.format(acc))

accuracy is 0.9747999906539917
