In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score, RepeatedKFold
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score

df = pd.read_csv('/Users/chentingkao/PracticeData/housing.csv')
feature_df = df.drop(['MEDV'], axis = 1)
target = df['MEDV']

X_train, X_test, y_train, y_test = train_test_split(feature_df, target, test_size = 0.25)

In [2]:
ridge_model = Ridge(alpha = 1)
ridge_model.fit(X_train, y_train)

# The intercept is the value of the target variable when all the predictors
# are zero.
# intercept 就是那個常數項
ridge_model.intercept_

22.972988767284903

In [3]:
# correlation coefficient
# 可說是各自變數的權重
ridge_model.coef_

array([-0.04379411,  0.03072625, -0.02053928,  1.50380613, -8.87804506,
        4.79488244, -0.01391751, -1.13890216,  0.20222418, -0.01177644,
       -0.83984277,  0.00941658, -0.43911811])

In [4]:
y_predict = ridge_model.predict(X_test)
r2_score(y_test, y_predict)

0.7092900412378031

In [15]:
from sklearn.model_selection import GridSearchCV

# 做三次，每次均分為十份，其中某一份做為 test
cross_validation = RepeatedKFold(n_splits = 10, n_repeats = 3)

# define GridSearch
grid = {'alpha': np.arange(0, 1, 0.1)}
model = Ridge()
search = GridSearchCV(model, grid, scoring = 'neg_mean_absolute_error', cv = cross_validation, n_jobs = -1)

# return a GridSearchCV object
search_results = search.fit(X_train, y_train)
print(search_results.best_score_)
print(search_results.best_params_)

# Get the best model
selected_model = search_results.best_estimator_

# print the correlation coefficient of all features
print(pd.Series(selected_model.coef_, index = X_train.columns))

# make predictions and see how accurate they could be
y_predict = selected_model.predict(X_test)
print(r2_score(y_test, y_predict))


-3.341103417845311
{'alpha': 0.6000000000000001}
CRIM        -0.045433
 ZN          0.030323
 INDUS      -0.012432
 CHAS        1.565749
 NOX       -10.847450
 RM          4.788400
 AGE        -0.012182
 DIS        -1.166753
 RAD         0.206715
 TAX        -0.011560
 PTRATIO    -0.863998
 B           0.009372
 LSTAT      -0.435621
dtype: float64
0.7109356983325323
