In [5]:
import autograd
import autograd.numpy as np
import autograd.scipy as sp
from autograd.test_util import check_grads

num_obs = 1000

# Some valid values
true_sigma = \
    np.eye(3) * np.diag(np.array([1, 2, 3])) + \
    np.random.random((3, 3)) * 0.1
true_sigma = 0.5 * (true_sigma + true_sigma.T)

true_mu = np.array([0, 1, 2])

x = np.random.multivariate_normal(
    mean=true_mu, cov=true_sigma, size=(num_obs, ))

# Succeeds
print('Differentiation succeeds')
check_grads(autograd.scipy.stats.multivariate_normal.logpdf, modes=['rev'])(x, true_mu, true_sigma)

# Fails
try:
    def logpdf_mu(mu):
        return np.sum(autograd.scipy.stats.multivariate_normal.logpdf(x, mu, true_sigma))
    print('Can evaluate the function: ', logpdf_mu(true_mu))
    check_grads(logpdf_mu, modes=['rev'])(true_mu)
except AssertionError as err:
    print('Differentiation fails:\ntest_util.py 31:\tassert vspace(vjp_y) == x_vs')

try:
    def logpdf_mu(mu):
        return np.sum(autograd.scipy.stats.multivariate_normal.logpdf(x, mean=mu, cov=true_sigma))
    print('Can evaluate the function: ', logpdf_mu(true_mu))
    check_grads(logpdf_mu, modes=['rev'])(true_mu)
except TypeError as err:
    print('Differentiation fails: ' + str(err))




Differentiation succeeds
Can evaluate the function:  -5135.146812235187
Differentiation fails:
test_util.py 31:	assert vspace(vjp_y) == x_vs
Can evaluate the function:  -5135.146812235187
Differentiation fails: Can't differentiate w.r.t. type <class 'numpy.int64'>


In [6]:
model_logpdf_params_mu_grad = autograd.grad(autograd.scipy.stats.multivariate_normal.logpdf, argnum=2)
print(model_logpdf_params_mu_grad(x[0, :], true_mu, true_sigma))

model_logpdf_params_mu_grad = autograd.grad(autograd.scipy.stats.multivariate_normal.logpdf, argnum=1)
print(model_logpdf_params_mu_grad(x[0, :], true_mu, true_sigma))

print(autograd.scipy.stats.multivariate_normal.logpdf(x[0, :], true_mu, true_sigma))
print(autograd.scipy.stats.multivariate_normal.logpdf(x[0, :], mean=true_mu, cov=true_sigma))

[[ 0.37957388  0.27230804  0.04135152]
 [ 0.27230804 -0.45748685 -0.13267569]
 [ 0.04135152 -0.13267569  0.14400415]]
[-0.49025091  1.1833296   0.21333178]
-5.2780275801618695
-5.2780275801618695
