In [1]:
import numpy as np
import scipy.stats

In [2]:
import pandas as pd
pd.set_option("display.precision", 3)

In [3]:
df = pd.read_csv('CA_deaths_2015-2022.csv')

In [4]:
df.head()

Unnamed: 0,2015deaths,2016deaths,2017deaths,2018deaths,2019deaths,2020deaths,2021deaths,2022deaths
0,7.0,8.0,14.0,19.0,22.0,24.0,19.0,17.0
1,5.0,5.0,13.0,9.0,8.0,24.0,11.0,16.0
2,5.0,7.0,9.0,10.0,13.0,8.0,12.0,16.0
3,3.0,7.0,5.0,4.0,7.0,13.0,8.0,11.0
4,4.0,6.0,8.0,8.0,8.0,11.0,8.0,8.0


In [5]:
df.tail()

Unnamed: 0,2015deaths,2016deaths,2017deaths,2018deaths,2019deaths,2020deaths,2021deaths,2022deaths
1323,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1324,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1325,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1326,0.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0
1327,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
df.values.max()

24.0

In [7]:
A = df.values[:]

In [8]:
A[:4]

array([[ 7.,  8., 14., 19., 22., 24., 19., 17.],
       [ 5.,  5., 13.,  9.,  8., 24., 11., 16.],
       [ 5.,  7.,  9., 10., 13.,  8., 12., 16.],
       [ 3.,  7.,  5.,  4.,  7., 13.,  8., 11.]])

In [9]:
T = 2 # num test
L = 2 # num lookback


x_list = []
y_list = []

x_te_list = []
y_te_list = []

for p in range(L, A.shape[1]):
    x_N2 = A[:, (p-L):p]
    y_N = A[:, p][:,np.newaxis]

    if p > A.shape[1] - T:
        # test set
        x_te_list.append(x_N2)
        y_te_list.append(y_N)
    else:
        x_list.append(x_N2)
        y_list.append(y_N)
        

In [10]:
train_x_N2 = np.vstack(x_list)
train_y_N1 = np.vstack(y_list)

test_x_N2 = np.vstack(x_te_list)
test_y_N1 = np.vstack(y_te_list)

In [11]:
import sklearn.linear_model

In [12]:
model = sklearn.linear_model.LinearRegression()

In [13]:
model.fit(train_x_N2, train_y_N1)

LinearRegression()

In [14]:
model.coef_

array([[0.42367552, 0.47476729]])

In [15]:
model.intercept_

array([0.27662864])

In [16]:
yhat_N1 = model.predict(test_x_N2)

In [17]:
yhat_N1.max()

19.465419706589834

In [18]:
yhat_N1.min()

0.2766286440512751

In [19]:
test_x_N2[:5]

array([[24., 19.],
       [24., 11.],
       [ 8., 12.],
       [13.,  8.],
       [11.,  8.]])

In [20]:
yhat_N1[:5]

array([[19.46541971],
       [15.66728136],
       [ 9.36324033],
       [ 9.58254876],
       [ 8.73519771]])

## Try to build jax model

In [21]:
import numpy as np

from jax import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')

import jax
import jax.numpy as jnp


import scipy.optimize

In [22]:
def calc_neg_log_lik(theta, x_N2, y_N, to_pos=jnp.exp, reduce=True):
    N = x_N2.shape[0]
    y_N = jnp.reshape(y_N, (N,))
    mu_N = to_pos(theta[0] + jnp.dot(x_N2, theta[1:]))
    assert mu_N.shape == (N,)
    assert y_N.shape == mu_N.shape
    loglik_N = -mu_N + y_N * jnp.log(mu_N)
    if reduce:
        return -1.0 * jnp.mean(loglik_N)
    else:
        return -1.0 * loglik_N

