## CatBoost-Optunaのサンプルコード

In [1]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as  cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [2]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns = boston.feature_names)
# 目的変数の追加
df['MEDV'] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [3]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(df.iloc[:, 0:df.shape[1]-1],
                                                    df.iloc[:, df.shape[1]-1],
                                                    test_size=TEST_SIZE,
                                                    random_state=RANDOM_STATE)

In [4]:
def objective(trial):
    
    iterations = trial.suggest_int('iterations', 50, 300)
    depth = trial.suggest_int('depth', 4, 10)                                       
    learning_rate = trial.suggest_loguniform('learning_rate', 0.01, 0.3)               
    random_strength = trial.suggest_int('random_strength', 0, 100)   
    bagging_temperature = trial.suggest_loguniform('bagging_temperature', 0.01, 100.00) 
    od_type = trial.suggest_categorical('od_type', ['IncToDec', 'Iter'])
    od_wait = trial.suggest_int('od_wait', 10, 50)


    model = cb.CatBoostRegressor(iterations = iterations,
                                depth = depth,
                                learning_rate = learning_rate,
                                random_strength = random_strength,
                                bagging_temperature = bagging_temperature,
                                od_type = od_type,
                                od_wait = od_wait
                               )

    score = cross_val_score(model,
                            x_train,
                            y_train,
                            cv=5,
                            scoring="neg_mean_absolute_error"
                           )
    mae = score.mean()

    return mae

In [5]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

