In [1]:
import autograd.numpy as np
from autograd.core import primitive
from autograd import grad, jacobian, hessian
from autograd.numpy.numpy_grads import unbroadcast
import scipy.stats

In [2]:
def LogitTermClosure(beta, beta_cov, x_mat, std_vec):
    class DataState:
        pass
    
    class ParamState:
        pass
    
    def HashData(x_mat, std_vec):
        # Better than nothing I guess.
        return hash(str(x_mat) + str(std_vec))
    
    data = DataState()
    par = ParamState()
    def SetDataState(x_mat, std_vec):
        data.x_mat = x_mat
        data.std_vec = std_vec
        data.x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
        data.data_hash = HashData(x_mat, std_vec)
        
    def SetParamState(beta, beta_cov):
        par.beta = beta
        par.K = beta.size
        assert beta_cov.shape == (par.K, par.K)
        par.beta_cov = beta_cov
        par.sigma = np.einsum('ijk,jk->i', data.x_outer, beta_cov)
        par.mu = np.einsum('ij,j->i', data.x_mat, beta)
        par.z = np.einsum('i,j->ij', par.sigma, data.std_vec) + \
                np.expand_dims(par.mu, 1)
        par.p = np.exp(par.z) / (1 + np.exp(par.z))
        par.p_1m_p = par.p * (1 - par.p)
        
    def CheckParCache(beta, beta_cov):
        if (beta != par.beta).any() or (beta_cov != par.beta_cov).any():
            print('Refreshing parameter cache.  (So refreshing.)')
            SetParamState(beta, beta_cov)

    def CheckDataCache(x_mat, std_vec):
        if HashData(x_mat, std_vec) != data.data_hash:
            print('Refreshing data cache.  (So refreshing.)')
            SetDataState(x_mat, std_vec)
            SetParamState(par.beta, par.beta_cov)
            
    SetDataState(x_mat, std_vec)
    SetParamState(beta, beta_cov)
    
    # Define the functions.
    
    # Only this will be accessible on the outside, so only here do we need the
    # cache check.
    @primitive
    def LogitTerm(beta, beta_cov, x_mat, std_vec):
        CheckDataCache(x_mat, std_vec)
        CheckParCache(beta, beta_cov)
        return np.sum(np.log(1 + np.exp(par.z)))

    # We will simplify the formulas by expressing gradients as functions of
    # weighted sums over p, since only p depends on the parameters.
    @primitive
    def WeightedPSum(beta, beta_cov, weights):
        return np.sum(par.p * weights)

    @primitive
    def WeightedPSum_grad_beta(weights):
        return np.einsum('ij,ij,ik->k', par.p_1m_p, weights, data.x_mat)
    @primitive
    def WeightedPSum_vjp_beta(g, ans, vs, gvs, beta, beta_cov, weights):
        return g * WeightedPSum_grad_beta(weights)
    WeightedPSum.defvjp(WeightedPSum_vjp_beta, argnum=0)

    @primitive
    def WeightedPSum_grad_beta_cov(weights):
        return np.einsum('ij,ij,j,iab->ab', par.p_1m_p, weights, data.std_vec, data.x_outer)
    @primitive
    def WeightedPSum_vjp_beta_cov(g, ans, vs, gvs, beta, beta_cov, weights):
        return g * WeightedPSum_grad_beta_cov(weights)
    WeightedPSum.defvjp(WeightedPSum_vjp_beta_cov, argnum=1)

    # Here and below, this is not primitive to inherit derivtatives from WeightedPSum
    #@primitive
    def LogitTerm_grad_beta_term(beta, beta_cov, a):
        weights = np.expand_dims(data.x_mat[:, a], 1)
        return WeightedPSum(beta, beta_cov, weights)

    #@primitive
    def LogitTerm_grad_beta(beta, beta_cov):
        return np.array([ LogitTerm_grad_beta_term(beta, beta_cov, a) for a in range(par.K)])

    #@primitive
    def LogitTerm_vjp_beta(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
        return g * LogitTerm_grad_beta(beta, beta_cov)
    LogitTerm.defvjp(LogitTerm_vjp_beta, argnum=0)
    
    #@primitive
    def LogitTerm_grad_beta_cov_term(beta, beta_cov, a, b):
        z_grad = np.einsum('j,i->ij', data.std_vec, data.x_outer[:, a, b])
        return WeightedPSum(beta, beta_cov, z_grad)

    #@primitive
    def LogitTerm_grad_beta_cov(beta, beta_cov):
        GradTerm = lambda a, b: LogitTerm_grad_beta_cov_term(beta, beta_cov, a, b)
        return np.array([[ GradTerm(a, b) for a in range(par.K) ] for b in range(par.K) ])

    #@primitive
    def LogitTerm_vjp_beta_cov(g, ans, vs, gvs, beta, beta_cov, x_mat, std_vec):
        return g * LogitTerm_grad_beta_cov(beta, beta_cov)
    LogitTerm.defvjp(LogitTerm_vjp_beta_cov, argnum=1)


    return LogitTerm, SetDataState, SetParamState


In [15]:
def LogitTermAD(beta, beta_cov, x_mat, std_vec):
    x_outer = np.einsum('ij,ik->ijk', x_mat, x_mat)
    sigma = np.einsum('ijk,jk->i', x_outer, beta_cov)
    mu = np.einsum('ij,j->i', x_mat, beta)
    z = np.einsum('i,j->ij', sigma, std_vec) + np.expand_dims(mu, 1)
    return np.sum(np.log(1 + np.exp(z)))

# Wrapping functions
def UnWrap(par_vec, K):
    beta = par_vec[0:K]
    beta_cov = par_vec[K:].reshape(K, K)
    return beta, beta_cov

def Wrap(beta, beta_cov):
    K = beta.size
    return np.concatenate((beta, beta_cov.ravel()))

def LogitTermWrap(par_vec, x_mat, std_vec):
    K = x_mat.shape[1]
    beta, beta_cov = UnWrap(par_vec, K)
    return LogitTerm(beta, beta_cov, x_mat, std_vec)

def LogitTermWrapAD(par_vec, x_mat, std_vec):
    K = x_mat.shape[1]
    beta, beta_cov = UnWrap(par_vec, K)
    return LogitTermAD(beta, beta_cov, x_mat, std_vec)


In [18]:
N = 1000
K = 5
beta = np.random.rand(K)
beta_cov = np.full((K, K), 0.1) + np.eye(K)
x_mat = np.random.rand(N * K).reshape(N, K)
std_vec = np.array([-0.8, -0.3, 0.3, 0.8])

LogitTerm, SetDataState, SetParamState = LogitTermClosure(beta, beta_cov, x_mat, std_vec)

print LogitTerm(beta, beta_cov, x_mat, std_vec)
print LogitTermAD(beta, beta_cov, x_mat, std_vec)

print LogitTerm(beta + 1, beta_cov, x_mat, std_vec)
print LogitTermAD(beta + 1, beta_cov, x_mat, std_vec)

print LogitTerm(beta + 1, beta_cov, x_mat * 0.1, std_vec)
print LogitTermAD(beta + 1, beta_cov, x_mat * 0.1, std_vec)


6380.39646398
6380.39646398
Refreshing parameter cache.  (So refreshing.)
14690.1153357
14690.1153357
Refreshing data cache.  (So refreshing.)
3562.45966451
3562.45966451


In [19]:
par_vec = Wrap(beta, beta_cov)
UnWrap(par_vec, K)

print 'Grads:'
LogitTermWrapADGrad = grad(LogitTermWrapAD)
LogitTermWrapGrad = grad(LogitTermWrap)
print np.max(np.abs(LogitTermWrapADGrad(par_vec, x_mat, std_vec) -
                    LogitTermWrapGrad(par_vec, x_mat, std_vec)))


print 'Hessians:'
LogitTermWrapADHess = hessian(LogitTermWrapAD)
LogitTermWrapHess = hessian(LogitTermWrap)
print np.max(np.abs(LogitTermWrapADHess(par_vec, x_mat, std_vec) -
                    LogitTermWrapHess(par_vec, x_mat, std_vec)))




Refreshing data cache.  (So refreshing.)
Refreshing parameter cache.  (So refreshing.)
2.50111042988e-12
2.84217094304e-14


In [11]:
import timeit

time_num = 5
print 'Hessians:\n'
print timeit.timeit(lambda: LogitTermWrapHess(par_vec, x_mat, std_vec), number=time_num) / time_num
print timeit.timeit(lambda: LogitTermWrapADHess(par_vec, x_mat, std_vec), number=time_num) / time_num



Hessians:

0.191372203827
0.0229320049286
