In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 

In [None]:
# Q1: reading and formatting dataset 

# Q1.1: read German credit dataset 
dataset = pd.read_csv('GermanCredit.txt', delim_whitespace = True, header = None)
dataset.shape

# Q1.2: credit risk data 
m = 800 
n = 200
ytrain = np.array(dataset.iloc[0:m,-1]) - 1
ytest = np.array(dataset.iloc[m:(m+n),-1]) - 1

# Q1.3: center and scale features 
xdata = np.array(dataset.iloc[:, 0:24])
xdata = (xdata - np.mean(xdata, axis = 0)) / np.std(xdata, axis = 0)
xtrain = xdata[0:m, :]
xtest = xdata[m:(m+n), :]

# Q1.4 add a column of ones 
xtrain = np.concatenate((np.ones([m, 1]), xtrain), axis = 1)
xtest = np.concatenate((np.ones([n, 1]), xtest), axis = 1)

# dimension
D = 25

In [None]:
# install jax package
import sys
!{sys.executable} -m pip install jax 

In [None]:
# Q5 write log-likelihood function in numpy 
def loglikelihood_numpy(beta):
    x_beta = np.matmul(xtrain, beta)
    output = np.sum(ytrain * x_beta - np.log(1 + np.exp(x_beta)))
    return output

# simulate regression coeff
beta = np.random.randn(D)  
print(loglikelihood_numpy(beta))


In [None]:
import jax.numpy as jnp
from jax import jit

# write log-likelihood function in a JAX-compatiable way
@jit
def loglikelihood_jax(beta):
    x_beta = jnp.matmul(xtrain, beta)
    output = jnp.sum(ytrain * x_beta - jnp.log(1 + jnp.exp(x_beta)))
    return output

# check that we get the same result as the numpy implementation
print(loglikelihood_jax(beta))

# there is no point using JIT for the log-likelihood
jit_loglikelihood_jax = jit(loglikelihood_jax)

import time
%timeit jit_loglikelihood_jax(beta) # timing with JIT 
%timeit loglikelihood_jax(beta) # timing without JIT

# use the implementation without JIT
loglikelihood = loglikelihood_jax


In [None]:
# Q6: auto-diff to get the gradient of the log-likelihood
from jax import grad
gradloglikelihood1 = grad(loglikelihood_jax) # without JIT
gradloglikelihood2 = jit(grad(loglikelihood_jax)) # with JIT 

# for the gradient, there is a signfiicant speedup with JIT
%timeit gradloglikelihood1(beta)
%timeit gradloglikelihood2(beta)

# use the faster implementation
gradloglikelihood = jit(grad(loglikelihood_jax))