In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
import edward as ed
import numpy as np
import tensorflow as tf
import six
import os

In [3]:
from sklearn.linear_model import LogisticRegression

In [4]:
from edward.models import Normal,Empirical,Bernoulli,Categorical
from tensorflow.contrib import slim
from tensorflow.contrib.keras.api.keras.layers import Dense
from tensorflow.examples.tutorials.mnist import input_data

In [5]:
import matplotlib.pyplot as plt
import train_utils as util

In [6]:
%matplotlib inline

In [7]:
ed.set_seed(7)

In [8]:
mtype = 'bpm'
iftype = 'VI'
dataset = 'MNIST'
zero_constraint = False
test_logistic_regression = True

In [9]:
N = 20000 # number of training data points
N_test = 5000 # number of testing data points

noise_std = .01
B = 100 # batch size during training

if dataset == 'MNIST':
    DATA_DIR = "../data/mnist"
    IMG_DIR = "img"
    mnist = input_data.read_data_sets(DATA_DIR)
    X = mnist.train.images[:N]
    Y = mnist.train.labels[:N]
    X_test = mnist.test.images[:N_test]
    Y_test = mnist.test.labels[:N_test]
    M = X.shape[1]+1
    H = 10
    #normalise
    X = (X - X.mean())/X.std()
elif dataset == 'synthesis': 
    M = 2 # number of features
    H = 2 # number of classes
    
    s_mean = 0.
    s_std = 1.
    d_mean = 0.
    d_std = 3.
    
    Y,X = util.build_toy_dataset(mtype,N,M-1,H,s_std,s_mean,d_std)
    Y_test,X_test = util.build_toy_dataset(mtype,N_test,M-1,H,s_std,s_mean,d_std)
    

Extracting ../data/mnist/train-images-idx3-ubyte.gz
Extracting ../data/mnist/train-labels-idx1-ubyte.gz
Extracting ../data/mnist/t10k-images-idx3-ubyte.gz
Extracting ../data/mnist/t10k-labels-idx1-ubyte.gz


In [10]:
if test_logistic_regression:
    lgr = LogisticRegression(n_jobs=8,fit_intercept=True)
    lgr.fit(X,Y)
    print('Logistic regression train accuracy: ',sum(lgr.predict(X)==Y)/N)
    print('Logistic regression test accuracy: ',sum(lgr.predict(X_test)==Y_test)/N_test)

Logistic regression train accuracy:  0.9461
Logistic regression test accuracy:  0.7998


In [11]:
X = np.hstack((X,np.ones((N,1))))
X_test = np.hstack((X_test,np.ones((X_test.shape[0],1))))

In [12]:
X.shape

(20000, 785)

In [13]:
X = X.astype(np.float32)
#Y = Y.astype(np.float32)

x_ph = tf.placeholder(tf.float32, [B,M])
y_ph = tf.placeholder(tf.int32,[B])
y_ph_ohe = tf.placeholder(tf.float32,[B,H]) 

In [14]:
# model
if zero_constraint:
    w = Normal(tf.zeros([H-1,M]),tf.ones([H-1,M]))
    
    if iftype == 'HMC':
        qw = Empirical(params=tf.Variable(tf.random_normal([B,H-1,M])))
    else:
        qw = Normal(tf.Variable(tf.random_normal([H-1,M])), tf.nn.softplus(tf.Variable(tf.random_normal([H-1,M]))))

    y = Categorical(tf.nn.softmax(Normal(tf.concat([tf.matmul(x_ph,tf.transpose(w)),tf.zeros([B,1])],axis=1), noise_std)))
    y_test = tf.nn.softmax(tf.concat([tf.matmul(x_ph,tf.transpose(qw.loc)),tf.zeros([B,1])],axis=1))
else:

    w = Normal(tf.zeros([H,M]),tf.ones([H,M]))

    if iftype == 'HMC':
        qw = Empirical(params=tf.Variable(tf.random_normal([B,H,M])))
    else:
        qw = Normal(tf.Variable(tf.random_normal([H,M])), tf.nn.softplus(tf.Variable(tf.random_normal([H,M]))))

    y = Categorical(tf.nn.softmax(Normal(tf.matmul(x_ph,tf.transpose(w)), noise_std)))
    y_test = tf.nn.softmax(tf.matmul(x_ph,tf.transpose(qw.loc)))


In [15]:
# inference
scaling = float(N) / B
nprint = 1000
niter = 20000

if iftype == 'EP':
    inference = ed.KLpq({w:qw},data={y:y_ph})
elif iftype == 'VI':
    inference = ed.KLqp({w:qw},data={y:y_ph})
elif iftype == 'HMC':
    inference = ed.HMC({w:qw},data={y:y_ph})
else:
    print('invalid inference type')
    
inference.initialize(n_iter=niter,n_print=nprint,scale={y:scaling})


In [16]:
sess = ed.get_session()
tf.global_variables_initializer().run()

