This library is a set of personal implementations. It has a few goals: it compiles the data science tools that I use day-to-day in a single place, it's a place for me to practice coding to keep from getting too rusty, and it's a way for me to check my understanding of new models by implementing them.

The main convenience feature over `sklearn` or `statsmodels` is a port of the Python `stargazer` package that can produce latex tables displaying any of the linear models in the package side-by-side. It duplicates part of the `statsmodels` GLM functionality, and provides a general GLM class that the user can initialize with an arbitrary link function and `oryx` distribution.  The backend of the package is written in `jax`. Overhead is higher, but the package will outperform `statsmodels` and `sklearn` in large-sample or high-dimensional cases. See the demo notebook for specifics. 




This notebook walks through the core functionality of the package. 

# Basics and OLS

To begin, let's load the Longly dataset from `statsmodels`. `statjax` follows the `sklearn` convention of `model().fit(X,y)` rather than the `statsmodels` convention of model(y,X).fit(). Models can take dataframe, array, or ModelMatrix (see below) inputs. They will add intercepts unless add_intercept = False is passed in fit. 


In [1]:
import statsmodels.api as sm
import statjax as sj
from statjax import OLS


longley = sm.datasets.longley.load_pandas()
X = longley.exog
y = longley.endog

ols= OLS().fit(X, y)
ols_no_intercept= sj.OLS().fit(X, y, add_intercept=False)