class MyPoissonGLM():
    
    def __init__(self, init_theta=np.zeros(3), to_pos=jnp.exp, theta=None):
        self.to_pos = to_pos
        self.init_theta = init_theta
        if theta is not None:
            self.theta = theta
            self.coef_ = theta[1:].copy()
            self.intercept_ = theta[0].copy()
        
    def fit(self, x_N2, y_N):
        calc_grad = jax.grad(calc_neg_log_lik, argnums=[0])

        def f(theta, *args):
            return np.asarray(calc_neg_log_lik(theta, *args), dtype=np.float64, order='F').item()
        def g(theta, *args):
            return np.asarray(calc_grad(theta, *args)[0], dtype=np.float64, order='F').copy()
        ans = scipy.optimize.minimize(
            f, self.init_theta, args=(x_N2, y_N, self.to_pos), jac=g,
            method='L-BFGS-B',
            options={'ftol':1e-13, 'gtol':1e-14})
        self.ans = ans
        self.theta = ans.x
        self.intercept_ = ans.x[0].copy()
        self.coef_ = ans.x[1:].copy()
        return self
        
    def predict(self, x_N2):
        return self.to_pos(self.intercept_ + np.dot(x_N2, self.coef_))
    
    def score(self, x_N2, y_N, eval_logpmf_method='scipy'):
        if eval_logpmf_method.count('scipy'):
            mu_N = self.predict(x_N2)
            return np.mean([-1.0 * scipy.stats.poisson(mu).logpmf(y) for (mu, y) in zip(mu_N, y_N)])
        else:
            return calc_neg_log_lik(self.theta, x_N2, y_N, to_pos=self.to_pos).item()

# Common evaluation metrics

In [23]:
def make_model_perf_df(model, method_name=None, eval_logpmf_method='scipy'):
    yhat_stats_list = list()
    for split_label, x, y in [("train", train_x_N2, train_y_N1), ("test", test_x_N2, test_y_N1)]:

        rmse = np.sqrt(np.mean(np.square(y-model.predict(x))))
        mae = np.mean(np.abs(y-model.predict(x)))
        neglogpmf = model.score(x, y, eval_logpmf_method)

        yhat_df = pd.DataFrame(model.predict(x))
        yhat_stats_df = yhat_df.describe(percentiles=[0.01, 0.1, 0.9, 0.99]).T.copy()
        
        yhat_stats_df.insert(0, 'rmse', rmse)
        yhat_stats_df.insert(0, 'mae', mae)
        yhat_stats_df.insert(0, 'neglogpmf', neglogpmf)
        yhat_stats_df.insert(0, 'split', split_label)
        if method_name is not None:
            yhat_stats_df.insert(0, 'method', method_name)
        yhat_stats_list.append(yhat_stats_df)
        
    df = pd.concat(yhat_stats_list).reset_index(drop=True).copy()
    return df

In [24]:
eval_logpmf_method = 'scipy'

In [25]:
cols_to_print = ['method', 'split', 'neglogpmf', 'mae', 'rmse', 'min', 'mean', 'max']

## sklearn PoissonRegressor model with alpha = 0.0 (corresponds to max lik on train)

In [26]:
sk_model_alph0 = sklearn.linear_model.PoissonRegressor(alpha=0.0)

In [27]:
sk_model_alph0.fit(train_x_N2, train_y_N1[:,0])

PoissonRegressor(alpha=0.0)

In [28]:
my_sk0_model = MyPoissonGLM(theta=np.hstack([sk_model_alph0.intercept_, sk_model_alph0.coef_]))
perf_sk_alpha0 = make_model_perf_df(my_sk0_model, 'sk_alpha0', eval_logpmf_method)
print(perf_sk_alpha0[cols_to_print])

      method  split  neglogpmf    mae   rmse    min   mean      max
0  sk_alpha0  train      1.389  1.123  2.633  0.789  1.034  134.872
1  sk_alpha0   test      1.577  1.489  3.322  0.789  1.256   83.123


## sklearn PoissonRegressor model with alpha default

In [29]:
sk_model_alph1 = sklearn.linear_model.PoissonRegressor(alpha=1.0)

In [30]:
sk_model_alph1.fit(train_x_N2, train_y_N1[:,0])

PoissonRegressor()

