#### Introduction to Statistical Learning Lab 7.2

# Generalized Additive Models 

We now fit a GAM to predict `wage` using natural spline functions of `year`
and `age` , treating `education` as a qualitative predictor from the `Wage` dataset. Since
this is just a big linear regression model using an appropriate choice of
basis functions, we can simply do this using the `ols` function from `sm`.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
from statsmodels.gam.api import GLMGam, BSplines
import statsmodels.formula.api as smf
from sklearn.preprocessing import scale 
from islpy import datasets
from statsmodels.gam.tests.test_penalized import df_autos
sns.set()
%matplotlib inline

In [None]:
Wage=datasets.Wage()

In [None]:
Wage.describe()

Lets visualize the data before we do something about it

In [None]:
sns.scatterplot(Wage.age,Wage.wage)
plt.show()
sns.scatterplot(Wage.year,Wage.wage)
plt.show()
sns.boxplot(x='education', y='wage',data=Wage,order=['1. < HS Grad', '2. HS Grad', '3. Some College','4. College Grad','5. Advanced Degree'])
plt.xticks(rotation=90)
plt.show()

Lets perform a `ols` for predicting `wage` on `year`, `age` and `education`:

In [None]:
lm = model = smf.ols(formula='wage~year+age+education', data=Wage).fit()
lm.summary()

This might not be a good fit acording to R-squared, so we will work on a GAM, we will perform a smothing spline using `bs()` from `GLMGam`, we are going to work 4 d.o.f for `year` and 5 d.o.f. for `age`, education will be fitted as a dummy variable.

In [None]:
x_spline = Wage[['year','age']]
bs = BSplines(x_spline, df=[4, 6], degree=[3,5],include_intercept=False)

`bs` defines de conditions under which the Splines will be made. Now lets fit our data for `wage`:

In [None]:
gam_bs = GLMGam.from_formula('wage ~ year +age+education', data=Wage, smoother=bs)

In [None]:
res_bs = gam_bs.fit()
res_bs.summary()

Finally lets visualize the fits:

In [None]:
res_bs.plot_partial(1, cpr=True,plot_se=False)
plt.show()
res_bs.plot_partial(0, cpr=True,plot_se=False)
plt.show()
