In [122]:
import autograd
import autograd.numpy as np
import autograd.scipy as sp

import VariationalBayes as vb
from copy import deepcopy

import scipy as osp
from scipy import stats

import matplotlib.pyplot as plt
%matplotlib inline


In [119]:

class Kernel(object):
    def __init__(self, c=1.0, beta=-0.5):
        self.c = c
        self.beta = beta
        
    def k(self, x, y):
        xy_diff = x - y
        return (self.c + np.dot(xy_diff, xy_diff)) ** self.beta

    
class MVNModel(object):
    def __init__(self, obs, obs_info_mat, prior_loc):
        self.prior_loc = deepcopy(prior_loc)
        self.obs_dim = len(obs)
        self.obs = obs
        self.obs_info_mat = deepcopy(obs_info_mat)
        self.prior_info_mat = np.eye(self.obs_dim)

        self.theta = np.zeros(self.obs_dim)
        
    def get_log_post(self, broadcast=False):
        if broadcast:
            obs_centered = np.expand_dims(self.obs, 1) - self.theta
            theta_centered = self.theta - np.expand_dims(self.prior_loc, 1)
        else:
            obs_centered = self.obs - self.theta
            theta_centered = self.theta - self.prior_loc
            
        log_lik =  -0.5 * np.einsum(
            'i...,ij,j...->...',
            obs_centered, self.obs_info_mat, obs_centered)

        log_prior =  -0.5 * np.einsum(
            'i...,ij,j...->...',
            theta_centered, self.prior_info_mat, theta_centered)
        
        return log_lik + log_prior
    
    def get_post_cov(self):
        return np.linalg.inv(self.obs_info_mat + self.prior_info_mat)

    def get_post_mean(self):
        post_cov = self.get_post_cov()
        post_suff_stat = np.matmul(self.obs_info_mat, obs) + \
                         np.matmul(self.prior_info_mat, self.prior_loc)
        return np.matmul(post_cov, post_suff_stat)
    
    def set_theta_get_log_post(self, theta):
        self.theta = theta
        return self.get_log_post()



In [105]:
obs_dim = 2
obs = np.array([0.2, 0.5])
obs_info_mat = 30 * np.eye(obs_dim)
prior_loc = np.zeros(obs_dim)
theta = np.full(obs_dim, 0.2)

mvn_model = MVNModel(obs, obs_info_mat, prior_loc)
print(mvn_model.get_post_mean())
print(obs)



[ 0.19354839  0.48387097]
[ 0.2  0.5]


In [181]:
num_draws = 30
theta_draws = osp.stats.multivariate_normal.rvs(
    size=num_draws, mean=mvn_model.get_post_mean(), cov=mvn_model.get_post_cov())
print(theta_draws.shape)

theta_mean = np.mean(theta_draws, axis=0)
theta_sd = np.std(theta_draws, axis=0)
theta_se = theta_sd / np.sqrt(num_draws)

(30, 2)


In [168]:
class SteinGradientGenerator(object):
    def __init__(self, mvn_model):
        self.mvn_model = mvn_model
        self.get_log_post_grad = autograd.grad(self.mvn_model.set_theta_get_log_post)
        self.kernel = Kernel()
        self.get_kernel_grad = autograd.grad(kernel.k, argnum=0)

    def get_kernel_mat(self, theta_draws):
        return np.array([[ kernel.k(t1, t2) for t2 in theta_draws] for t1 in theta_draws])
    
    def get_kernel_grad_mat(self, theta_draws):
        # t1 is in the first index; it is this that we need to sum out for the gradient step.
        return np.array([[ get_kernel_grad(t1, t2) for t2 in theta_draws] \
                           for t1 in theta_draws])

    def get_post_grads(self, theta_draws):
        return np.array([ get_log_post_grad(theta) for theta in theta_draws ])
    
    def get_stein_grad_term(self, theta_draws):
        kernel_mat = self.get_kernel_mat(theta_draws)
        post_grads = self.get_post_grads(theta_draws)
        num_draws = theta_draws.shape[0]
        return np.einsum('ij,jk->ik', kernel_mat, post_grads) / num_draws
    
    def get_stein_kernel_term(self, theta_draws):
        num_draws = theta_draws.shape[0]
        kernel_grad_mat = self.get_kernel_grad_mat(theta_draws)
        return np.sum(kernel_grad_mat, axis=0) / num_draws
    
    def get_stein_direction(self, theta_draws):
        return self.get_stein_grad_term(theta_draws) + self.get_stein_kernel_term(theta_draws)

