https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation

In [11]:
import pandas as pd
import numpy as np
import seaborn as sns
import statsmodels.formula.api as smf

from sklearn.metrics import r2_score, confusion_matrix
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.linear_model import LinearRegression

from src.split_data import statmodels_split

In [29]:
df = sns.load_dataset('iris')
df.head(10)

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa
5,5.4,3.9,1.7,0.4,setosa
6,4.6,3.4,1.4,0.3,setosa
7,5.0,3.4,1.5,0.2,setosa
8,4.4,2.9,1.4,0.2,setosa
9,4.9,3.1,1.5,0.1,setosa


In [14]:
#let's do linear regression with statsmodels first

In [33]:
formula = 'sepal_length ~ sepal_width + petal_length + petal_width + C(species)'
train, test = statmodels_split(df, random_state=3)
model = smf.ols(formula=formula, data=train).fit()

r2_statsmodels = r2_score(test['sepal_length'],model.predict(test))  #0.8308454946286852 with random_state=3
r2_statsmodels

0.8308454946286856

In [34]:
# Do it with sklearn
df_dummies = pd.get_dummies(df, drop_first=True)
df_dummies.columns

Index(['sepal_length', 'sepal_width', 'petal_length', 'petal_width',
       'species_versicolor', 'species_virginica'],
      dtype='object')

In [35]:
# formula = 'sepal_length ~ sepal_width + petal_length + petal_width + C(species)' #same as above
LHS = formula.split('~')[0].strip() # 'sepal_length'
y = df_dummies[LHS]
X = df_dummies.drop(columns=LHS) # df[['sepal_width', 'petal_length', 'petal_width' + species_dummy_1 + species_dummy_2]]
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=3)

linear_model = LinearRegression(n_jobs=-1).fit(X_train, y_train)
r2_sklearn = r2_score(y_test,linear_model.predict(X_test))
r2_sklearn

0.8308454946286854

In [36]:
round(r2_statsmodels,5) == round(r2_sklearn,5)

True

So now we know we can get the same r-squared, to within rounding error.
What about 10 fold cross validation?

In [37]:
# start with sklearn

In [38]:
cross_val_sklearn_r2 = np.mean(cross_val_score(linear_model, X_test, y_test, scoring='r2', cv=10, n_jobs=-1))
cross_val_sklearn_r2

0.6636266115633545

In [39]:
LHS = formula.split('~')[0].strip() # 'sepal_length'
y = df_dummies[LHS]
X = df_dummies.drop(columns=LHS) # df[['sepal_width', 'petal_length', 'petal_width' + species_dummy_1 + species_dummy_2]]
# X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=3)


In [40]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   sepal_width         150 non-null    float64
 1   petal_length        150 non-null    float64
 2   petal_width         150 non-null    float64
 3   species_versicolor  150 non-null    uint8  
 4   species_virginica   150 non-null    uint8  
dtypes: float64(3), uint8(2)
memory usage: 3.9 KB


In [41]:
X.shape, y.shape

((150, 5), (150,))

In [42]:
X_train.shape, y_train.shape

((112, 5), (112,))

In [43]:
X_test.shape, y_test.shape

((38, 5), (38,))

In [44]:
df_dummies = pd.get_dummies(df, drop_first=True)
df_dummies.columns
LHS = formula.split('~')[0].strip() # 'sepal_length'
y = df_dummies[LHS]
X = df_dummies.drop(columns=LHS)
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=3)
model = sm.OLS
cross_val_statsmodels_r2 = cross_val_score(SMWrapper(model), 
                                                   X_train, 
                                                   y_train, 
                                                   scoring='r2', 
                                                   cv=10, 
                                                   n_jobs=-1
                                            )
cross_val_statsmodels_r2

array([0.71260406, 0.87968716, 0.80601214, 0.87399924, 0.88859355,
       0.867954  , 0.75752086, 0.82381163, 0.86767947, 0.89468072])