In [31]:
my_sk1_model = MyPoissonGLM(theta=np.hstack([sk_model_alph1.intercept_, sk_model_alph1.coef_]))
perf_sk_alpha1 = make_model_perf_df(my_sk1_model, 'sk_alpha1', eval_logpmf_method)
print(perf_sk_alpha1[cols_to_print])

      method  split  neglogpmf    mae   rmse    min   mean      max
0  sk_alpha1  train      1.390  1.116  2.468  0.798  1.034  120.073
1  sk_alpha1   test      1.576  1.475  3.176  0.798  1.247   76.787


## Flat model (intercept set to train mean, coefs all zero)

In [32]:
flat_model = MyPoissonGLM(theta=np.asarray([np.log(np.mean(train_y_N1)), 0., 0.]))

In [33]:
perf_flat = make_model_perf_df(flat_model, 'flat', eval_logpmf_method)
print(perf_flat[cols_to_print])

  method  split  neglogpmf    mae   rmse    min   mean    max
0   flat  train      1.594  1.037  1.657  1.034  1.034  1.034
1   flat   test      1.977  1.264  2.079  1.034  1.034  1.034


## Existing model

>  learned coefficients are [0.13583314, 0.13693048] and the intercept is -0.457

In [34]:
cur_model = MyPoissonGLM(theta=np.asarray([-0.457, 0.1358, 0.1369]))

In [35]:
perf_cur = make_model_perf_df(cur_model, 'fromslack', eval_logpmf_method)
print(perf_cur[cols_to_print])

      method  split  neglogpmf    mae   rmse    min   mean      max
0  fromslack  train      1.422  1.166  5.006  0.633  0.939  335.694
1  fromslack   test      1.700  1.665  6.867  0.633  1.293  222.138


## Our JAX model with link = softplus

In [36]:
our_sp_model = MyPoissonGLM(to_pos=jax.nn.softplus)
our_sp_model.fit(train_x_N2, train_y_N1)

<__main__.MyPoissonGLM at 0x7fedc8ced5b0>

In [37]:
perf_ours_sp = make_model_perf_df(our_sp_model, 'ours_sp', eval_logpmf_method)
print(perf_ours_sp[cols_to_print])

    method  split  neglogpmf    mae   rmse    min   mean     max
0  ours_sp  train      1.290  1.221  1.967  0.469  1.031  24.399
1  ours_sp   test      1.422  1.674  2.669  0.469  1.436  22.542


### Our JAX with link = log (should match the sklearn implementation exactly)

In [38]:
our_log_model = MyPoissonGLM(to_pos=jnp.exp)

In [39]:
our_log_model.fit(train_x_N2, train_y_N1)

<__main__.MyPoissonGLM at 0x7fedca9939a0>

In [40]:
perf_ours_log = make_model_perf_df(our_log_model, 'ours_log', eval_logpmf_method)
print(perf_ours_log[cols_to_print])

     method  split  neglogpmf    mae   rmse    min   mean      max
0  ours_log  train      1.389  1.123  2.633  0.789  1.034  134.872
1  ours_log   test      1.577  1.489  3.322  0.789  1.256   83.123


In [41]:
agg_df = pd.concat([perf_cur, perf_sk_alpha0, perf_ours_log, perf_ours_sp]).copy()

print(agg_df.query("split == 'test'")[cols_to_print].to_string(index=False))

   method split  neglogpmf   mae  rmse   min  mean     max
fromslack  test      1.700 1.665 6.867 0.633 1.293 222.138
sk_alpha0  test      1.577 1.489 3.322 0.789 1.256  83.123
 ours_log  test      1.577 1.489 3.322 0.789 1.256  83.123
  ours_sp  test      1.422 1.674 2.669 0.469 1.436  22.542


## Final numbers

(just copy-pasted from above)

```
                       link  negloglik   MAE  RMSE max(yhat)
copied from slack       log      1.700 1.665 6.867  222.138
using sklearn PoiReg    log      1.577 1.489 3.322   83.123
using our JAX           log      1.577 1.489 3.322   83.123
using our JAX      softplus      1.422 1.674 2.669   22.542
```