In [21]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import glob

from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

%matplotlib inline

In [22]:
df0 = pd.read_csv("boston.csv", index_col=0)
df1 = df0.reset_index(drop=True)

In [23]:
X = df1.drop("medv", axis=1)
y = df1["medv"]

In [24]:
# KFoldの設定
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [25]:
list_RMSE = []
for i, (train_index, test_index) in enumerate(kf.split(X)):
    X_train, X_test = X.loc[train_index], X.loc[test_index]
    y_train, y_test = y.loc[train_index], y.loc[test_index]
    
    reg = RandomForestRegressor()
    
    reg.fit(X_train, y_train)
    
    name = "model_" + str(i) + ".pickle"
    with open(name, mode = "wb") as f:
        pickle.dump(reg, f)
        
    y_pred = reg.predict(X_test)
    
    RMSE = np.sqrt(mean_squared_error(y_test, y_pred))
    print(RMSE)
    
    list_RMSE.append(RMSE)

2.7859494386133066
3.3208660104566
4.09500817344655
2.98439004498307
2.8716366430836935


In [26]:
np.mean(list_RMSE)

3.211570062116644

In [28]:
models = glob.glob("*.pickle")
models

['model_0.pickle',
 'model_1.pickle',
 'model_2.pickle',
 'model_3.pickle',
 'model_4.pickle']

In [31]:
for model_ in models:
    with open(model_, mode="br") as f:
        model = pickle.load(f)
        
    pred = model.predict(X_test)

    print(pred)

[21.865 20.294 14.044 19.218 13.852 24.487 15.497 19.498 20.141 24.964
 22.491 19.103 23.343 21.858 27.907 21.926 26.134 22.204 32.835 18.749
 19.921 19.449 21.133 19.18  15.706 14.962 19.621 26.545 48.457 49.681
 22.459 22.012 30.225 28.847 34.687 34.957 30.328 23.773 22.649 22.146
 26.883 23.306 20.421 23.932 25.576 31.117 48.182 30.291 30.96  35.864
 20.703 21.197 38.624 34.056 23.539 28.126 32.851 33.915 26.154 20.841
 22.049 17.646 21.299 19.909 20.702 19.344 25.247 30.72  18.686 25.201
 23.148 22.105 16.9   21.424 42.618  7.62   7.659 11.662  8.136 14.847
  9.375 11.747 14.378 12.657 15.045 11.46  14.712 15.368 18.756 18.088
 20.482 21.732 21.779 17.661 21.237 18.881 16.427 23.237 21.51  19.987
 14.827]
[21.854 20.528 14.219 19.592 13.92  24.521 15.92  19.552 20.188 25.169
 22.526 18.914 23.224 21.717 27.255 22.174 25.894 22.104 32.955 18.969
 19.02  18.795 20.469 19.543 15.393 15.107 20.046 26.138 48.222 49.517
 22.155 22.301 30.637 29.263 34.359 34.848 29.962 23.631 22.429 21.4