In [2]:
import paragami

import autograd
from autograd import numpy as np

# Use the original scipy for functions we don't need to differentiate.
import scipy as osp

In [23]:
np.random.seed(42)

num_obs = 10000
data_dim = 3

# True values of parameters
true_sigma = \
    np.eye(3) * np.diag(np.arange(0, data_dim)) + \
    np.random.random((data_dim, data_dim)) * 0.1
true_sigma = 0.5 * (true_sigma + true_sigma.T)
true_mu = np.arange(0, data_dim)

true_norm_param_dict = dict()
true_norm_param_dict['mu'] = true_mu
true_norm_param_dict['sigma'] = true_sigma

# Data
data = np.random.multivariate_normal(
    mean=true_norm_param_dict['mu'],
    cov=true_norm_param_dict['sigma'],
    size=(num_obs, ))

In [42]:
def get_mvn_log_probs(obs, mean, cov):
    cov_inv = np.linalg.inv(cov)
    cov_det_sign, cov_log_det = np.linalg.slogdet(cov)
    if cov_det_sign <= 0:
        return np.full(float('inf'), obs.shape[0])
    else:
        obs_centered = obs - np.expand_dims(mean, axis=0)
        return -0.5 * (
            np.einsum('ni,ij,nj->n', obs_centered, cov_inv, obs_centered) + \
            cov_log_det)

def get_data_lp(data, norm_param_dict, weights):
    data_lp = np.sum(weights *
                     get_mvn_log_probs(
                         data,
                         mean=norm_param_dict['mu'],
                         cov=norm_param_dict['sigma']))
    return data_lp

def get_prior_lp(norm_param_dict, prior_param_dict):
    prior_lp = get_mvn_log_probs(
        obs=np.expand_dims(norm_param_dict['mu'], axis=0),
        mean=prior_param_dict['prior_mean'],
        cov=prior_param_dict['prior_cov'])

    # Sum so as to return a scalar.
    return np.sum(prior_lp)

def get_loss(data, norm_param_dict, prior_param_dict, weights):
    return -1 * (get_prior_lp(norm_param_dict, prior_param_dict) +
                 get_data_lp(data, norm_param_dict, weights))
    
class NormalModel():
    def __init__(self, data):
        self.data = data
        self.num_obs = self.data.shape[0]
        self.data_dim = self.data.shape[1]
                
        # Reasonable defaults for the priors and weights.
        self.set_prior(np.full(self.data_dim, 0.), 100 * np.eye(self.data_dim))
        self.set_weights(np.full(self.num_obs, 1.0))
                
    def set_weights(self, weights):
        self.weights = weights
    
    def set_prior(self, prior_mean, prior_cov):
        self.prior_dict = dict()
        self.prior_dict['prior_mean'] = prior_mean
        self.prior_dict['prior_cov'] = prior_cov

    def get_loss_for_opt(self, norm_param_dict):
        return get_loss(
            self.data, norm_param_dict, self.prior_dict, self.weights)
    
    def get_loss_by_prior(self, norm_param_dict, prior_dict):
        return get_loss(
            self.data, norm_param_dict, prior_dict, self.weights)

    def get_loss_by_weights(self, norm_param_dict, weights):
        return get_loss(
            self.data, norm_param_dict, self.prior_dict, weights)

    
model = NormalModel(data)
print('Loss at true parameter: {}'.format(model.get_loss_for_opt(true_norm_param_dict)))

Loss at true parameter: 103.94712761511776


Define patterns.

In [43]:
norm_pattern = paragami.PatternDict()
norm_pattern['mu'] = paragami.NumericArrayPattern(shape=(data_dim, ))
norm_pattern['sigma'] = paragami.PSDMatrixPattern(size=data_dim)

prior_pattern = paragami.PatternDict()
prior_pattern['prior_mean'] = paragami.NumericArrayPattern(shape=(data_dim, ))
prior_pattern['prior_cov'] = paragami.PSDMatrixPattern(size=data_dim)

weight_pattern = paragami.NumericArrayPattern(shape=(num_obs, ))

In [44]:
# Optimize.
opt_fun = paragami.FlattenedFunction(
    original_fun=model.get_loss_for_opt,
    patterns=norm_pattern,
    free=True)
opt_fun_grad = autograd.grad(opt_fun)
opt_fun_hessian = autograd.hessian(opt_fun)

# Initialize with zeros.
init_param = np.zeros(norm_pattern.flat_length(free=True))
mle_opt = osp.optimize.minimize(
    method='trust-ncg',
    x0=init_param,
    fun=opt_fun,
    jac=opt_fun_grad,
    hess=opt_fun_hessian,
    options={'gtol': 1e-8, 'disp': True})

         Current function value: 97.135404
         Iterations: 19
         Function evaluations: 21
         Gradient evaluations: 18
         Hessian evaluations: 18


In [45]:
opt_norm_param_dict = norm_pattern.fold(mle_opt.x, free=True)
print(opt_norm_param_dict)
print(true_norm_param_dict)

OrderedDict([('mu', array([0.00322004, 0.96876934, 1.88305126])), ('sigma', array([[0.03823283, 0.07566443, 0.03912985],
       [0.07566443, 0.96271823, 0.09310977],
       [0.03912985, 0.09310977, 1.95881609]]))])
{'mu': array([0, 1, 2]), 'sigma': array([[0.03745401, 0.07746864, 0.03950388],
       [0.07746864, 1.01560186, 0.05110853],
       [0.03950388, 0.05110853, 2.0601115 ]])}
