In [1]:
import autograd.numpy as np
from autograd.core import primitive
from autograd import grad, jacobian, hessian
from autograd.numpy.numpy_grads import unbroadcast
import scipy.stats

In [51]:
def LogitTermClosure(beta, beta_cov, x_mat, std_vec):
    class DataState:
        pass
    
    class ParamState:
        pass
    
    def HashData(x_mat, std_vec):
        # Better than nothing I guess.
        return hash(str(x_mat) + str(std_vec))
    
    data = DataState()
    par = ParamState()
    def SetDataState(x_mat, std_vec):
        data.x_mat = x_mat
        data.std_vec = std_vec
        data.x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
        data.data_hash = HashData(x_mat, std_vec)
        
    def SetParamState(beta, beta_cov):
        par.beta = beta
        par.K = beta.size
        assert beta_cov.shape == (par.K, par.K)
        par.beta_cov = beta_cov
        par.sigma = np.einsum('ijk,jk->i', data.x_outer, beta_cov)
        par.mu = np.einsum('ij,j->i', data.x_mat, beta)
        par.z = np.einsum('i,j->ij', par.sigma, data.std_vec) + \
                np.expand_dims(par.mu, 1)
        par.p = np.exp(par.z) / (1 + np.exp(par.z))
        
    def CheckParCache(beta, beta_cov):
        if (beta != par.beta).any() or (beta_cov != par.beta_cov).any():
            print('Refreshing parameter cache.  (So refreshing.)')
            SetParamState(beta, beta_cov)

    def CheckDataCache(x_mat, std_vec):
        if HashData(x_mat, std_vec) != data.data_hash:
            print('Refreshing data cache.  (So refreshing.)')
            SetDataState(x_mat, std_vec)
            SetParamState(par.beta, par.beta_cov)
            
    SetDataState(x_mat, std_vec)
    SetParamState(beta, beta_cov)
    
    # Define the functions.
    
    # Only this will be accessible on the outside, so only here do we need the
    # cache check.
    @primitive
    def LogitTerm(beta, beta_cov, x_mat, std_vec):
        CheckDataCache(x_mat, std_vec)
        CheckParCache(beta, beta_cov)
        return np.sum(np.log(1 + np.exp(par.z)))

    # We will simplify the formulas by expressing gradients as functions of
    # weighted sums over p, since only p depends on the parameters.
    @primitive
    def WeightedPSum(beta, beta_cov, weights):
        return np.sum(par.p * weights)

    @primitive
    def LogitTerm_grad_beta_term(beta, beta_cov, a):
        weights = np.expand_dims(data.x_mat[:, a], 1)
        return WeightedPSum(par.beta, par.beta_cov, weights)

    @primitive
    def LogitTerm_grad_beta(beta, beta_cov):
        return np.array([ LogitTerm_grad_beta_term(par.beta, par.beta_cov, a) for a in range(par.K)])

    @primitive
    def LogitTerm_vjp_beta(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
        return g * LogitTerm_grad_beta(par.beta, par.beta_cov)
    LogitTerm.defvjp(LogitTerm_vjp_beta, argnum=0)
    
    @primitive
    def LogitTerm_grad_beta_cov_term(beta, beta_cov, a, b):
        z_grad = np.einsum('j,i->ij', data.std_vec, data.x_outer[:, a, b])
        return WeightedPSum(par.beta, par.beta_cov, z_grad)

    @primitive
    def LogitTerm_grad_beta_cov(beta, beta_cov):
        GradTerm = lambda a, b: LogitTerm_grad_beta_cov_term(par.beta, par.beta_cov, a, b)
        return np.array([[ GradTerm(a, b) for a in range(par.K) ] for b in range(par.K) ])

    @primitive
    def LogitTerm_vjp_beta_cov(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
        return g * LogitTerm_grad_beta_cov(beta, beta_cov)
    LogitTerm.defvjp(LogitTerm_vjp_beta_cov, argnum=1)


    return LogitTerm, SetDataState, SetParamState


In [39]:
def LogitTermAD(beta, beta_cov, x_mat, std_vec):
    x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
    sigma = np.einsum('ijk,jk->i', x_outer, beta_cov)
    mu = np.einsum('ij,j->i', x_mat, beta)
    z = np.einsum('i,j->ij', sigma, std_vec) + np.expand_dims(mu, 1)
    return np.sum(np.log(1 + np.exp(z)))

# @primitive
# def LogitTerm(beta, beta_cov, x_mat, std_vec):
#     return LogitTermAD(beta, beta_cov, x_mat, std_vec)

# Every gradient is of the form of weighted sums of p, so writing everything
# in terms of this makes it easy to differentiate.
# @primitive
# def WeightedPSum(beta, beta_cov, p, x_mat, std_vec, weights):
#     return np.sum(p * weights)

@primitive
def WeightedPSum_grad_beta(beta, beta_cov, p, x_mat, std_vec, weights):
    return np.einsum('ij,ij,ik->k', p * (1 - p), weights, x_mat)

@primitive
def WeightedPSum_grad_beta_cov(beta, beta_cov, p, x_mat, std_vec, weights):
    x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
    return np.einsum('ij,ij,j,iab->ab', p * (1 - p), weights, std_vec, x_outer)

# def LogitTerm_grad_beta_term(beta, p, x_mat, std_vec, a):
#     # sigma = np.einsum('ijk,jk->i', x_outer, beta_cov)
#     # mu = np.einsum('ij,j->i', x_mat, beta)
#     # z = np.einsum('i,j->ij', sigma, std_vec) + np.expand_dims(mu, 1)
#     # p = np.exp(z) / (1 + np.exp(z))
#     return WeightedPSum(beta, beta_cov, p, x_mat, std_vec, np.expand_dims(x_mat[:, a], 1))
#     # return np.sum(p * np.expand_dims(x_mat[:, a], 1))

# def LogitTerm_grad_beta(beta, beta_cov, p, x_mat, std_vec):
#     K = beta.size
#     return np.array([ LogitTerm_grad_beta_term(beta, beta_cov, p, x_mat, std_vec, a) for a in range(K)])

# def LogitTerm_vjp_beta(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
# #     sigma = np.einsum('ijk,jk->i', x_outer, beta_cov)
# #     mu = np.einsum('ij,j->i', x_mat, beta)
# #     z = np.einsum('i,j->ij', sigma, std_vec) + np.expand_dims(mu, 1)
# #     p = np.exp(z) / (1 + np.exp(z))
#     return g * LogitTerm_grad_beta(beta, beta_cov, x_mat, std_vec)
# LogitTerm.defvjp(LogitTerm_vjp_beta, argnum=0)

# def LogitTerm_grad_beta_cov_term(beta, beta_cov, x_mat, std_vec, a, b):
#     sigma = np.einsum('ijk,jk->i', x_outer, beta_cov)
#     mu = np.einsum('ij,j->i', x_mat, beta)
#     z = np.einsum('i,j->ij', sigma, std_vec) + np.expand_dims(mu, 1)
#     p = np.exp(z) / (1 + np.exp(z))
#     x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
#     z_grad = np.einsum('j,i->ij', std_vec, x_outer[:, a, b])
#     return WeightedPSum(beta, beta_cov, p, x_mat, std_vec, z_grad)
#     # return np.einsum('ji,i,j', p, std_vec, x_outer[:, a, b])

# def LogitTerm_grad_beta_cov(beta, beta_cov, x_mat, std_vec):
#     K = beta.size
#     GradTerm = lambda a, b: LogitTerm_grad_beta_cov_term(beta, beta_cov, x_mat, std_vec, a, b)
#     return np.array([[ GradTerm(a, b) for a in range(K) ] for b in range(K) ])

# def LogitTerm_vjp_beta_cov(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
#     return g * LogitTerm_grad_beta_cov(beta, beta_cov, x_mat, std_vec)
# LogitTerm.defvjp(LogitTerm_vjp_beta_cov, argnum=1)


# Define Hessians


# Wrapping functions
def UnWrap(par_vec, K):
    beta = par_vec[0:K]
    beta_cov = par_vec[K:].reshape(K, K)
    return beta, beta_cov

def Wrap(beta, beta_cov):
    K = beta.size
    return np.concatenate((beta, beta_cov.ravel()))

def LogitTermWrap(par_vec, x_mat, std_vec):
    beta, beta_cov = UnWrap(par_vec, 2)
    return LogitTerm(beta, beta_cov, x_mat, std_vec)

def LogitTermWrapAD(par_vec, x_mat, std_vec):
    beta, beta_cov = UnWrap(par_vec, 2)
    return LogitTermAD(beta, beta_cov, x_mat, std_vec)


In [52]:
N = 10
beta = np.array([1., -0.5])
beta_cov = np.full((2, 2), 0.1) + np.eye(2)
x_mat = np.random.rand(10 * 2).reshape(N, 2)
std_vec = np.array([-0.8, -0.3, 0.3, 0.8])

LogitTerm, SetDataState, SetParamState = LogitTermClosure(beta, beta_cov, x_mat, std_vec)
print LogitTerm(beta, beta_cov, x_mat, std_vec)
print LogitTermAD(beta, beta_cov, x_mat, std_vec)

print LogitTerm(beta + 1, beta_cov, x_mat, std_vec)
print LogitTermAD(beta + 1, beta_cov, x_mat, std_vec)

print LogitTerm(beta + 1, beta_cov, x_mat * 0.1, std_vec)
print LogitTermAD(beta + 1, beta_cov, x_mat * 0.1, std_vec)


36.1203404731
36.1203404731
Refreshing parameter cache.  (So refreshing.)
69.6750717025
69.6750717025
Refreshing data cache.  (So refreshing.)
30.7931641682
30.7931641682


In [54]:
par_vec = Wrap(beta, beta_cov)
UnWrap(par_vec, 2)

LogitTermWrapADGrad = grad(LogitTermWrapAD, argnum=0)
LogitTermWrapGrad = grad(LogitTermWrap, argnum=0)
print np.max(np.abs(LogitTermWrapADGrad(par_vec, x_mat, std_vec) -
                    LogitTermWrapGrad(par_vec, x_mat, std_vec)))


3.5527136788e-15
3.5527136788e-15


[0, 10, 20, 1, 11, 21]