In [62]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import statsmodels.api as sma
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## hold-outで線形回帰を学習&評価

In [40]:
df = sns.load_dataset('tips')
df.head()

Unnamed: 0,total_bill,tip,sex,smoker,day,time,size
0,16.99,1.01,Female,No,Sun,Dinner,2
1,10.34,1.66,Male,No,Sun,Dinner,3
2,21.01,3.5,Male,No,Sun,Dinner,3
3,23.68,3.31,Male,No,Sun,Dinner,2
4,24.59,3.61,Female,No,Sun,Dinner,4


In [60]:
y_col = 'tip'
y = df[y_col]
X = df.drop(columns=y_col)
numerical_columns = list(X.select_dtypes(include='number'))

# 特徴量をダミー変数にする
X = pd.get_dummies(X, drop_first=True)

# hold-out
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

# 標準化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train[numerical_columns])
X_train[numerical_columns] = X_train_scaled
X_test[numerical_columns] = scaler.transform(X_test[numerical_columns])

# 線形回帰
model = LinearRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

# 学習データとテストデータの精度
mean_squared_error(y_test, y_pred)

0.955080898861715

# K-Fold CVで線形回帰を学習&評価

In [83]:
y = df[y_col]
X = df['total_bill'].values.reshape(-1, 1)
k = 5
cv = KFold(n_splits=k, random_state=0, shuffle=True)
model = LinearRegression()
mse_list = []

for train_index, test_index in cv.split(X):
  X_train, X_test = X[train_index], X[test_index]
  y_train, y_test = y[train_index], y[test_index]
  # モデル学習
  model.fit(X_train, y_train)
  y_pred = model.predict(X_test)
  # MSE
  mse = mean_squared_error(y_test, y_pred)
  mse_list.append(mse)

print(f"MSE({k}FoldCV): {np.mean(mse_list)})")
print(f"std: {np.std(mse_list)})")

[[16.99]
 [10.34]
 [21.01]
 [23.68]
 [24.59]
 [ 8.77]
 [26.88]
 [14.78]
 [10.27]
 [35.26]
 [18.43]
 [14.83]
 [10.33]
 [16.29]
 [20.65]
 [17.92]
 [20.29]
 [39.42]
 [19.82]
 [17.81]
 [13.37]
 [12.69]
 [21.7 ]
 [19.65]
 [ 9.55]
 [18.35]
 [15.06]
 [20.69]
 [17.78]
 [24.06]
 [16.31]
 [18.69]
 [31.27]
 [16.04]
 [17.46]
 [13.94]
 [ 9.68]
 [22.23]
 [32.4 ]
 [28.55]
 [18.04]
 [12.54]
 [10.29]
 [34.81]
 [ 9.94]
 [25.56]
 [38.01]
 [26.41]
 [11.24]
 [48.27]
 [20.29]
 [13.81]
 [11.02]
 [20.08]
 [16.45]
 [ 3.07]
 [20.23]
 [15.01]
 [12.02]
 [26.86]
 [10.51]
 [27.2 ]
 [22.76]
 [17.29]
 [19.44]
 [16.66]
 [10.07]
 [32.68]
 [15.98]
 [34.83]
 [13.03]
 [18.28]
 [24.71]
 [21.16]
 [28.97]
 [22.49]
 [16.32]
 [22.75]
 [40.17]
 [27.28]
 [12.03]
 [21.01]
 [12.46]
 [11.35]
 [15.38]
 [44.3 ]
 [22.42]
 [15.36]
 [20.49]
 [14.31]
 [38.07]
 [23.95]
 [25.71]
 [17.31]
 [29.93]
 [10.65]
 [24.08]
 [11.69]
 [13.42]
 [14.26]
 [15.95]
 [ 8.52]
 [14.52]
 [11.38]
 [22.82]
 [19.08]
 [20.27]
 [11.17]
 [12.26]
 [18.26]
 [ 8.51]
 