In [1]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from edward.models import Normal

plt.style.use('ggplot')
def build_toy_dataset(N, w):
  D = len(w)
  x = np.random.normal(0.0, 2.0, size=(N, D))
  y = np.dot(x, w) + np.random.normal(0.0, 0.01, size=N)
  return x, y


ed.set_seed(42)

N = 40  # number of data points
D = 10  # number of features

w_true = np.random.randn(D) * 0.5
X_train, y_train = build_toy_dataset(N, w_true)
X_test, y_test = build_toy_dataset(N, w_true)
X = tf.placeholder(tf.float32, [N, D])
w = Normal(loc=tf.zeros(D), scale=tf.ones(D))
b = Normal(loc=tf.zeros(1), scale=tf.ones(1))
y = Normal(loc=ed.dot(X, w) + b, scale=tf.ones(N))


qw = Normal(loc=tf.Variable(tf.random_normal([D])),
            scale=tf.nn.softplus(tf.Variable(tf.random_normal([D]))))
qb = Normal(loc=tf.Variable(tf.random_normal([1])),
            scale=tf.nn.softplus(tf.Variable(tf.random_normal([1]))))

inference = ed.KLqp({w: qw, b: qb}, data={X: X_train, y: y_train})
inference.run(n_samples=5, n_iter=250)


y_post = ed.copy(y, {w: qw, b: qb})
# This is equivalent to
# y_post = Normal(loc=ed.dot(X, qw) + qb, scale=tf.ones(N))


print("Mean squared error on test data:")
print(ed.evaluate('mean_squared_error', data={X: X_test, y_post: y_test}))

print("Mean absolute error on test data:")
print(ed.evaluate('mean_absolute_error', data={X: X_test, y_post: y_test}))




250/250 [100%] ██████████████████████████████ Elapsed: 7s | Loss: 66.347
Mean squared error on test data:
0.0480191
Mean absolute error on test data:
0.187246


In [6]:
y_post.eval(feed_dict={X:X_train})

array([-5.90990019, -1.124053  , -3.24419236, -0.86554849, -0.93430269,
       -2.3821857 ,  4.02491379,  3.15348768,  0.71714246,  2.09037328,
        0.50140125, -1.01154912, -3.33019471, -2.00523186,  4.79771423,
        0.38504279,  1.236485  ,  1.36871123, -3.50240827,  5.2772994 ,
        0.44167721, -0.67191738, -3.32381153,  0.17984629,  3.18410587,
       -3.24864459,  1.6402235 ,  0.1941148 ,  4.48797989,  1.08562958,
        2.20487332,  4.55905104, -2.43331051, -4.52374411, -0.96808124,
        0.35039425,  1.8222506 , -0.39504093,  0.36485246, -2.22324634], dtype=float32)