## Logit Model with NUTS with / without autodiff from JAX

The details are from [here](https://cran.r-project.org/web/packages/hmclearn/vignettes/logistic_regression_hmclearn.html). 

The log posterior for logistic regression is given by the sum of the log likelihood and the log prior:

$$
\log p(\boldsymbol{\beta} | \mathbf{y}, \mathbf{X}) = \log p(\mathbf{y} | \mathbf{X}, \boldsymbol{\beta}) + \log p(\boldsymbol{\beta})
$$

The log likelihood for logistic regression is given by:

$$
\log p(\mathbf{y} | \mathbf{X}, \boldsymbol{\beta}) = \sum_{i=1}^{n} \left[ y_i \log \left( \frac{1}{1 + \exp(-\mathbf{x}_i^T \boldsymbol{\beta})} \right) + (1 - y_i) \log \left( 1 - \frac{1}{1 + \exp(-\mathbf{x}_i^T \boldsymbol{\beta})} \right) \right]
$$

The log prior for a Gaussian distribution is given by:

$$
\log p(\boldsymbol{\beta}) = -\frac{1}{2} \boldsymbol{\beta}^T \boldsymbol{\Sigma}^{-1} \boldsymbol{\beta} + \text{const}
$$

where $\boldsymbol{\Sigma}$ is the covariance matrix of the Gaussian prior.

The gradient of the log posterior with respect to $\boldsymbol{\beta}$ is given by the sum of the gradient of the log likelihood and the gradient of the log prior:

$$
\nabla \log p(\boldsymbol{\beta} | \mathbf{y}, \mathbf{X}) = \nabla \log p(\mathbf{y} | \mathbf{X}, \boldsymbol{\beta}) + \nabla \log p(\boldsymbol{\beta})
$$

The gradient of the log likelihood with respect to $\boldsymbol{\beta}$ is given by:

$$
\nabla \log p(\mathbf{y} | \mathbf{X}, \boldsymbol{\beta}) = \mathbf{X}^T (\mathbf{y} - \mathbf{p})
$$

where $\mathbf{p} = (p_1, ..., p_n)^T$ and $p_i = 1 / (1 + \exp(-\mathbf{x}_i^T \boldsymbol{\beta}))$.

The gradient of the log prior with respect to $\boldsymbol{\beta}$ is given by:

$$
\nabla \log p(\boldsymbol{\beta}) = -\boldsymbol{\Sigma}^{-1} \boldsymbol{\beta}
$$

In [23]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from jax.scipy.special import expit
from tqdm import tqdm

class ModelwithNUTS(object):
    
    def __init__(self):
        self.prng_key = None
        self.epsilon = None
    
    def log_likelihood(self):
        raise NotImplementedError
    
    def grad_ll(self):
        raise NotImplementedError
    
    def leapfrog(self, beta, p, epsilon):
        p_half_step = p + epsilon / 2 * self.grad_ll(beta)
        beta_new = beta + epsilon * p_half_step
        p_new = p_half_step + epsilon / 2 * self.grad_ll(beta_new)
        return beta_new, p_new
    
    def build_tree(self, u, v, j, epsilon, beta, p, r, Emax=1000):
        if j == 0:
            # Base case, take one leapfrog step in the direction v
            beta_prime, p_prime = self.leapfrog(beta, p, v*epsilon)
            if u <= jnp.exp(self.log_likelihood(beta_prime) - 0.5*jnp.dot(p_prime, p_prime)):
                n_prime = 1
            else:
                n_prime = 0
            s_prime = int(self.log_likelihood(beta_prime) - 0.5*jnp.dot(p_prime, p_prime) > u - Emax)
            return beta_prime, p_prime, beta_prime, p_prime, beta_prime, n_prime, s_prime
        else:
            # Recursion, build left and right subtrees
            beta_minus, p_minus, beta_plus, p_plus, beta_prime, n_prime, s_prime = self.build_tree(u, v, j-1, epsilon, beta, p, r)
            if s_prime == 1:
                if v == -1:
                    beta_minus, p_minus, _, _, beta_double_prime, n_double_prime, s_double_prime = self.build_tree(u, v, j-1, epsilon, beta_minus, p_minus, r)
                else:
                    _, _, beta_plus, p_plus, beta_double_prime, n_double_prime, s_double_prime = self.build_tree(u, v, j-1, epsilon, beta_plus, p_plus, r)
                if random.uniform(self.prng_key) < n_double_prime / max(n_prime + n_double_prime, 1):
                    beta_prime = beta_double_prime
                if jnp.dot(beta_plus-beta_minus, p_minus) >= 0 and jnp.dot(beta_plus-beta_minus, p_plus) >= 0:
                    s_prime = s_double_prime
                else:
                    s_prime = 0
                n_prime += n_double_prime
            return beta_minus, p_minus, beta_plus, p_plus, beta_prime, n_prime, s_prime
    
    def NUTS(self, current_beta):
        p = random.normal(self.prng_key, shape=current_beta.shape)
        u = random.uniform(self.prng_key, minval=0, maxval=jnp.exp(self.log_likelihood(current_beta) - 0.5*jnp.dot(p, p)))
        beta_minus = beta_plus = beta_prime = current_beta
        p_minus = p_plus = p
        j = 0
        n_prime = s_prime = 1
        r = 1e-10
        while s_prime == 1:
            v = random.choice(self.prng_key, a=jnp.array([-1, 1]))
            if v == -1:
                beta_minus, p_minus, _, _, beta_prime, n_prime, s_prime = self.build_tree(u, v, j, self.epsilon, beta_minus, p_minus, r)
            else:
                _, _, beta_plus, p_plus, beta_prime, n_prime, s_prime = self.build_tree(u, v, j, self.epsilon, beta_plus, p_plus, r)
            r += n_prime
            if s_prime == 1 and random.uniform(self.prng_key) < min(1, n_prime / r):
                current_beta = beta_prime
            j += 1
        return current_beta
    
    
class LogisticRegression(ModelwithNUTS):
    def __init__(self, X, y, initial_beta, seed=0, epsilon=0.01):
        self.X = X
        self.y = y
        self.initial_beta = initial_beta
        self.prng_key = random.PRNGKey(seed)
        self.epsilon = epsilon
        
    def log_likelihood(self, beta):
        z = jnp.dot(self.X, beta)
        return jnp.dot(self.y, z) - jnp.log(1 + jnp.exp(z)).sum()
    
    def grad_ll(self, beta):
        return grad(self.log_likelihood)(beta)
    
    def fit(self, n_sample=100):
        self.beta = [self.initial_beta]
        for i in tqdm(range(n_sample)):
            self.beta.append(self.NUTS(self.beta[-1]))
        return self.beta
    
class LogisticRegression_noAutoDiff(ModelwithNUTS):
    def __init__(self, X, y, initial_beta, seed=0, epsilon=0.01):
        self.X = X
        self.y = y
        self.initial_beta = initial_beta
        self.prng_key = random.PRNGKey(seed)
        self.epsilon = epsilon
        
    def log_likelihood(self, beta):
        z = jnp.dot(self.X, beta)
        return jnp.dot(self.y, z) - jnp.log(1 + jnp.exp(z)).sum()
    
    def grad_ll(self, beta):
        z = jnp.dot(self.X, beta)
        return jnp.dot(self.X.T, (self.y - expit(z)))
    
    def fit(self, n_sample=100):
        self.beta = [self.initial_beta]
        for i in tqdm(range(n_sample)):
            self.beta.append(self.NUTS(self.beta[-1]))
        return self.beta

In [17]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

X = np.random.normal(size=(100, 2))
X = np.hstack([np.ones((100, 1)), X])  # add intercept
true_beta = np.array([0, 1, 2])
y = np.random.binomial(1, sigmoid(X @ true_beta))

initial_beta = np.zeros(X.shape[1])

model = LogisticRegression(X, y, jnp.array(initial_beta))

In [20]:
%%time
_ = model.fit(100)

100%|██████████| 100/100 [03:25<00:00,  2.05s/it]

CPU times: user 2min 47s, sys: 1min 16s, total: 4min 3s
Wall time: 3min 25s





In [24]:
model = LogisticRegression_noAutoDiff(X, y, jnp.array(initial_beta))

In [25]:
%%time
_ = model.fit(100)

100%|██████████| 100/100 [01:51<00:00,  1.11s/it]

CPU times: user 1min 9s, sys: 1min 4s, total: 2min 14s
Wall time: 1min 51s



