## CatBoost-Optunaのサンプルコード

In [1]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as  cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [2]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns = boston.feature_names)
# 目的変数の追加
df['MEDV'] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [3]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(df.iloc[:, 0:df.shape[1]-1],
                                                    df.iloc[:, df.shape[1]-1],
                                                    test_size=TEST_SIZE,
                                                    random_state=RANDOM_STATE)

In [4]:
def objective(trial):
    
    iterations = trial.suggest_int('iterations', 50, 300)
    depth = trial.suggest_int('depth', 4, 10)                                       
    learning_rate = trial.suggest_loguniform('learning_rate', 0.01, 0.3)               
    random_strength = trial.suggest_int('random_strength', 0, 100)   
    bagging_temperature = trial.suggest_loguniform('bagging_temperature', 0.01, 100.00) 
    od_type = trial.suggest_categorical('od_type', ['IncToDec', 'Iter'])
    od_wait = trial.suggest_int('od_wait', 10, 50)


    model = cb.CatBoostRegressor(iterations = iterations,
                                depth = depth,
                                learning_rate = learning_rate,
                                random_strength = random_strength,
                                bagging_temperature = bagging_temperature,
                                od_type = od_type,
                                od_wait = od_wait
                               )

    score = cross_val_score(model,
                            x_train,
                            y_train,
                            cv=5,
                            scoring="neg_mean_absolute_error"
                           )
    mae = score.mean()

    return mae

In [5]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