The OLS model supports non-robust, heteroskedasticity-robust, and clustered standard errors following [Cameron and Miller (2015)](https://cameron.econ.ucdavis.edu/research/Cameron_Miller_JHR_2015_February.pdf).

In [2]:
ols_robust = OLS("robust").fit(X,y,)
decade = X["YEAR"] // 10
ols_clustered = OLS("clustered").fit(X, y, decade)

To visualize, use the table functionality, which closely follows the python `stargazer` package – see their documentation for table options. 

In [3]:
ols_table = sj.RegressionTable([ols , ols_robust, ols_clustered])
ols_table.custom_columns(["OLS", "OLS (robust)", "OLS (clustered)"])
ols_table.title("OLS regressions")
display(ols_table)

0,1,2,3
,,,
,(TOTEMP),(TOTEMP),(TOTEMP)
,,,
,OLS,OLS (robust),OLS (clustered)
,(1),(2),(3)
,,,
Intercept,-3482258.637***,-3482258.637***,-3482258.637***
,(4.69e-03),(6.08e-03),(1.38e-03)
GNPDEFL,15.062,15.062,15.062
,(83.113),(87.551),(45.657)


The tables can be converted to latex for display - see demo_table.pdf.

In [4]:
print(ols_table.render_latex())

\begin{tabular}{@{\extracolsep{5pt}}lccc}
\\[-1.8ex]\hline
\hline \\[-1.8ex]
\textit{Dependent variable: } &  \multicolumn{3}{c}{TOTEMP} 
% \\[-1.8ex]
\cr \cline{2-4}
\\[-1.8ex]\\[-1.8ex] & \multicolumn{1}{c}{OLS} & \multicolumn{1}{c}{OLS (robust)} & \multicolumn{1}{c}{OLS (clustered)}  \\
\\[-1.8ex] & (1) & (2) & (3) \\
\hline \\[-1.8ex]
 Intercept & -3482258.637$^{***}$ & -3482258.637$^{***}$ & -3482258.637$^{***}$ \\
& (4.69e-03) & (6.08e-03) & (1.38e-03) \\
 GNPDEFL & 15.062$^{}$ & 15.062$^{}$ & 15.062$^{}$ \\
& (83.113) & (87.551) & (45.657) \\
 GNP & -0.036$^{*}$ & -0.036$^{}$ & -0.036$^{***}$ \\
& (0.019) & (0.022) & (6.48e-03) \\
 UNEMP & -2.020$^{***}$ & -2.020$^{***}$ & -2.020$^{***}$ \\
& (0.268) & (0.300) & (0.073) \\
 ARMED & -1.033$^{***}$ & -1.033$^{***}$ & -1.033$^{***}$ \\
& (0.179) & (0.110) & (0.063) \\
 POP & -0.051$^{}$ & -0.051$^{}$ & -0.051$^{}$ \\
& (0.206) & (0.261) & (0.057) \\
 YEAR & 1829.151$^{***}$ & 1829.151$^{***}$ & 1829.151$^{***}$ \\
& (11.349) & (13.

RegressionTable can take either a single model or a list of models. If arrays are passed, the table will substitute variable names. The table will produce rows for all variables. 

In [5]:
ols_from_arrays = OLS().fit(X.values, y.values)
sj.RegressionTable([ols, ols_from_arrays,ols_no_intercept])

0,1,2,3
,,,
,(TOTEMP),(y),(TOTEMP)
,,,
,(1),(2),(3)
,,,
Intercept,-3482258.637***,-3482258.637***,
,(4.69e-03),(4.69e-03),
GNPDEFL,15.062,,-52.994
,(83.113),,(129.545)
GNP,-0.036*,,0.071**


The final input type supported is formulaic model_matrix. If a model_matrix is passed, the model will predict untransformed data using that model_matrix, meaning that the model will automatically transform untransformed data if passed. The variable names will automatically export. 

In [6]:
from formulaic import model_matrix

y_f,X_f = model_matrix("TOTEMP ~ GNPDEFL + GNP + ARMED:UNEMP", longley.data )

formula_model = sj.OLS().fit(X_f, y_f, add_intercept=False)
print(f"predictions are equal: {bool((formula_model.predict(X_f)==formula_model.predict(longley.data)).prod())}")
display(sj.RegressionTable(formula_model))

predictions are equal: True


0,1
,
,(TOTEMP)
,
,(1)
,
Intercept,49160.443***
,(0.159)
GNPDEFL,39.530***
,(9.652)
GNP,0.037***


 However, it will not support model matrices with different formulas. 

In [7]:
different_formula_X = model_matrix("GNPDEFL + GNP + ARMED*UNEMP", longley.data)

try:
    different_formula_X = model_matrix("GNPDEFL + GNP + ARMED*UNEMP", longley.data)
    formula_model.predict(different_formula_X)
except Exception as e:
    print(f"error: {e}")

error: Predictor matrix has different features than those used to fit the model.


# Linear Models
## GLM

All of the GLM implementations follow Agresti, Foundations of Linear and Generalized Linear Models (2015). The models use either Fisher scoring or iterative least squares to fit $\beta$, then Newton-Raphson to fit any other parameters of the distribution. The `oryx` library provides the infrastructure for random variables.

Normal, Bernoulli, Poisson, Gamma, and Inverse Normal GLMS are supported by default. Certain link functions can be imported from glm as sj.glm.link, currently identity_link, log_link, inverse_link, logit_link, probit_link, and inverse_squared_link. 


In [20]:
from statjax.glm import PoissonGLM, GammaGLM, InverseNormalGLM, NormalGLM

scotland =  sm.datasets.scotland.load()
X = scotland.exog
y = scotland.endog

nglm = NormalGLM().fit(X,y) 

poisson_id = PoissonGLM(link=sj.glm.identity_link).fit(X,y, )
poisson_log = PoissonGLM(link = sj.glm.log_link).fit(X,y)
glm_gamma_new = GammaGLM().fit(X,y)
inv_gauss = InverseNormalGLM().fit(X,y)
inv_gauss2 = InverseNormalGLM(link = sj.glm.identity_link).fit(X,y)

glm_table = sj.RegressionTable([nglm,poisson_id, poisson_log,glm_gamma_new,inv_gauss,inv_gauss2])
glm_table.custom_columns(["Normal", "Poisson (Identity)", "Poisson (Log)", "Gamma", "Inverse Gaussian", "Inverse Gaussian (Identity)"])
glm_table.title("Base GLMS")
glm_table


0,1,2,3,4,5,6
,,,,,,
,(YES),(YES),(YES),(YES),(YES),(YES)
,,,,,,
,Normal,Poisson (Identity),Poisson (Log),Gamma,Inverse Gaussian,Inverse Gaussian (Identity)
,(1),(2),(3),(4),(5),(6)
,,,,,,
Intercept,137.414***,129.525,5.779***,-0.018*,-1.07e-03***,113.638***
,(35.439),(87.264),(1.477),(0.010),(3.40e-04),(35.800)
COUTAX,-0.116**,-0.109,-2.47e-03,4.96e-05***,1.91e-06***,-0.093*
,(0.050),(0.123),(2.09e-03),(1.42e-05),(4.77e-07),(0.050)


The GLM framework allows the definition of custom GLMs from a link function and distribution. The initial guess of parameters must have the correct shape since the GLM will not infer how many parameters or the shape of those parameters. The link function must have the form  $g: S\to \R$ where $S$ is the support of the model, and must be invertible by `oryx`. If it is not invertible, the user can specific a custom inverse: see the `oryx` documentation for more information. 
The first parameter of the specified distribution must be the mean of the distribution or whatever parameter is equal to $g^{-1} (\mathbf X \beta)$. 

By default, the model will use weighted least squares to fit the model. If this fails to train, `sj.glm.fit_glm_gradient` is a more robust alternative. However, it is sensitive to the initial guess of beta, unlike the least squares fitting procedure. 

Note that in the below demonstration, the least squares inverse gaussian model produces a slightly different result: with non-canonical link functions, the least squares and gradient-based algorithms do not necessarily converge to the same solution.


In [9]:
from src.statjax.glm import GLM
from oryx.distributions import Poisson, InverseGaussian
import jax.numpy as jnp

identity = lambda x: x
custom_inverse = lambda x: 1/x

custom_poisson = GLM(link = custom_inverse,
                     dist = Poisson, 
                     params_init = (jnp.zeros([X.shape[1]+1] ) ),
                    )

custom_inverse_gaussian_1 = GLM(link = identity,
                                dist = InverseGaussian,
                                params_init = (jnp.zeros([X.shape[1]+1] )+ 1e-8 , jnp.ones(1)  ),
                               )

custom_inverse_gaussian_2 = GLM(link = identity,
                              dist = InverseGaussian,
                              params_init = (jnp.zeros([X.shape[1]+1] )+ 1e-8 , jnp.ones(1)  ),
                              fit = sj.glm.fit_glm_gradient
                             )

custom_glm_table = sj.RegressionTable([custom_poisson.fit(X,y),
                                    PoissonGLM(link = sj.glm.inverse_link).fit(X,y),
                                    custom_inverse_gaussian_1.fit(X,y),
                                    custom_inverse_gaussian_2.fit(X,y),
                                    InverseNormalGLM(link = sj.glm.identity_link).fit(X,y)])

custom_glm_table.custom_columns(["Custom Poisson", "Poisson", "Custom Inverse Gaussian: least squares", "Custom Inverse Gaussian: gradient", "Inverse Gaussian "])
custom_glm_table.title_text = "Poisson with inverse link and Inverse Gaussian with identity link GLM Comparison"
display(custom_glm_table)

0,1,2,3,4,5
,,,,,
,(YES),(YES),(YES),(YES),(YES)
,,,,,
,Custom Poisson,Poisson,Custom Inverse Gaussian: least squares,Custom Inverse Gaussian: gradient,Inverse Gaussian
,(1),(2),(3),(4),(5)
,,,,,
Intercept,-0.019,-0.019,113.638***,114.002***,114.002***
,(0.025),(0.025),(35.800),(35.794),(35.794)
COUTAX,5.03e-05,5.03e-05,-0.093*,-0.094*,-0.094*
,(3.52e-05),(3.52e-05),(0.050),(0.050),(0.050)


The BernoulliGLM class uses a logit link by default. To use a probit link instead, simply use the method from `statjax.glm`. 

In [10]:

from statjax.glm import BernoulliGLM, probit_link

import sklearn.datasets
X2, y2 = sklearn.datasets.load_breast_cancer(return_X_y=True)
X2= X2[:, [2,3,6,7,8]]

logit_glm = BernoulliGLM().fit(X2,y2)
probit_glm = BernoulliGLM(link = probit_link).fit(X2,y2)
print(f"accuacy of logit and probit glms: {((((logit_glm.predict(X2)) > .5) == y2).mean(), ((probit_glm.predict((X2)) > .5) == y2).mean())}")
sj.RegressionTable([logit_glm, probit_glm])



accuacy of logit and probit glms: (Array(0.91915643, dtype=float32), Array(0.9209139, dtype=float32))


0,1,2
,,
,(y),(y)
,,
,(1),(2)
,,
Intercept,-1.077,-0.940
,(6.610),(3.842)
x0,0.254*,0.156*
,(0.153),(0.089)
x1,-0.027**,-0.016**


## Regularized and Gradient-Based Models

The package also provides access to basic regularized models, as well as functionality for the user to define linear models according to a predict, loss, and regularization function. The user can define arbitrary NLMs by loss, predict, and regularization in a similar way to the GLMs above, but the models tend to be unstable. 

In [11]:
import sklearn.datasets
from statjax.nlm import ElasticNet, LASSO
from statjax import Ridge

X,y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
r = sj.RegressionTable([ sj.OLS().fit(X,y), NormalGLM().fit(X,y), ElasticNet(100,100).fit(X,y), Ridge(100).fit(X,y), ElasticNet(0,100).fit(X,y), LASSO(100).fit(X,y)])
r.custom_columns(["OLS", "Normal GLM", "ElasticNet", "Ridge (Analytic)", "Ridge (Gradient)", "LASSO"])
display(r)


0,1,2,3,4,5,6
,,,,,,
,(target),(target),(target),(target),(target),(target)
,,,,,,
,OLS,Normal GLM,ElasticNet,Ridge (Analytic),Ridge (Gradient),LASSO
,(1),(2),(3),(4),(5),(6)
,,,,,,
Intercept,152.133***,152.133***,123.972,124.065,124.065,152.020
,(2.576),(2.544),,,,
age,-10.010,-10.010,2.387,2.897,2.870,-7.16e-04
,(59.749),(59.001),,,,


There's also a default neural network - more on that in the next section. 

# Causal Models

`statjax` currently offers four causal ate estimators. We'll first download the Lalonde dataset. 

In [12]:
import pandas as pd
# https://users.nber.org/~rdehejia/nswdata2.html

columns = ["training",   # Treatment assignment indicator
           "age",        # Age of participant
           "education",  # Years of education
           "black",      # Indicate whether individual is black
           "hispanic",   # Indicate whether individual is hispanic
           "married",    # Indicate whether individual is married
           "no_degree",  # Indicate if individual has no high-school diploma
           "re74",       # Real earnings in 1974, prior to study participation
           "re75",       # Real earnings in 1975, prior to study participation
           "re78"]       # Real earnings in 1978, after study end


file_names = ["http://www.nber.org/~rdehejia/data/nswre74_treated.txt",
              "http://www.nber.org/~rdehejia/data/nswre74_control.txt",
              "http://www.nber.org/~rdehejia/data/psid_controls.txt",
              "http://www.nber.org/~rdehejia/data/psid2_controls.txt",
              "http://www.nber.org/~rdehejia/data/psid3_controls.txt",
              "http://www.nber.org/~rdehejia/data/cps_controls.txt",
              "http://www.nber.org/~rdehejia/data/cps2_controls.txt",
              "http://www.nber.org/~rdehejia/data/cps3_controls.txt"]
files = [pd.read_csv(file_name, sep='\s+', header=None, names=columns) for file_name in file_names]
lalonde = pd.concat(files, ignore_index=True)


  files = [pd.read_csv(file_name, sep='\s+', header=None, names=columns) for file_name in file_names]


`statjax` follows the Rudin causal model, with $D$ indicating treatment status, $X$ indicating covariates, and $Y$ indicating outcomes. We remove points with features outside the range of that feature among the treated points:

In [13]:
from statjax import causal

D = lalonde[["training"]]
X = lalonde[["age", "education", "black", "hispanic", "married", "no_degree", "re74", "re75"]]
Y = lalonde[["re78"]]

in_overlap = causal.check_overlap(D,lalonde[list(X.columns)])

print(f"n violating overlap: {sum(~in_overlap)}")
D = D[in_overlap]
X = X[in_overlap]
Y = Y[in_overlap]

n violating overlap: 7407


The ExperimentalEstimator class assumes random assignment, and simply compares the group means. 

In [14]:
linreg = OLS().fit(D,Y)
linreg_controlled = OLS().fit(jnp.hstack([X.values, D.values]),Y)
print(f"ols coef on treatment w/o controls: {linreg.beta[1]}")
print(f"ols coef on treatment w/ controls: {linreg_controlled.beta[1]}")

naive_model = causal.ExperimentalEstimator().fit(D,Y)
exp_est = naive_model.ate
print(f"naive ate: {exp_est}")

ols coef on treatment w/o controls: -6497.048111659512
ols coef on treatment w/ controls: -107.58057145844275
naive ate: -6497.048111659312


The RegressionEstimator fits two regression models, one for each treatment outcome, and compares the difference in expectation. It defaults to linear estimators. 

In [15]:
regression_model = causal.RegressionEstimator().fit(D,X, Y)
print(f"ols imputation ate: {regression_model.ate}")

ols imputation ate: -2809.682457499978


It allows for more other regression models. While it takes longer to train, here we use the NNRegression model, which is a flexible neural net that can be used as a miscellaneous non-parametric model. 

In [16]:
from statjax.nn import NNRegression

nn_regression_model =  causal.RegressionEstimator(model=NNRegression(hidden_layers = (128,128,64,64)))
nn_regression_model.fit(D,X,Y)
print(f"neural imputation ate: {nn_regression_model.ate}")

neural imputation ate: 26136.701298586253


The PropensityScoreEstimator fits $p(D_i=1|X_i)$  then uses that as inverse weights. The default model is a logistic regression, but the user can specify alternatives.

Note that the model automatically prunes points with $p>1-\delta$ or $p<\delta$, with the default at initialization as $\delta = .1$. 

In [17]:
from jax.nn import sigmoid

logit_model = causal.PropensityScoreEstimator().fit(D,X,Y)
probit_model = causal.PropensityScoreEstimator(propensity_model=BernoulliGLM(link=probit_link)).fit(D,X,Y)

nn_ps_model = causal.PropensityScoreEstimator(propensity_model=NNRegression(hidden_layers = (32,32),output_activation=sigmoid)).fit(D,X,Y)

print(f"logit ate: {logit_model.ate}, probit ate: {probit_model.ate}, nn ate: {nn_ps_model.ate}")

logit ate: 2186.539131282006, probit ate: 2571.336146299483, nn ate: 1320.169593785673


Finally, `statjax` contains a doubly robust estimator, which can take a custom propensity_model and outcome_model. To demonstrate why pruning is important, we can fit a DRE without any pruning ($\delta = 0$) to see the effect on the ATE.  

As above, the bulk of the runtime is the outcome network. 

In [18]:
from statjax.causal import DREstimator
    
dre_model = DREstimator().fit(D,X,Y)
dre_model_unpruned = DREstimator(delta = 0.).fit(D,X,Y)

neural_dre = DREstimator(propensity_model=NNRegression(hidden_layers = (32,32),output_activation=sigmoid)).fit(D,X,Y)
fully_neural_dre = DREstimator(outcome_model=NNRegression(hidden_layers=(128, 128, 64, 64)),
                              propensity_model=NNRegression(hidden_layers = (32,32),output_activation=sigmoid),
                              ).fit(D,X,Y)

print(f"dre ate: {dre_model.ate}\nneural propensity dre ate: {neural_dre.ate}\nfully neural dre ate: {fully_neural_dre.ate}\nunpruned dre ate: {dre_model_unpruned.ate}")

dre ate: 2084.2925401808584
neural propensity dre ate: 1137.4334587308513
fully neural dre ate: 2421.439895598937
unpruned dre ate: 12717764449.22189


The submodels of any causal model can be accessed at model0/model1 for regression-based models and propensity_model for propensity scoring models. Here are all 3 from the DRE:

In [19]:
dre_table = sj.RegressionTable([dre_model.propensity_model, dre_model.model0, dre_model.model1])
dre_table.custom_columns(["P(D=1|X)", "E[Y|X,D=0]", "E[Y|X,D=1]"])
dre_table

0,1,2,3
,,,
,(training),(re78),(re78)
,,,
,P(D=1|X),"E[Y|X,D=0]","E[Y|X,D=1]"
,(1),(2),(3)
,,,
Intercept,-4.183**,4361.918***,-1508.424
,(1.669),(599.002),(6006.589)
age,0.023,-110.543***,83.562
,(0.016),(8.069),(85.387)