[32m[I 2021-05-08 11:20:18,224][0m A new study created in memory with name: no-name-91941c11-8bfc-4154-83f3-08cfbbe330bf[0m


0:	learn: 8.6626503	total: 67.8ms	remaining: 11.1s
1:	learn: 8.5623648	total: 106ms	remaining: 8.58s
2:	learn: 8.4176540	total: 125ms	remaining: 6.71s
3:	learn: 8.2814829	total: 142ms	remaining: 5.67s
4:	learn: 8.1851065	total: 150ms	remaining: 4.77s
5:	learn: 8.0372811	total: 158ms	remaining: 4.16s
6:	learn: 7.9070159	total: 168ms	remaining: 3.76s
7:	learn: 7.8259450	total: 177ms	remaining: 3.46s
8:	learn: 7.7198170	total: 209ms	remaining: 3.6s
9:	learn: 7.6018988	total: 227ms	remaining: 3.49s
10:	learn: 7.5532044	total: 233ms	remaining: 3.23s
11:	learn: 7.4432000	total: 246ms	remaining: 3.11s
12:	learn: 7.3556810	total: 257ms	remaining: 2.98s
13:	learn: 7.2507997	total: 279ms	remaining: 2.99s
14:	learn: 7.1231800	total: 283ms	remaining: 2.81s
15:	learn: 7.0604626	total: 295ms	remaining: 2.72s
16:	learn: 6.9631341	total: 312ms	remaining: 2.69s
17:	learn: 6.9297835	total: 322ms	remaining: 2.61s
18:	learn: 6.8164309	total: 341ms	remaining: 2.6s
19:	learn: 6.7171372	total: 363ms	remainin

[32m[I 2021-05-08 11:20:42,576][0m Trial 0 finished with value: -2.615471131074588 and parameters: {'iterations': 164, 'depth': 8, 'learning_rate': 0.040408940463150064, 'random_strength': 33, 'bagging_temperature': 7.962117362539803, 'od_type': 'IncToDec', 'od_wait': 33}. Best is trial 0 with value: -2.615471131074588.[0m


163:	learn: 2.4999194	total: 1.88s	remaining: 0us
0:	learn: 8.7864622	total: 105ms	remaining: 19.7s
1:	learn: 8.7529965	total: 177ms	remaining: 16.5s
2:	learn: 8.7187124	total: 213ms	remaining: 13.2s
3:	learn: 8.6744350	total: 267ms	remaining: 12.3s
4:	learn: 8.6289577	total: 327ms	remaining: 12s
5:	learn: 8.5811550	total: 360ms	remaining: 11s
6:	learn: 8.5398266	total: 415ms	remaining: 10.8s
7:	learn: 8.5006758	total: 475ms	remaining: 10.8s
8:	learn: 8.4821331	total: 484ms	remaining: 9.69s
9:	learn: 8.4425252	total: 524ms	remaining: 9.38s
10:	learn: 8.4173073	total: 565ms	remaining: 9.15s
11:	learn: 8.3755295	total: 623ms	remaining: 9.19s
12:	learn: 8.3331662	total: 698ms	remaining: 9.45s
13:	learn: 8.2955190	total: 752ms	remaining: 9.4s
14:	learn: 8.2655189	total: 809ms	remaining: 9.39s
15:	learn: 8.2260441	total: 874ms	remaining: 9.45s
16:	learn: 8.1842770	total: 877ms	remaining: 8.87s
17:	learn: 8.1444271	total: 932ms	remaining: 8.85s
18:	learn: 8.1073193	total: 995ms	remaining: 8.

[32m[I 2021-05-08 11:21:23,210][0m Trial 1 finished with value: -3.6018846243816265 and parameters: {'iterations': 189, 'depth': 10, 'learning_rate': 0.011177969351553116, 'random_strength': 91, 'bagging_temperature': 0.13909608951480798, 'od_type': 'Iter', 'od_wait': 40}. Best is trial 0 with value: -2.615471131074588.[0m


186:	learn: 4.5318671	total: 5.49s	remaining: 58.7ms
187:	learn: 4.5161975	total: 5.53s	remaining: 29.4ms
188:	learn: 4.5006065	total: 5.56s	remaining: 0us
0:	learn: 8.2109488	total: 18.4ms	remaining: 1.45s
1:	learn: 7.8720039	total: 34.7ms	remaining: 1.35s
2:	learn: 7.4031691	total: 41.7ms	remaining: 1.07s
3:	learn: 7.0157002	total: 53.2ms	remaining: 1.01s
4:	learn: 6.7818350	total: 64.1ms	remaining: 961ms
5:	learn: 6.3485132	total: 74.4ms	remaining: 918ms
6:	learn: 6.0499550	total: 84.5ms	remaining: 881ms
7:	learn: 5.8074383	total: 94.6ms	remaining: 851ms
8:	learn: 5.6064880	total: 102ms	remaining: 804ms
9:	learn: 5.3837384	total: 112ms	remaining: 785ms
10:	learn: 5.3238155	total: 117ms	remaining: 732ms
11:	learn: 5.0469699	total: 128ms	remaining: 723ms
12:	learn: 4.9278733	total: 140ms	remaining: 721ms
13:	learn: 4.7150686	total: 168ms	remaining: 793ms
14:	learn: 4.5513574	total: 172ms	remaining: 743ms
15:	learn: 4.4328180	total: 174ms	remaining: 697ms
16:	learn: 4.3198357	total: 18

[32m[I 2021-05-08 11:21:37,154][0m Trial 2 finished with value: -2.3571034605682053 and parameters: {'iterations': 80, 'depth': 8, 'learning_rate': 0.15842858557625358, 'random_strength': 19, 'bagging_temperature': 0.14000335974578265, 'od_type': 'IncToDec', 'od_wait': 36}. Best is trial 2 with value: -2.3571034605682053.[0m


77:	learn: 1.2326552	total: 973ms	remaining: 25ms
78:	learn: 1.2128095	total: 981ms	remaining: 12.4ms
79:	learn: 1.1955573	total: 1s	remaining: 0us
0:	learn: 8.7618507	total: 37ms	remaining: 6.03s
1:	learn: 8.7056841	total: 64ms	remaining: 5.19s
2:	learn: 8.6601608	total: 80.8ms	remaining: 4.34s
3:	learn: 8.6104367	total: 96.9ms	remaining: 3.87s
4:	learn: 8.5516475	total: 116ms	remaining: 3.69s
5:	learn: 8.4974626	total: 132ms	remaining: 3.48s
6:	learn: 8.4467533	total: 137ms	remaining: 3.07s
7:	learn: 8.3960352	total: 156ms	remaining: 3.05s
8:	learn: 8.3325630	total: 195ms	remaining: 3.35s
9:	learn: 8.2933373	total: 225ms	remaining: 3.47s
10:	learn: 8.2331282	total: 239ms	remaining: 3.33s
11:	learn: 8.1731573	total: 254ms	remaining: 3.21s
12:	learn: 8.1196801	total: 272ms	remaining: 3.16s
13:	learn: 8.0889081	total: 275ms	remaining: 2.95s
14:	learn: 8.0396279	total: 291ms	remaining: 2.89s
15:	learn: 7.9809738	total: 306ms	remaining: 2.83s
16:	learn: 7.9350768	total: 325ms	remaining: 2

[32m[I 2021-05-08 11:22:21,126][0m Trial 3 finished with value: -3.266895399183845 and parameters: {'iterations': 164, 'depth': 9, 'learning_rate': 0.01534361648520562, 'random_strength': 22, 'bagging_temperature': 10.05358866896086, 'od_type': 'IncToDec', 'od_wait': 39}. Best is trial 2 with value: -2.3571034605682053.[0m


160:	learn: 3.9034744	total: 8.36s	remaining: 156ms
161:	learn: 3.8957700	total: 8.38s	remaining: 103ms
162:	learn: 3.8833099	total: 8.39s	remaining: 51.5ms
163:	learn: 3.8759552	total: 8.4s	remaining: 0us
0:	learn: 8.5127127	total: 48.6ms	remaining: 12.1s
1:	learn: 8.2417718	total: 77.2ms	remaining: 9.53s
2:	learn: 7.9777969	total: 103ms	remaining: 8.48s
3:	learn: 7.6432274	total: 129ms	remaining: 7.92s
4:	learn: 7.3233362	total: 171ms	remaining: 8.33s
5:	learn: 7.0127594	total: 236ms	remaining: 9.57s
6:	learn: 6.7497497	total: 264ms	remaining: 9.13s
7:	learn: 6.5263058	total: 295ms	remaining: 8.88s
8:	learn: 6.4419056	total: 302ms	remaining: 8.05s
9:	learn: 6.2432769	total: 326ms	remaining: 7.8s
10:	learn: 6.1447895	total: 363ms	remaining: 7.86s
11:	learn: 5.9261589	total: 414ms	remaining: 8.17s
12:	learn: 5.7336020	total: 571ms	remaining: 10.4s
13:	learn: 5.5751819	total: 621ms	remaining: 10.4s
14:	learn: 5.4607474	total: 737ms	remaining: 11.5s
15:	learn: 5.3214180	total: 847ms	rema

[32m[I 2021-05-08 11:23:33,813][0m Trial 4 finished with value: -2.2351302007517675 and parameters: {'iterations': 249, 'depth': 10, 'learning_rate': 0.09907036538210659, 'random_strength': 73, 'bagging_temperature': 0.013642546811938504, 'od_type': 'Iter', 'od_wait': 22}. Best is trial 4 with value: -2.2351302007517675.[0m


0:	learn: 8.7447824	total: 54.3ms	remaining: 13.9s
1:	learn: 8.6865001	total: 88.8ms	remaining: 11.3s
2:	learn: 8.6270172	total: 115ms	remaining: 9.75s
3:	learn: 8.5505083	total: 141ms	remaining: 8.95s
4:	learn: 8.4711315	total: 169ms	remaining: 8.54s
5:	learn: 8.3897161	total: 207ms	remaining: 8.67s
6:	learn: 8.3195415	total: 250ms	remaining: 8.94s
7:	learn: 8.2526589	total: 282ms	remaining: 8.77s
8:	learn: 8.2220436	total: 285ms	remaining: 7.84s
9:	learn: 8.1562894	total: 313ms	remaining: 7.72s
10:	learn: 8.1095981	total: 383ms	remaining: 8.57s
11:	learn: 8.0381595	total: 421ms	remaining: 8.59s
12:	learn: 7.9687696	total: 484ms	remaining: 9.09s
13:	learn: 7.9100631	total: 552ms	remaining: 9.58s
14:	learn: 7.8618843	total: 601ms	remaining: 9.7s
15:	learn: 7.7986326	total: 643ms	remaining: 9.69s
16:	learn: 7.7314496	total: 663ms	remaining: 9.36s
17:	learn: 7.6680001	total: 700ms	remaining: 9.29s
18:	learn: 7.6094938	total: 730ms	remaining: 9.15s
19:	learn: 7.5672535	total: 755ms	remain

[32m[I 2021-05-08 11:24:37,527][0m Trial 5 finished with value: -2.7427893671987174 and parameters: {'iterations': 257, 'depth': 10, 'learning_rate': 0.019633093134444342, 'random_strength': 49, 'bagging_temperature': 0.5282616678863579, 'od_type': 'Iter', 'od_wait': 33}. Best is trial 4 with value: -2.2351302007517675.[0m


0:	learn: 8.7506395	total: 56.7ms	remaining: 9.65s
1:	learn: 8.6820674	total: 110ms	remaining: 9.31s
2:	learn: 8.6059129	total: 174ms	remaining: 9.77s
3:	learn: 8.5381596	total: 213ms	remaining: 8.89s
4:	learn: 8.4642968	total: 255ms	remaining: 8.46s
5:	learn: 8.3839643	total: 294ms	remaining: 8.09s
6:	learn: 8.3216146	total: 307ms	remaining: 7.2s
7:	learn: 8.2537648	total: 371ms	remaining: 7.55s
8:	learn: 8.1826575	total: 401ms	remaining: 7.22s
9:	learn: 8.1180727	total: 433ms	remaining: 6.96s
10:	learn: 8.0446443	total: 463ms	remaining: 6.74s
11:	learn: 7.9848173	total: 493ms	remaining: 6.53s
12:	learn: 7.9214489	total: 546ms	remaining: 6.63s
13:	learn: 7.8541351	total: 577ms	remaining: 6.47s
14:	learn: 7.7892439	total: 604ms	remaining: 6.28s
15:	learn: 7.7229817	total: 618ms	remaining: 5.99s
16:	learn: 7.6753460	total: 650ms	remaining: 5.89s
17:	learn: 7.6019820	total: 674ms	remaining: 5.73s
18:	learn: 7.5270949	total: 710ms	remaining: 5.68s
19:	learn: 7.4677363	total: 759ms	remaini

[32m[I 2021-05-08 11:25:37,095][0m Trial 6 finished with value: -2.8868714861971876 and parameters: {'iterations': 171, 'depth': 10, 'learning_rate': 0.01827034750739825, 'random_strength': 7, 'bagging_temperature': 31.081597106467587, 'od_type': 'Iter', 'od_wait': 10}. Best is trial 4 with value: -2.2351302007517675.[0m


168:	learn: 3.1138040	total: 5.21s	remaining: 61.7ms
169:	learn: 3.1019961	total: 5.28s	remaining: 31.1ms
170:	learn: 3.0898844	total: 5.3s	remaining: 0us
0:	learn: 8.7430354	total: 10.9ms	remaining: 3.25s
1:	learn: 8.6452339	total: 23.1ms	remaining: 3.41s
2:	learn: 8.5538923	total: 27.9ms	remaining: 2.74s
3:	learn: 8.4633771	total: 32.9ms	remaining: 2.42s
4:	learn: 8.3583996	total: 39.5ms	remaining: 2.31s
5:	learn: 8.2570420	total: 49.5ms	remaining: 2.41s
6:	learn: 8.1591193	total: 53.9ms	remaining: 2.24s
7:	learn: 8.0579486	total: 59ms	remaining: 2.14s
8:	learn: 8.0086942	total: 61.8ms	remaining: 1.99s
9:	learn: 7.9489879	total: 65.9ms	remaining: 1.9s
10:	learn: 7.8441325	total: 74.7ms	remaining: 1.95s
11:	learn: 7.7716810	total: 81.6ms	remaining: 1.94s
12:	learn: 7.7116515	total: 84.4ms	remaining: 1.85s
13:	learn: 7.6588296	total: 89.8ms	remaining: 1.82s
14:	learn: 7.5815176	total: 103ms	remaining: 1.94s
15:	learn: 7.5165321	total: 107ms	remaining: 1.89s
16:	learn: 7.4297093	total: 

[32m[I 2021-05-08 11:26:19,152][0m Trial 7 finished with value: -2.51825892634249 and parameters: {'iterations': 298, 'depth': 7, 'learning_rate': 0.028257638621714754, 'random_strength': 96, 'bagging_temperature': 0.29926149630720833, 'od_type': 'Iter', 'od_wait': 25}. Best is trial 4 with value: -2.2351302007517675.[0m


295:	learn: 2.3257934	total: 5.16s	remaining: 34.8ms
296:	learn: 2.3186051	total: 5.17s	remaining: 17.4ms
297:	learn: 2.3139066	total: 5.18s	remaining: 0us
0:	learn: 8.6735747	total: 6.53ms	remaining: 1.09s
1:	learn: 8.5506409	total: 14.2ms	remaining: 1.18s
2:	learn: 8.4762616	total: 29.7ms	remaining: 1.64s
3:	learn: 8.3684248	total: 32ms	remaining: 1.31s
4:	learn: 8.2700932	total: 36.7ms	remaining: 1.2s
5:	learn: 8.1293966	total: 40.7ms	remaining: 1.1s
6:	learn: 8.0128414	total: 47.5ms	remaining: 1.09s
7:	learn: 7.8992551	total: 50ms	remaining: 1000ms
8:	learn: 7.8419506	total: 56.3ms	remaining: 994ms
9:	learn: 7.7206747	total: 57.5ms	remaining: 908ms
10:	learn: 7.5966416	total: 75.5ms	remaining: 1.08s
11:	learn: 7.5563173	total: 82.6ms	remaining: 1.07s
12:	learn: 7.4598603	total: 91.8ms	remaining: 1.09s
13:	learn: 7.3734589	total: 97.1ms	remaining: 1.07s
14:	learn: 7.3379685	total: 99.8ms	remaining: 1.02s
15:	learn: 7.2218445	total: 107ms	remaining: 1.01s
16:	learn: 7.1713981	total: 

[32m[I 2021-05-08 11:26:41,258][0m Trial 8 finished with value: -2.7931911424749165 and parameters: {'iterations': 168, 'depth': 4, 'learning_rate': 0.033586375571097005, 'random_strength': 28, 'bagging_temperature': 0.4270458362857897, 'od_type': 'IncToDec', 'od_wait': 34}. Best is trial 4 with value: -2.2351302007517675.[0m


167:	learn: 3.6252480	total: 503ms	remaining: 0us
0:	learn: 8.4253072	total: 9.71ms	remaining: 2.13s
1:	learn: 8.1144414	total: 15.7ms	remaining: 1.71s
2:	learn: 7.8414385	total: 45.5ms	remaining: 3.29s
3:	learn: 7.5823463	total: 58.4ms	remaining: 3.15s
4:	learn: 7.2845833	total: 68.8ms	remaining: 2.96s
5:	learn: 7.0211487	total: 79.9ms	remaining: 2.85s
6:	learn: 6.7796356	total: 86.7ms	remaining: 2.64s
7:	learn: 6.5436003	total: 97.8ms	remaining: 2.59s
8:	learn: 6.4433533	total: 108ms	remaining: 2.53s
9:	learn: 6.3236865	total: 113ms	remaining: 2.36s
10:	learn: 6.0955547	total: 120ms	remaining: 2.27s
11:	learn: 5.9627317	total: 125ms	remaining: 2.16s
12:	learn: 5.8592409	total: 138ms	remaining: 2.19s
13:	learn: 5.7788469	total: 141ms	remaining: 2.07s
14:	learn: 5.6429292	total: 144ms	remaining: 1.97s
15:	learn: 5.5478319	total: 147ms	remaining: 1.87s
16:	learn: 5.4058874	total: 151ms	remaining: 1.8s
17:	learn: 5.3231056	total: 158ms	remaining: 1.77s
18:	learn: 5.2503641	total: 169ms	r

[32m[I 2021-05-08 11:27:09,842][0m Trial 9 finished with value: -2.1046695187554985 and parameters: {'iterations': 220, 'depth': 7, 'learning_rate': 0.09707351397968994, 'random_strength': 33, 'bagging_temperature': 0.15607369505514285, 'od_type': 'Iter', 'od_wait': 11}. Best is trial 9 with value: -2.1046695187554985.[0m


CPU times: user 5min 50s, sys: 34.5 s, total: 6min 24s
Wall time: 6min 51s


In [6]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(iterations = study.best_params['iterations'],
                                       depth = study.best_params['depth'],
                                       learning_rate = study.best_params['learning_rate'],
                                       random_strength = study.best_params['random_strength'],
                                       bagging_temperature = study.best_params['bagging_temperature'],
                                       od_type = study.best_params['od_type'],
                                       od_wait = study.best_params['od_wait']
                                      )

optimised_model.fit(x_train ,y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.3623300	total: 6.48ms	remaining: 1.42s
1:	learn: 8.0608052	total: 25.2ms	remaining: 2.74s
2:	learn: 7.8113851	total: 35.8ms	remaining: 2.59s
3:	learn: 7.5663558	total: 43.7ms	remaining: 2.36s
4:	learn: 7.2677175	total: 67.9ms	remaining: 2.92s
5:	learn: 6.9998095	total: 79.7ms	remaining: 2.84s
6:	learn: 6.7199944	total: 95.2ms	remaining: 2.9s
7:	learn: 6.4699148	total: 100ms	remaining: 2.65s
8:	learn: 6.3908418	total: 107ms	remaining: 2.51s
9:	learn: 6.2780960	total: 119ms	remaining: 2.51s
10:	learn: 6.0485992	total: 129ms	remaining: 2.45s
11:	learn: 5.9301835	total: 155ms	remaining: 2.69s
12:	learn: 5.8313952	total: 171ms	remaining: 2.73s
13:	learn: 5.7607332	total: 187ms	remaining: 2.75s
14:	learn: 5.6313833	total: 202ms	remaining: 2.76s
15:	learn: 5.4846757	total: 210ms	remaining: 2.68s
16:	learn: 5.3466353	total: 222ms	remaining: 2.65s
17:	learn: 5.2590406	total: 228ms	remaining: 2.56s
18:	learn: 5.1738165	total: 231ms	remaining: 2.44s
19:	learn: 5.0676711	total: 237ms	r

In [7]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame({'R2': r2_score(true, pred),
                          'MAE': mean_absolute_error(true, pred),
                          'MSE': mean_squared_error(true, pred),
                          'RMSE': np.sqrt(mean_squared_error(true, pred))},
                           index = ['scores'])
    return scores

In [8]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE        MSE      RMSE
scores  0.879188  2.535765  12.634581  3.554516
