GAMs using statsmodels

In [1]:
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.gam.api import GLMGam, BSplines
from statsmodels.gam.tests.test_penalized import df_autos
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

The following illustrates a Gaussian and a Poisson regression where categorical variables are treated as linear terms and the effect of two explanatory variables is captured by penalized B-splines. 

The data is from the automobile dataset https://archive.ics.uci.edu/ml/datasets/automobile 

We can load a dataframe with selected columns from the unit test module.

In [4]:
df_autos.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 203 entries, 0 to 204
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   city_mpg  203 non-null    int64  
 1   fuel      203 non-null    object 
 2   drive     203 non-null    object 
 3   weight    203 non-null    int64  
 4   hp        203 non-null    float64
dtypes: float64(1), int64(2), object(2)
memory usage: 9.5+ KB


In [5]:
df_autos['drive'].value_counts()

fwd    118
rwd     76
4wd      9
Name: drive, dtype: int64

In [6]:
df_autos['fuel'].value_counts()

gas       183
diesel     20
Name: fuel, dtype: int64

In [None]:
df_autos.describe()

In [None]:
sns.boxplot(y='weight', data=df_autos)

In [None]:
sns.boxenplot(y='weight', data=df_autos)

In [None]:
sns.boxplot(y='hp', data=df_autos)

In [None]:
sns.boxenplot(y='hp', data=df_autos)

In [2]:
# create spline basis for weight and hp
x_spline = df_autos[['weight', 'hp']]
bs = BSplines(x_spline, df=[12, 10], degree=[3, 3])

In [7]:
formula1 = 'city_mpg ~ C(fuel) + C(drive)'
# penalization weight
alpha_1 = np.array([21833888.8, 6460.38479])

gam_bs_1 = GLMGam.from_formula(formula1, 
                             data=df_autos,
                             smoother=bs, 
                             alpha=alpha_1)

res_bs_1 = gam_bs_1.fit()

print(res_bs_1.summary())

                 Generalized Linear Model Regression Results                  
Dep. Variable:               city_mpg   No. Observations:                  203
Model:                         GLMGam   Df Residuals:                   189.13
Model Family:                Gaussian   Df Model:                        12.87
Link Function:               identity   Scale:                          4.8825
Method:                         PIRLS   Log-Likelihood:                -441.81
Date:                Mon, 09 Aug 2021   Deviance:                       923.45
Time:                        00:35:28   Pearson chi2:                     923.
No. Iterations:                     3                                         
Covariance Type:            nonrobust                                         
                      coef    std err          z      P>|z|      [0.025      0.975]
-----------------------------------------------------------------------------------
Intercept          51.9923      1.997     

In [8]:
OLS = smf.ols(formula='city_mpg ~ C(fuel) + C(drive) + weight + hp', data=df_autos).fit()
OLS.summary()

0,1,2,3
Dep. Variable:,city_mpg,R-squared:,0.788
Model:,OLS,Adj. R-squared:,0.782
Method:,Least Squares,F-statistic:,146.0
Date:,"Mon, 09 Aug 2021",Prob (F-statistic):,2.9000000000000002e-64
Time:,00:35:42,Log-Likelihood:,-512.5
No. Observations:,203,AIC:,1037.0
Df Residuals:,197,BIC:,1057.0
Df Model:,5,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,54.3232,2.309,23.523,0.000,49.769,58.877
C(fuel)[T.gas],-7.4126,0.870,-8.522,0.000,-9.128,-5.697
C(drive)[T.fwd],1.7066,1.092,1.562,0.120,-0.448,3.861
C(drive)[T.rwd],1.1618,1.126,1.031,0.304,-1.060,3.383
weight,-0.0074,0.001,-9.119,0.000,-0.009,-0.006
hp,-0.0471,0.010,-4.833,0.000,-0.066,-0.028

0,1,2,3
Omnibus:,60.405,Durbin-Watson:,1.286
Prob(Omnibus):,0.0,Jarque-Bera (JB):,241.199
Skew:,1.114,Prob(JB):,4.21e-53
Kurtosis:,7.853,Cond. No.,30800.0


In [9]:
OLS = smf.ols(formula='city_mpg ~ C(fuel) + weight + hp', data=df_autos).fit()
OLS.summary()

0,1,2,3
Dep. Variable:,city_mpg,R-squared:,0.785
Model:,OLS,Adj. R-squared:,0.781
Method:,Least Squares,F-statistic:,241.6
Date:,"Mon, 09 Aug 2021",Prob (F-statistic):,4.6000000000000004e-66
Time:,00:37:33,Log-Likelihood:,-513.9
No. Observations:,203,AIC:,1036.0
Df Residuals:,199,BIC:,1049.0
Df Model:,3,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,56.9187,1.687,33.737,0.000,53.592,60.246
C(fuel)[T.gas],-7.5871,0.862,-8.806,0.000,-9.286,-5.888
weight,-0.0078,0.001,-10.677,0.000,-0.009,-0.006
hp,-0.0459,0.010,-4.790,0.000,-0.065,-0.027

0,1,2,3
Omnibus:,57.16,Durbin-Watson:,1.283
Prob(Omnibus):,0.0,Jarque-Bera (JB):,229.938
Skew:,1.043,Prob(JB):,1.17e-50
Kurtosis:,7.779,Cond. No.,22000.0


In [11]:
OLS = smf.ols(formula='city_mpg ~ + weight', data=df_autos).fit()
OLS.summary()

0,1,2,3
Dep. Variable:,city_mpg,R-squared:,0.575
Model:,OLS,Adj. R-squared:,0.573
Method:,Least Squares,F-statistic:,271.7
Date:,"Mon, 09 Aug 2021",Prob (F-statistic):,3.5e-39
Time:,00:38:22,Log-Likelihood:,-582.92
No. Observations:,203,AIC:,1170.0
Df Residuals:,201,BIC:,1176.0
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,49.5768,1.507,32.901,0.000,46.606,52.548
weight,-0.0095,0.001,-16.483,0.000,-0.011,-0.008

0,1,2,3
Omnibus:,31.886,Durbin-Watson:,1.451
Prob(Omnibus):,0.0,Jarque-Bera (JB):,44.85
Skew:,0.935,Prob(JB):,1.82e-10
Kurtosis:,4.344,Cond. No.,13000.0


In [None]:
res_bs_1.plot_partial(0, cpr=True)

In [None]:
res_bs_1.plot_partial(1, cpr=True)

Poisson Regression

In [None]:
alpha_2 = np.array([8283989284.5829611, 14628207.58927821])

gam_bs_2 = GLMGam.from_formula('city_mpg ~ fuel + drive', 
                             data=df_autos,
                             smoother=bs, 
                             alpha=alpha_2,
                             family=sm.families.Poisson())
res_bs_2 = gam_bs_2.fit()

print(res_bs_2.summary())

In [None]:
# Optimal penalization weights alpha can be obtaine through generalized
# cross-validation or k-fold cross-validation.
# The alpha above are from the unit tests against the R mgcv package.
print(gam_bs_2.select_penweight()[0])
print(gam_bs_2.select_penweight_kfold()[0])


source
- https://www.statsmodels.org/devel/gam.html