In [122]:
import autograd
import autograd.numpy as np
import autograd.scipy as sp

import VariationalBayes as vb
from copy import deepcopy

import scipy as osp
from scipy import stats

import matplotlib.pyplot as plt
%matplotlib inline


In [119]:

class Kernel(object):
    def __init__(self, c=1.0, beta=-0.5):
        self.c = c
        self.beta = beta
        
    def k(self, x, y):
        xy_diff = x - y
        return (self.c + np.dot(xy_diff, xy_diff)) ** self.beta

    
class MVNModel(object):
    def __init__(self, obs, obs_info_mat, prior_loc):
        self.prior_loc = deepcopy(prior_loc)
        self.obs_dim = len(obs)
        self.obs = obs
        self.obs_info_mat = deepcopy(obs_info_mat)
        self.prior_info_mat = np.eye(self.obs_dim)

        self.theta = np.zeros(self.obs_dim)
        
    def get_log_post(self, broadcast=False):
        if broadcast:
            obs_centered = np.expand_dims(self.obs, 1) - self.theta
            theta_centered = self.theta - np.expand_dims(self.prior_loc, 1)
        else:
            obs_centered = self.obs - self.theta
            theta_centered = self.theta - self.prior_loc
            
        log_lik =  -0.5 * np.einsum(
            'i...,ij,j...->...',
            obs_centered, self.obs_info_mat, obs_centered)

        log_prior =  -0.5 * np.einsum(
            'i...,ij,j...->...',
            theta_centered, self.prior_info_mat, theta_centered)
        
        return log_lik + log_prior
    
    def get_post_cov(self):
        return np.linalg.inv(self.obs_info_mat + self.prior_info_mat)

    def get_post_mean(self):
        post_cov = self.get_post_cov()
        post_suff_stat = np.matmul(self.obs_info_mat, obs) + \
                         np.matmul(self.prior_info_mat, self.prior_loc)
        return np.matmul(post_cov, post_suff_stat)
    
    def set_theta_get_log_post(self, theta):
        self.theta = theta
        return self.get_log_post()



In [105]:
obs_dim = 2
obs = np.array([0.2, 0.5])
obs_info_mat = 30 * np.eye(obs_dim)
prior_loc = np.zeros(obs_dim)
theta = np.full(obs_dim, 0.2)

mvn_model = MVNModel(obs, obs_info_mat, prior_loc)
print(mvn_model.get_post_mean())
print(obs)



[ 0.19354839  0.48387097]
[ 0.2  0.5]


In [181]:
num_draws = 30
theta_draws = osp.stats.multivariate_normal.rvs(
    size=num_draws, mean=mvn_model.get_post_mean(), cov=mvn_model.get_post_cov())
print(theta_draws.shape)

theta_mean = np.mean(theta_draws, axis=0)
theta_sd = np.std(theta_draws, axis=0)
theta_se = theta_sd / np.sqrt(num_draws)

(30, 2)


In [168]:
class SteinGradientGenerator(object):
    def __init__(self, mvn_model):
        self.mvn_model = mvn_model
        self.get_log_post_grad = autograd.grad(self.mvn_model.set_theta_get_log_post)
        self.kernel = Kernel()
        self.get_kernel_grad = autograd.grad(kernel.k, argnum=0)

    def get_kernel_mat(self, theta_draws):
        return np.array([[ kernel.k(t1, t2) for t2 in theta_draws] for t1 in theta_draws])
    
    def get_kernel_grad_mat(self, theta_draws):
        # t1 is in the first index; it is this that we need to sum out for the gradient step.
        return np.array([[ get_kernel_grad(t1, t2) for t2 in theta_draws] \
                           for t1 in theta_draws])

    def get_post_grads(self, theta_draws):
        return np.array([ get_log_post_grad(theta) for theta in theta_draws ])
    
    def get_stein_grad_term(self, theta_draws):
        kernel_mat = self.get_kernel_mat(theta_draws)
        post_grads = self.get_post_grads(theta_draws)
        num_draws = theta_draws.shape[0]
        return np.einsum('ij,jk->ik', kernel_mat, post_grads) / num_draws
    
    def get_stein_kernel_term(self, theta_draws):
        num_draws = theta_draws.shape[0]
        kernel_grad_mat = self.get_kernel_grad_mat(theta_draws)
        return np.sum(kernel_grad_mat, axis=0) / num_draws
    
    def get_stein_direction(self, theta_draws):
        return self.get_stein_grad_term(theta_draws) + self.get_stein_kernel_term(theta_draws)

In [161]:
kernel = Kernel()
get_kernel_grad = autograd.grad(kernel.k, argnum=0)
print(kernel_mat.shape)
#plt.matshow(kernel_mat); plt.colorbar()

print(kernel_grad_mat.shape)


(30, 30)
(30, 30, 2)


In [197]:
stein_gradient_generator = SteinGradientGenerator(mvn_model)

def update_theta(stein_gradient_generator, theta_draws):
    stein_grad_term = stein_gradient_generator.get_stein_grad_term(theta_draws)
    stein_kernel_term = stein_gradient_generator.get_stein_kernel_term(theta_draws)
    stein_step = stein_gradient_generator.get_stein_direction(theta_draws)
    theta_sd = np.std(theta_draws, axis=0)
    num_draws = theta_draws.shape[0]
    theta_se = theta_sd / np.sqrt(num_draws)

    # This is the step size that changes the mean by no more than half a standard error.
    stein_step_eps = 0.5 * np.min(theta_se / np.abs(np.mean(stein_step, axis=0)))
    theta_update = theta_draws + stein_step_eps * stein_step

    return theta_update, stein_step, stein_step_eps


In [201]:

true_mean = mvn_model.get_post_mean()
theta_draws = osp.stats.multivariate_normal.rvs(
    size=100,
    mean=mvn_model.get_post_mean(),
    cov=mvn_model.get_post_cov())
print(np.mean(theta_draws, axis=0) - true_mean)

last_theta = deepcopy(theta_draws)
print('Updating')
theta_draws, stein_step, stein_step_eps = update_theta(stein_gradient_generator, theta_draws)
print(np.mean(theta_draws, axis=0) - true_mean)




[  4.58293346e-05   1.07220308e-03]
Updating
[ 0.01484262 -0.00921673]