In [17]:
ii = 0
sess = ed.get_session()
for t in range(niter):
    x_batch,y_batch,ii = util.get_next_batch(X,B,ii,Y)
  
    info_dict = inference.update(feed_dict={x_ph:x_batch,y_ph:y_batch})
    inference.print_progress(info_dict)
    
    if t % nprint == 0:
        print('\n w mean:')
        print(sess.run(qw))

    1/20000 [  0%]                                ETA: 5259s | Loss: 53990.281

  if labels!=None:



 w mean:
[[  2.67822266e-01  -1.51020789e+00  -1.17865455e+00 ...,  -1.84761047e-01
   -6.00543320e-02   2.62241697e+00]
 [ -6.88719392e-01   3.66260767e-01  -7.26246953e-01 ...,  -2.00679526e-03
    7.05349445e-03   2.45263249e-01]
 [  3.59947920e-01   2.20906568e+00  -1.11120105e+00 ...,   1.89409447e+00
   -8.55887771e-01  -1.10888946e+00]
 ..., 
 [  8.41239452e-01  -6.53325200e-01  -4.13533211e-01 ...,  -2.44671404e-02
    3.94323897e+00   2.75328219e-01]
 [  1.64461064e+00   1.09980273e+00   1.27512217e-01 ...,   6.00731850e-01
    7.01308250e-04  -5.82185626e-01]
 [  3.35248321e-01   1.60996556e-01  -1.01432458e-01 ...,   1.39070201e+00
    7.07330346e-01  -3.94003719e-01]]


  if labels == None:


 1000/20000 [  5%] █                              ETA: 25s | Loss: 38683.613
 w mean:
[[-1.21243393  0.17324366  0.60993636 ..., -2.18824697 -0.8407042
   0.3460393 ]
 [ 0.46579108 -0.49815896 -0.10677464 ..., -1.54218745 -0.53067887
  -0.29770374]
 [-1.38680172  0.67412698 -1.00517583 ...,  1.96893454 -0.54159641
   1.56141317]
 ..., 
 [ 1.39497411  1.95443225 -0.132515   ..., -1.60168505  1.98014581
   2.77921438]
 [ 2.46722293  0.80048126 -1.03170204 ..., -0.34431851 -1.21401703
   0.01540397]
 [ 0.23285615 -0.27904716 -0.12100218 ..., -0.4286074   0.18486455
  -0.55285728]]
 2000/20000 [ 10%] ███                            ETA: 22s | Loss: 35170.191
 w mean:
[[ 1.1341182  -0.28098479  1.03415859 ..., -0.4112249  -0.51607668
  -1.52690589]
 [ 0.42127228 -0.93755102 -0.26536256 ..., -0.81387937  0.54299688
   1.36934638]
 [-1.10966778 -1.67036104  0.4514921  ..., -0.97145045 -0.32489905
   0.09871459]
 ..., 
 [-0.76343101 -1.23990417 -0.68970513 ..., -0.44924581  0.41307643
   1.8978

15000/20000 [ 75%] ██████████████████████         ETA: 5s | Loss: 35091.758
 w mean:
[[ 0.14566132 -0.65791881  0.03277641 ...,  0.76480663  0.18264417
   0.38647068]
 [ 0.89844805  1.20096481 -0.7505306  ...,  0.47410506  0.99964255
   0.42704001]
 [-1.54161704 -0.15599906 -1.72597766 ..., -0.43599355 -0.42263824
   0.11891308]
 ..., 
 [ 0.5641728  -1.35753286 -0.55483788 ...,  1.16626787  0.01191663
  -0.12132458]
 [ 1.20249116  0.33333489 -0.50969553 ...,  1.28476775 -0.47734132
  -1.79122746]
 [ 0.69181776  0.07064035 -0.66236871 ..., -0.43659431  1.93525434
  -0.66263449]]
16000/20000 [ 80%] ████████████████████████       ETA: 4s | Loss: 33927.852
 w mean:
[[-1.14550257  2.46165872  0.77798581 ..., -0.45457432  1.3382113
  -0.04069033]
 [ 0.90337908 -0.34321791 -1.23148751 ...,  0.89406478 -0.99558806
  -0.19531184]
 [-0.80012143 -1.04712629 -0.27539328 ...,  1.93954539  1.97658396
  -1.19280612]
 ..., 
 [ 0.2288304  -0.75238866  0.59764934 ..., -0.07961151 -0.68861341
   0.489621

In [18]:
ii = 0
acu = 0

for i in range(int(np.floor(N_test/B))):
    x_batch,y_batch,ii = util.get_next_batch(X_test,B,ii,Y_test)
    y_test_batch = sess.run(y_test,feed_dict={x_ph:x_batch,y_ph:y_batch})
    acu += sum(np.argmax(y_test_batch,axis=1)==y_batch)
print('Test accuracy: ', acu*1./N_test)

Test accuracy:  0.8066


In [19]:
ii = 0
acu = 0

for i in range(int(np.floor(N/B))):
    x_batch,y_batch,ii = util.get_next_batch(X,B,ii,Y)
    y_test_batch = sess.run(y_test,feed_dict={x_ph:x_batch,y_ph:y_batch})
    acu += sum(np.argmax(y_test_batch,axis=1)==y_batch)
print('Train accuracy: ', acu*1./N)

Train accuracy:  0.9057