In [161]:
kernel = Kernel()
get_kernel_grad = autograd.grad(kernel.k, argnum=0)
print(kernel_mat.shape)
#plt.matshow(kernel_mat); plt.colorbar()

print(kernel_grad_mat.shape)


(30, 30)
(30, 30, 2)


In [240]:
stein_gradient_generator = SteinGradientGenerator(mvn_model)

def get_stein_direction(stein_gradient_generator, theta_draws):
    stein_grad_term = stein_gradient_generator.get_stein_grad_term(theta_draws)
    stein_kernel_term = stein_gradient_generator.get_stein_kernel_term(theta_draws)
    stein_direction = stein_gradient_generator.get_stein_direction(theta_draws)
    
    return stein_direction

def get_stein_step(stein_gradient_generator, theta_draws, start_scale=1.0):
    stein_direction = get_stein_direction(stein_gradient_generator, theta_draws)

    theta_sd = np.std(theta_draws, axis=0)
    num_draws = theta_draws.shape[0]
    theta_se = theta_sd / np.sqrt(num_draws)

    # This is the step size that changes the mean by no more than se_scale times
    # a standard error.
    stein_step_eps = np.min(theta_se / np.abs(np.mean(stein_direction, axis=0)))
    
    grad_mag = np.linalg.norm(stein_direction)
    grad_mag_diff = np.float('inf')
    scale = start_scale
    while grad_mag_diff > 0.:
        theta_update = theta_draws + scale * stein_step_eps * stein_direction
        new_direction = get_stein_direction(stein_gradient_generator, theta_update)
        new_grad_mag = np.linalg.norm(new_direction)
        grad_mag_diff = new_grad_mag - grad_mag
        if grad_mag_diff >= 0.:
            scale *= 0.5
            print('Difference: ', grad_mag_diff, ' decreasing scale to ', scale)
        else:
            print('Accepting step from ', grad_mag, ' to ', new_grad_mag)

    return theta_update, stein_direction, scale, stein_step_eps


In [241]:

true_mean = mvn_model.get_post_mean()
theta_draws = osp.stats.multivariate_normal.rvs(
    size=50,
    mean=mvn_model.get_post_mean(),
    cov=mvn_model.get_post_cov())
init_theta_draws = deepcopy(theta_draws)

last_scale = 1.0
for step in range(20):
    theta_draws, stein_direction, last_scale, stein_step_eps = \
        get_stein_step(stein_gradient_generator, theta_draws, start_scale=1.5 * last_scale)
    print(np.linalg.norm(stein_direction))
    #print(np.mean(theta_draws, axis=0) - true_mean)

print('Done.')



Accepting step from  4.25636441352  to  1.75647844211
4.25636441352
Difference:  2.50305116805  decreasing scale to  0.5
Accepting step from  1.75647844211  to  1.27412998101
1.75647844211
Difference:  0.485066657232  decreasing scale to  0.25
Accepting step from  1.27412998101  to  0.315341478169
1.27412998101
Difference:  0.95273218818  decreasing scale to  0.125
Difference:  0.220691123654  decreasing scale to  0.0625
Accepting step from  0.315341478169  to  0.227029449171
0.315341478169
Difference:  0.0698183332681  decreasing scale to  0.03125
Accepting step from  0.227029449171  to  0.185081795535
0.227029449171
Difference:  0.0163459230463  decreasing scale to  0.015625
Accepting step from  0.185081795535  to  0.168569262437
0.185081795535
Accepting step from  0.168569262437  to  0.15911017447
0.168569262437
Accepting step from  0.15911017447  to  0.142865932185
0.15911017447
Accepting step from  0.142865932185  to  0.137192973216
0.142865932185
Accepting step from  0.1371929732

In [226]:
print(np.cov(np.transpose(theta_draws)), '\n-----\n',
      np.cov(np.transpose(init_theta_draws)), '\n-----\n',
      mvn_model.get_post_cov())

print('\n\n\n\n')
print(np.mean(theta_draws, axis=0), '\n-----\n',
      np.mean(init_theta_draws, axis=0), '\n-----\n',
      mvn_model.get_post_mean())


[[ 0.02638812  0.00110151]
 [ 0.00110151  0.02600595]] 
-----
 [[ 0.02603235  0.00114453]
 [ 0.00114453  0.02561276]] 
-----
 [[ 0.03225806  0.        ]
 [ 0.          0.03225806]]





[ 0.19872657  0.48233941] 
-----
 [ 0.21501743  0.47621199] 
-----
 [ 0.19354839  0.48387097]
