In [1]:
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import numpy as np
from hsvi.tensorflow import Hierarchy_SVI
from utils.distributions import Normal, OneHotCategorical
from utils.train_util import get_next_batch

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

### Example of training a Bayesian MLP model by variational inference

In [3]:
class Bayesian_MLP:
    def __init__(self,x,net_shape,learning_rate=0.001,num_samples=1,ac_fn=tf.nn.relu):
        self.x = x
        self.net_shape = net_shape
        self.num_samples = num_samples # number of samples of parameters
        self.ac_fn = ac_fn
        self._build_net()
        self._conf_opt(learning_rate)
        
    def _build_net(self):
        self.H,self.W,self.B,self.parm_var = [], [],[],{}
        
        ### expand input dimensions ###
        h = tf.expand_dims(self.x,axis=0)
        h = tf.tile(h,[self.num_samples,1,1])
        
        ### define variables ###
        with tf.variable_scope('global'):
            for i in range(len(self.net_shape)-1):
                ### define variational parameters ###
                print('conf layer {}'.format(i))
                d1 = self.net_shape[i]
                d2 = self.net_shape[i+1]
                w_loc_var = tf.get_variable(dtype=tf.float32, initializer=tf.random_normal([d1,d2],stddev=0.001),name='l'+str(i)+'_w_loc')
                w_s_var = tf.get_variable(dtype=tf.float32, initializer=tf.ones([d1,d2])*-3.,name='l'+str(i)+'_w_scale')
                b_loc_var = tf.get_variable(dtype=tf.float32, initializer=tf.random_normal([d2],stddev=0.001),name='l'+str(i)+'_b_loc')
                b_s_var = tf.get_variable(dtype=tf.float32, initializer=tf.ones([d2])*-3.,name='l'+str(i)+'_b_scale')
                w = Normal(loc=w_loc_var,scale=tf.exp(w_s_var))
                b = Normal(loc=b_loc_var,scale=tf.exp(b_s_var))
                self.W.append(w)
                self.B.append(b)
                self.parm_var[w] = [w_loc_var,w_s_var]
                self.parm_var[b] = [b_loc_var,b_s_var]

                ### sample parameters to compute output ###
                ew = w.sample(self.num_samples)
                eb = b.sample(self.num_samples)
                z = tf.einsum('sbi,sij->sbj',h,ew)+tf.expand_dims(eb,1)
                if i != len(self.net_shape) - 2:
                    h = self.ac_fn(z)
                else:
                    h = OneHotCategorical(logits=z)

                self.H.append(h)
            
    def _conf_opt(self,learning_rate):
        ### config optimizer ###
        with tf.variable_scope('global'):
            step = tf.Variable(0, trainable=False, name='global_step')                                      
            self.optimizer = (tf.train.AdamOptimizer(learning_rate,beta1=0.9),step)
      
    def forward(self,x,sess):
        h = x
        for l in range(len(self.W)):
            w = sess.run(self.W[l].loc)
            b = sess.run(self.B[l].loc)
            h = tf.add(tf.matmul(h,w),b)
            if l != len(self.W)-1:
                h = self.ac_fn(h)
            else:
                h = OneHotCategorical(logits=h)
        return h

In [4]:
def config_inference(model,Y,TRAIN_SIZE,vi_type='KLqp',scale=1.):
    
    ### config priors for parameters ###
    prior_cfg = {}
    for qw in model.W+model.B:
        pw = Normal(loc=tf.zeros_like(qw),scale=tf.ones_like(qw))
        prior_cfg[pw] = qw        
        
    ### config variational inference ###
    inference = Hierarchy_SVI(latent_vars={'global':prior_cfg},data={'global':{model.H[-1]:Y}},vi_types={'global':vi_type},scale={model.H[-1]:scale},optimizer={'global':model.optimizer},train_size=TRAIN_SIZE)        

    return inference

In [5]:
train_size = 50000
test_size = 10000
batch_size = 256
epoch = 50
hidden = [100,100]
num_samples = 1 # number of samples of parameters

In [6]:
### load data ###
DATA_DIR = '/home/yu/gits/data/mnist/'
data = input_data.read_data_sets(DATA_DIR,one_hot=True)
X_TRAIN = data.train.images[:train_size]
Y_TRAIN = data.train.labels[:train_size]
X_TEST = data.test.images[:test_size]
Y_TEST = data.test.labels[:test_size]

Extracting /home/yu/gits/data/mnist/train-images-idx3-ubyte.gz
Extracting /home/yu/gits/data/mnist/train-labels-idx1-ubyte.gz
Extracting /home/yu/gits/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /home/yu/gits/data/mnist/t10k-labels-idx1-ubyte.gz


In [7]:
### config net shape ###
in_dim = X_TRAIN.shape[1]
out_dim = Y_TRAIN.shape[1]
net_shape = [in_dim]+hidden+[out_dim]

### config data input ###
x_ph = tf.placeholder(dtype=tf.float32,shape=[None,in_dim])
y_ph = tf.placeholder(dtype=tf.float32,shape=[num_samples,None,out_dim])

In [8]:
### define model ###
model = Bayesian_MLP(x=x_ph,net_shape=net_shape,num_samples=num_samples)
inference = config_inference(model,y_ph,train_size)

conf layer 0
conf layer 1
conf layer 2
start init hsvi
global KLqp
config optimizer in scope global


In [9]:
### train process ###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tf.global_variables_initializer().run(session=sess)
for e in range(epoch):
    shuffle_inds = np.arange(X_TRAIN.shape[0])
    np.random.shuffle(shuffle_inds)
    x_train = X_TRAIN[shuffle_inds]
    y_train = Y_TRAIN[shuffle_inds]
    ii = 0
    num_iter = int(np.ceil(x_train.shape[0]/batch_size))
    for _ in range(num_iter):
        x_batch,y_batch,ii = get_next_batch(x_train,batch_size,ii,labels=y_train)
        y_batch = np.expand_dims(y_batch,axis=0)
        y_batch = np.repeat(y_batch,num_samples,axis=0)

        feed_dict = {x_ph:x_batch,y_ph:y_batch}
        info_dict = inference.update(scope='global',feed_dict=feed_dict,sess=sess)
    if (e+1)%10==0:
        print('epoch {} loss {}'.format(e+1, info_dict['loss']))

epoch 10 loss 2.6054372787475586
epoch 20 loss 1.7771762609481812
epoch 30 loss 1.3696476221084595
epoch 40 loss 1.1777886152267456
epoch 50 loss 1.039338231086731


In [10]:
### test process ###
ty = model.forward(x_ph,sess)
y_pred_prob = sess.run(ty,feed_dict={x_ph:X_TEST})
y_pred = np.argmax(y_pred_prob,axis=1)
correct = np.sum(np.argmax(Y_TEST,axis=1)==y_pred)
acc = correct/Y_TEST.shape[0]
print('accuracy is {}'.format(acc))

accuracy is 0.9728