[32m[I 2021-05-16 02:37:52,956][0m A new study created in memory with name: no-name-456ccd5f-fc2e-41db-9899-7aedfe8113fc[0m


0:	learn: 8.6374276	total: 62.1ms	remaining: 4.41s
1:	learn: 8.5210440	total: 64.1ms	remaining: 2.24s
2:	learn: 8.4429531	total: 65.9ms	remaining: 1.52s
3:	learn: 8.3318635	total: 68.9ms	remaining: 1.17s
4:	learn: 8.2393306	total: 70.8ms	remaining: 949ms
5:	learn: 8.1435998	total: 72.5ms	remaining: 798ms
6:	learn: 8.0317041	total: 83ms	remaining: 771ms
7:	learn: 7.9249001	total: 84.4ms	remaining: 676ms
8:	learn: 7.7964491	total: 86.4ms	remaining: 605ms
9:	learn: 7.6461657	total: 89.7ms	remaining: 556ms
10:	learn: 7.5218745	total: 94.8ms	remaining: 526ms
11:	learn: 7.4679788	total: 100ms	remaining: 502ms
12:	learn: 7.3726608	total: 102ms	remaining: 463ms
13:	learn: 7.2575991	total: 104ms	remaining: 430ms
14:	learn: 7.1791690	total: 105ms	remaining: 401ms
15:	learn: 7.0716827	total: 107ms	remaining: 375ms
16:	learn: 7.0229334	total: 110ms	remaining: 355ms
17:	learn: 6.9542098	total: 112ms	remaining: 335ms
18:	learn: 6.9116055	total: 114ms	remaining: 318ms
19:	learn: 6.8590174	total: 117m

[32m[I 2021-05-16 02:38:02,991][0m Trial 0 finished with value: -3.3788407250472714 and parameters: {'iterations': 72, 'depth': 5, 'learning_rate': 0.04378701466670342, 'random_strength': 55, 'bagging_temperature': 15.708883855131923, 'od_type': 'IncToDec', 'od_wait': 12}. Best is trial 0 with value: -3.3788407250472714.[0m


68:	learn: 4.5251607	total: 194ms	remaining: 8.45ms
69:	learn: 4.5140418	total: 196ms	remaining: 5.61ms
70:	learn: 4.4865565	total: 202ms	remaining: 2.85ms
71:	learn: 4.4455706	total: 205ms	remaining: 0us
0:	learn: 8.7203559	total: 47.2ms	remaining: 6.84s
1:	learn: 8.6418447	total: 87.9ms	remaining: 6.33s
2:	learn: 8.5458857	total: 128ms	remaining: 6.09s
3:	learn: 8.4415937	total: 160ms	remaining: 5.67s
4:	learn: 8.3199805	total: 181ms	remaining: 5.09s
5:	learn: 8.2085471	total: 200ms	remaining: 4.67s
6:	learn: 8.1639454	total: 203ms	remaining: 4.02s
7:	learn: 8.0625425	total: 235ms	remaining: 4.04s
8:	learn: 7.9336910	total: 249ms	remaining: 3.79s
9:	learn: 7.8624784	total: 291ms	remaining: 3.96s
10:	learn: 7.7546864	total: 401ms	remaining: 4.92s
11:	learn: 7.6388785	total: 422ms	remaining: 4.71s
12:	learn: 7.5365311	total: 459ms	remaining: 4.7s
13:	learn: 7.5126392	total: 461ms	remaining: 4.34s
14:	learn: 7.4210514	total: 490ms	remaining: 4.28s
15:	learn: 7.3164326	total: 547ms	remai

[32m[I 2021-05-16 02:38:20,777][0m Trial 1 finished with value: -2.8928640825189915 and parameters: {'iterations': 146, 'depth': 9, 'learning_rate': 0.03325189967779843, 'random_strength': 77, 'bagging_temperature': 9.42099631416822, 'od_type': 'IncToDec', 'od_wait': 11}. Best is trial 1 with value: -2.8928640825189915.[0m


140:	learn: 2.9646340	total: 1.97s	remaining: 70ms
141:	learn: 2.9597824	total: 1.99s	remaining: 56ms
142:	learn: 2.9496901	total: 1.99s	remaining: 41.8ms
143:	learn: 2.9322020	total: 2s	remaining: 27.8ms
144:	learn: 2.9252477	total: 2.02s	remaining: 13.9ms
145:	learn: 2.9072511	total: 2.04s	remaining: 0us
0:	learn: 8.5217402	total: 36.6ms	remaining: 9.74s
1:	learn: 8.2580887	total: 55.5ms	remaining: 7.35s
2:	learn: 8.0007758	total: 69.9ms	remaining: 6.15s
3:	learn: 7.6744642	total: 86ms	remaining: 5.66s
4:	learn: 7.3614812	total: 105ms	remaining: 5.49s
5:	learn: 7.0569893	total: 123ms	remaining: 5.33s
6:	learn: 6.7990085	total: 148ms	remaining: 5.51s
7:	learn: 6.5788726	total: 167ms	remaining: 5.41s
8:	learn: 6.4951168	total: 170ms	remaining: 4.88s
9:	learn: 6.2987118	total: 202ms	remaining: 5.2s
10:	learn: 6.2004883	total: 224ms	remaining: 5.21s
11:	learn: 5.9846349	total: 324ms	remaining: 6.88s
12:	learn: 5.7938267	total: 347ms	remaining: 6.78s
13:	learn: 5.6364654	total: 368ms	rema

[32m[I 2021-05-16 02:39:32,341][0m Trial 2 finished with value: -2.245669442966687 and parameters: {'iterations': 267, 'depth': 10, 'learning_rate': 0.09610152280175836, 'random_strength': 74, 'bagging_temperature': 4.94289559514254, 'od_type': 'IncToDec', 'od_wait': 13}. Best is trial 2 with value: -2.245669442966687.[0m


0:	learn: 8.7680822	total: 2.48ms	remaining: 360ms
1:	learn: 8.7141998	total: 8.82ms	remaining: 635ms
2:	learn: 8.6672132	total: 22.5ms	remaining: 1.07s
3:	learn: 8.6047023	total: 39.2ms	remaining: 1.39s
4:	learn: 8.5542573	total: 45ms	remaining: 1.27s
5:	learn: 8.5093981	total: 50.9ms	remaining: 1.19s
6:	learn: 8.4726841	total: 54.5ms	remaining: 1.08s
7:	learn: 8.4152666	total: 59.6ms	remaining: 1.03s
8:	learn: 8.3669693	total: 70.1ms	remaining: 1.07s
9:	learn: 8.3100259	total: 75.2ms	remaining: 1.02s
10:	learn: 8.2690674	total: 95ms	remaining: 1.17s
11:	learn: 8.2163402	total: 110ms	remaining: 1.23s
12:	learn: 8.1624917	total: 118ms	remaining: 1.21s
13:	learn: 8.1132279	total: 124ms	remaining: 1.17s
14:	learn: 8.0708190	total: 132ms	remaining: 1.15s
15:	learn: 8.0259924	total: 142ms	remaining: 1.15s
16:	learn: 7.9746770	total: 149ms	remaining: 1.13s
17:	learn: 7.9216625	total: 157ms	remaining: 1.12s
18:	learn: 7.8782045	total: 169ms	remaining: 1.13s
19:	learn: 7.8271966	total: 178ms	

[32m[I 2021-05-16 02:39:47,726][0m Trial 3 finished with value: -3.3335653008921495 and parameters: {'iterations': 146, 'depth': 7, 'learning_rate': 0.012388496300759307, 'random_strength': 4, 'bagging_temperature': 40.67524742763052, 'od_type': 'Iter', 'od_wait': 20}. Best is trial 2 with value: -2.245669442966687.[0m


0:	learn: 8.7216749	total: 5.27ms	remaining: 421ms
1:	learn: 8.6542753	total: 7.11ms	remaining: 281ms
2:	learn: 8.5423390	total: 9.44ms	remaining: 245ms
3:	learn: 8.4957720	total: 14.3ms	remaining: 274ms
4:	learn: 8.3991517	total: 15.9ms	remaining: 241ms
5:	learn: 8.3352620	total: 19.2ms	remaining: 241ms
6:	learn: 8.2655921	total: 21ms	remaining: 222ms
7:	learn: 8.1583102	total: 29ms	remaining: 264ms
8:	learn: 8.1055724	total: 38.4ms	remaining: 307ms
9:	learn: 8.0142905	total: 44.3ms	remaining: 315ms
10:	learn: 7.9338212	total: 46.1ms	remaining: 294ms
11:	learn: 7.8817405	total: 48.4ms	remaining: 278ms
12:	learn: 7.8270358	total: 51.3ms	remaining: 268ms
13:	learn: 7.7536154	total: 52.7ms	remaining: 252ms
14:	learn: 7.6791007	total: 54.9ms	remaining: 242ms
15:	learn: 7.5898259	total: 57.1ms	remaining: 232ms
16:	learn: 7.5082679	total: 59ms	remaining: 222ms
17:	learn: 7.4447727	total: 66.7ms	remaining: 233ms
18:	learn: 7.3887324	total: 81.2ms	remaining: 265ms
19:	learn: 7.2979069	total: 

[32m[I 2021-05-16 02:39:54,912][0m Trial 4 finished with value: -3.5662443684053926 and parameters: {'iterations': 81, 'depth': 6, 'learning_rate': 0.023785728500298968, 'random_strength': 15, 'bagging_temperature': 0.15212084079957228, 'od_type': 'IncToDec', 'od_wait': 23}. Best is trial 2 with value: -2.245669442966687.[0m


70:	learn: 5.0120351	total: 364ms	remaining: 51.2ms
71:	learn: 4.9849211	total: 369ms	remaining: 46.1ms
72:	learn: 4.9620133	total: 383ms	remaining: 42ms
73:	learn: 4.9271775	total: 391ms	remaining: 37ms
74:	learn: 4.8966613	total: 395ms	remaining: 31.6ms
75:	learn: 4.8798560	total: 407ms	remaining: 26.8ms
76:	learn: 4.8614183	total: 412ms	remaining: 21.4ms
77:	learn: 4.8240417	total: 416ms	remaining: 16ms
78:	learn: 4.7952541	total: 424ms	remaining: 10.7ms
79:	learn: 4.7780640	total: 431ms	remaining: 5.39ms
80:	learn: 4.7419924	total: 437ms	remaining: 0us
0:	learn: 8.4213876	total: 809us	remaining: 186ms
1:	learn: 8.0374310	total: 15ms	remaining: 1.72s
2:	learn: 7.6323791	total: 34.7ms	remaining: 2.63s
3:	learn: 7.3246944	total: 56.9ms	remaining: 3.23s
4:	learn: 7.0520710	total: 74.7ms	remaining: 3.38s
5:	learn: 6.7508617	total: 93.4ms	remaining: 3.5s
6:	learn: 6.5135918	total: 111ms	remaining: 3.56s
7:	learn: 6.2577844	total: 130ms	remaining: 3.63s
8:	learn: 6.0161487	total: 158ms	re

[32m[I 2021-05-16 02:40:37,798][0m Trial 5 finished with value: -2.1252732899928377 and parameters: {'iterations': 231, 'depth': 9, 'learning_rate': 0.09348431522145284, 'random_strength': 3, 'bagging_temperature': 48.16021746453347, 'od_type': 'IncToDec', 'od_wait': 15}. Best is trial 5 with value: -2.1252732899928377.[0m


0:	learn: 8.2951121	total: 1.17ms	remaining: 280ms
1:	learn: 7.9909774	total: 2.73ms	remaining: 327ms
2:	learn: 7.7998408	total: 4.76ms	remaining: 377ms
3:	learn: 7.5476663	total: 7.81ms	remaining: 463ms
4:	learn: 7.2043242	total: 9.7ms	remaining: 458ms
5:	learn: 7.0153139	total: 11.8ms	remaining: 463ms
6:	learn: 6.7779765	total: 15.7ms	remaining: 525ms
7:	learn: 6.5968249	total: 17.1ms	remaining: 498ms
8:	learn: 6.3791875	total: 19.7ms	remaining: 508ms
9:	learn: 6.1338258	total: 21.6ms	remaining: 499ms
10:	learn: 5.9329388	total: 22.6ms	remaining: 473ms
11:	learn: 5.8541967	total: 24.6ms	remaining: 469ms
12:	learn: 5.7163018	total: 27ms	remaining: 473ms
13:	learn: 5.5592132	total: 28.1ms	remaining: 456ms
14:	learn: 5.4150116	total: 29.5ms	remaining: 445ms
15:	learn: 5.2772929	total: 30.6ms	remaining: 430ms
16:	learn: 5.1014759	total: 33.2ms	remaining: 437ms
17:	learn: 5.0303440	total: 35.4ms	remaining: 438ms
18:	learn: 4.8904423	total: 36.8ms	remaining: 429ms
19:	learn: 4.7666559	tota

[32m[I 2021-05-16 02:40:57,368][0m Trial 6 finished with value: -2.0923687315399833 and parameters: {'iterations': 241, 'depth': 5, 'learning_rate': 0.126920876848543, 'random_strength': 14, 'bagging_temperature': 2.9269391977263366, 'od_type': 'IncToDec', 'od_wait': 34}. Best is trial 6 with value: -2.0923687315399833.[0m


0:	learn: 8.2143467	total: 1.49ms	remaining: 292ms
1:	learn: 7.6612893	total: 6.29ms	remaining: 614ms
2:	learn: 7.2159847	total: 11.9ms	remaining: 768ms
3:	learn: 6.6741513	total: 21.7ms	remaining: 1.05s
4:	learn: 6.3141407	total: 34ms	remaining: 1.3s
5:	learn: 6.1129814	total: 40.7ms	remaining: 1.29s
6:	learn: 5.9080322	total: 42.4ms	remaining: 1.15s
7:	learn: 5.6082395	total: 46.6ms	remaining: 1.1s
8:	learn: 5.3464211	total: 52ms	remaining: 1.09s
9:	learn: 5.0695627	total: 57.7ms	remaining: 1.08s
10:	learn: 4.9035075	total: 64.4ms	remaining: 1.09s
11:	learn: 4.6673724	total: 74.1ms	remaining: 1.14s
12:	learn: 4.4134125	total: 82.3ms	remaining: 1.16s
13:	learn: 4.2591311	total: 89.5ms	remaining: 1.17s
14:	learn: 4.1438890	total: 100ms	remaining: 1.22s
15:	learn: 4.0265282	total: 109ms	remaining: 1.23s
16:	learn: 3.8739927	total: 115ms	remaining: 1.22s
17:	learn: 3.7942612	total: 120ms	remaining: 1.2s
18:	learn: 3.6823631	total: 129ms	remaining: 1.21s
19:	learn: 3.5894838	total: 137ms	

[32m[I 2021-05-16 02:41:19,747][0m Trial 7 finished with value: -2.0189016584035113 and parameters: {'iterations': 197, 'depth': 7, 'learning_rate': 0.14357109323108547, 'random_strength': 4, 'bagging_temperature': 0.6905090911097129, 'od_type': 'Iter', 'od_wait': 11}. Best is trial 7 with value: -2.0189016584035113.[0m


0:	learn: 8.7813414	total: 36.5ms	remaining: 4.53s
1:	learn: 8.7429728	total: 64.7ms	remaining: 3.98s
2:	learn: 8.7050383	total: 92.6ms	remaining: 3.77s
3:	learn: 8.6640818	total: 132ms	remaining: 3.99s
4:	learn: 8.6211794	total: 170ms	remaining: 4.07s
5:	learn: 8.5762531	total: 209ms	remaining: 4.15s
6:	learn: 8.5379681	total: 239ms	remaining: 4.03s
7:	learn: 8.5010824	total: 263ms	remaining: 3.85s
8:	learn: 8.4669946	total: 266ms	remaining: 3.43s
9:	learn: 8.4275363	total: 295ms	remaining: 3.39s
10:	learn: 8.4025681	total: 321ms	remaining: 3.33s
11:	learn: 8.3647459	total: 345ms	remaining: 3.25s
12:	learn: 8.3271797	total: 413ms	remaining: 3.56s
13:	learn: 8.2928593	total: 440ms	remaining: 3.49s
14:	learn: 8.2551286	total: 469ms	remaining: 3.44s
15:	learn: 8.2184594	total: 500ms	remaining: 3.4s
16:	learn: 8.1798361	total: 507ms	remaining: 3.22s
17:	learn: 8.1427158	total: 539ms	remaining: 3.2s
18:	learn: 8.1083495	total: 581ms	remaining: 3.24s
19:	learn: 8.0811307	total: 618ms	remain

[32m[I 2021-05-16 02:41:48,829][0m Trial 8 finished with value: -4.018428514351837 and parameters: {'iterations': 125, 'depth': 10, 'learning_rate': 0.010346308789977146, 'random_strength': 12, 'bagging_temperature': 0.28280439887778097, 'od_type': 'Iter', 'od_wait': 28}. Best is trial 7 with value: -2.0189016584035113.[0m


123:	learn: 5.2713859	total: 3.86s	remaining: 31.1ms
124:	learn: 5.2499642	total: 3.88s	remaining: 0us
0:	learn: 8.4546573	total: 51ms	remaining: 12.2s
1:	learn: 8.1377439	total: 77.4ms	remaining: 9.25s
2:	learn: 7.8323843	total: 110ms	remaining: 8.76s
3:	learn: 7.4467400	total: 137ms	remaining: 8.09s
4:	learn: 7.0911925	total: 164ms	remaining: 7.75s
5:	learn: 6.7444514	total: 193ms	remaining: 7.54s
6:	learn: 6.4514530	total: 237ms	remaining: 7.91s
7:	learn: 6.2113150	total: 268ms	remaining: 7.8s
8:	learn: 6.1244963	total: 272ms	remaining: 7.02s
9:	learn: 5.9124422	total: 300ms	remaining: 6.93s
10:	learn: 5.8128987	total: 333ms	remaining: 6.96s
11:	learn: 5.5787758	total: 359ms	remaining: 6.84s
12:	learn: 5.3767372	total: 421ms	remaining: 7.39s
13:	learn: 5.2138509	total: 450ms	remaining: 7.29s
14:	learn: 5.0969180	total: 483ms	remaining: 7.28s
15:	learn: 4.9592363	total: 518ms	remaining: 7.29s
16:	learn: 4.7940707	total: 522ms	remaining: 6.88s
17:	learn: 4.6586788	total: 549ms	remaini

[32m[I 2021-05-16 02:42:49,530][0m Trial 9 finished with value: -2.2175817925697023 and parameters: {'iterations': 241, 'depth': 10, 'learning_rate': 0.11828933853491311, 'random_strength': 99, 'bagging_temperature': 57.54687774779825, 'od_type': 'Iter', 'od_wait': 45}. Best is trial 7 with value: -2.0189016584035113.[0m


239:	learn: 0.2368461	total: 9.18s	remaining: 38.3ms
240:	learn: 0.2345006	total: 9.23s	remaining: 0us
CPU times: user 4min 15s, sys: 23 s, total: 4min 38s
Wall time: 4min 56s


In [6]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(iterations = study.best_params['iterations'],
                                       depth = study.best_params['depth'],
                                       learning_rate = study.best_params['learning_rate'],
                                       random_strength = study.best_params['random_strength'],
                                       bagging_temperature = study.best_params['bagging_temperature'],
                                       od_type = study.best_params['od_type'],
                                       od_wait = study.best_params['od_wait']
                                      )

optimised_model.fit(x_train ,y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.1990532	total: 684us	remaining: 134ms
1:	learn: 7.6643618	total: 4.35ms	remaining: 425ms
2:	learn: 7.2272107	total: 11.2ms	remaining: 725ms
3:	learn: 6.6520430	total: 15.3ms	remaining: 736ms
4:	learn: 6.2663841	total: 20.9ms	remaining: 802ms
5:	learn: 6.0027620	total: 24.4ms	remaining: 776ms
6:	learn: 5.7039948	total: 32.1ms	remaining: 871ms
7:	learn: 5.3652695	total: 39.3ms	remaining: 927ms
8:	learn: 5.0776113	total: 46.6ms	remaining: 973ms
9:	learn: 4.8263310	total: 56.1ms	remaining: 1.05s
10:	learn: 4.6858115	total: 63.1ms	remaining: 1.07s
11:	learn: 4.4901096	total: 72.8ms	remaining: 1.12s
12:	learn: 4.3708827	total: 78.5ms	remaining: 1.11s
13:	learn: 4.2245661	total: 85.8ms	remaining: 1.12s
14:	learn: 4.1181626	total: 88.7ms	remaining: 1.07s
15:	learn: 4.0110823	total: 110ms	remaining: 1.24s
16:	learn: 3.8703520	total: 119ms	remaining: 1.25s
17:	learn: 3.7196390	total: 125ms	remaining: 1.24s
18:	learn: 3.6282867	total: 131ms	remaining: 1.22s
19:	learn: 3.5361535	total:

In [7]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame({'R2': r2_score(true, pred),
                          'MAE': mean_absolute_error(true, pred),
                          'MSE': mean_squared_error(true, pred),
                          'RMSE': np.sqrt(mean_squared_error(true, pred))},
                           index = ['scores'])
    return scores

In [8]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE        MSE      RMSE
scores  0.865739  2.541304  14.041161  3.747154
