## CatBoost-Optunaのサンプルコード

In [1]:
%load_ext lab_black

In [2]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [3]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns=boston.feature_names)
# 目的変数の追加
df["MEDV"] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [4]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(
    df.iloc[:, 0 : df.shape[1] - 1],
    df.iloc[:, df.shape[1] - 1],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

In [5]:
def objective(trial):

    iterations = trial.suggest_int("iterations", 50, 300)
    depth = trial.suggest_int("depth", 4, 10)
    learning_rate = trial.suggest_loguniform("learning_rate", 0.01, 0.3)
    random_strength = trial.suggest_int("random_strength", 0, 100)
    bagging_temperature = trial.suggest_loguniform("bagging_temperature", 0.01, 100.00)
    od_type = trial.suggest_categorical("od_type", ["IncToDec", "Iter"])
    od_wait = trial.suggest_int("od_wait", 10, 50)

    model = cb.CatBoostRegressor(
        iterations=iterations,
        depth=depth,
        learning_rate=learning_rate,
        random_strength=random_strength,
        bagging_temperature=bagging_temperature,
        od_type=od_type,
        od_wait=od_wait,
    )

    score = cross_val_score(
        model, x_train, y_train, cv=5, scoring="neg_mean_absolute_error"
    )
    mae = score.mean()

    return mae

In [6]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

[32m[I 2021-05-29 00:59:12,494][0m A new study created in memory with name: no-name-e3bbc88f-6a33-4bff-b308-8de46a051b62[0m


0:	learn: 8.7693502	total: 73.7ms	remaining: 20.8s
1:	learn: 8.7281049	total: 98.4ms	remaining: 13.9s
2:	learn: 8.6771608	total: 131ms	remaining: 12.3s
3:	learn: 8.6215166	total: 152ms	remaining: 10.7s
4:	learn: 8.5557783	total: 184ms	remaining: 10.3s
5:	learn: 8.4951588	total: 209ms	remaining: 9.66s
6:	learn: 8.4695853	total: 215ms	remaining: 8.49s
7:	learn: 8.4129352	total: 258ms	remaining: 8.91s
8:	learn: 8.3418025	total: 294ms	remaining: 8.98s
9:	learn: 8.3015648	total: 329ms	remaining: 9.02s
10:	learn: 8.2403188	total: 431ms	remaining: 10.7s
11:	learn: 8.1732789	total: 481ms	remaining: 10.9s
12:	learn: 8.1136806	total: 590ms	remaining: 12.3s
13:	learn: 8.0985542	total: 745ms	remaining: 14.4s
14:	learn: 8.0436243	total: 801ms	remaining: 14.4s
15:	learn: 7.9820556	total: 1.22s	remaining: 20.5s
16:	learn: 7.9323706	total: 1.3s	remaining: 20.5s
17:	learn: 7.8970682	total: 1.51s	remaining: 22.3s
18:	learn: 7.8468525	total: 1.58s	remaining: 22s
19:	learn: 7.7941979	total: 1.63s	remainin

[32m[I 2021-05-29 01:00:50,612][0m Trial 0 finished with value: -2.8375805709673605 and parameters: {'iterations': 284, 'depth': 9, 'learning_rate': 0.0171817917818281, 'random_strength': 80, 'bagging_temperature': 90.47803777549244, 'od_type': 'Iter', 'od_wait': 12}. Best is trial 0 with value: -2.8375805709673605.[0m


283:	learn: 2.9364607	total: 16s	remaining: 0us
0:	learn: 8.1447236	total: 32.5ms	remaining: 2.01s
1:	learn: 7.6785606	total: 55ms	remaining: 1.68s
2:	learn: 7.0491642	total: 77.1ms	remaining: 1.54s
3:	learn: 6.5189991	total: 94.6ms	remaining: 1.4s
4:	learn: 6.2668127	total: 113ms	remaining: 1.31s
5:	learn: 5.8303194	total: 134ms	remaining: 1.27s
6:	learn: 5.5180394	total: 154ms	remaining: 1.23s
7:	learn: 5.3363849	total: 178ms	remaining: 1.23s
8:	learn: 5.1228054	total: 199ms	remaining: 1.2s
9:	learn: 4.8532796	total: 215ms	remaining: 1.14s
10:	learn: 4.7974527	total: 226ms	remaining: 1.07s
11:	learn: 4.5943713	total: 252ms	remaining: 1.07s
12:	learn: 4.4776939	total: 263ms	remaining: 1.01s
13:	learn: 4.2986247	total: 278ms	remaining: 974ms
14:	learn: 4.2403647	total: 284ms	remaining: 910ms
15:	learn: 4.2067028	total: 302ms	remaining: 888ms
16:	learn: 4.0960455	total: 323ms	remaining: 873ms
17:	learn: 4.0937539	total: 325ms	remaining: 812ms
18:	learn: 3.9245705	total: 351ms	remaining:

[32m[I 2021-05-29 01:01:03,700][0m Trial 1 finished with value: -2.4258870671272805 and parameters: {'iterations': 63, 'depth': 8, 'learning_rate': 0.24647014178653762, 'random_strength': 77, 'bagging_temperature': 2.0273199913212046, 'od_type': 'IncToDec', 'od_wait': 29}. Best is trial 1 with value: -2.4258870671272805.[0m


61:	learn: 1.1573978	total: 3.21s	remaining: 51.7ms
62:	learn: 1.1480058	total: 3.25s	remaining: 0us
0:	learn: 7.9530901	total: 5.81ms	remaining: 1.63s
1:	learn: 7.4984528	total: 8.86ms	remaining: 1.24s
2:	learn: 7.2306484	total: 11.1ms	remaining: 1.03s
3:	learn: 6.9050198	total: 14.6ms	remaining: 1.01s
4:	learn: 6.6155809	total: 17ms	remaining: 942ms
5:	learn: 6.4059049	total: 25.1ms	remaining: 1.15s
6:	learn: 6.1387199	total: 27.3ms	remaining: 1.07s
7:	learn: 5.9581101	total: 29.3ms	remaining: 1s
8:	learn: 5.7186292	total: 34ms	remaining: 1.03s
9:	learn: 5.4528466	total: 53.9ms	remaining: 1.47s
10:	learn: 5.2555872	total: 68.3ms	remaining: 1.68s
11:	learn: 5.1620246	total: 73.3ms	remaining: 1.65s
12:	learn: 5.0046443	total: 77.1ms	remaining: 1.6s
13:	learn: 4.8661088	total: 82.8ms	remaining: 1.58s
14:	learn: 4.7351335	total: 88.4ms	remaining: 1.57s
15:	learn: 4.6130708	total: 94.7ms	remaining: 1.57s
16:	learn: 4.4106044	total: 104ms	remaining: 1.61s
17:	learn: 4.3645371	total: 107ms	

[32m[I 2021-05-29 01:01:26,825][0m Trial 2 finished with value: -2.1352924351823033 and parameters: {'iterations': 282, 'depth': 5, 'learning_rate': 0.21315734619560464, 'random_strength': 25, 'bagging_temperature': 3.425218006435679, 'od_type': 'IncToDec', 'od_wait': 23}. Best is trial 2 with value: -2.1352924351823033.[0m


275:	learn: 0.3048700	total: 1.3s	remaining: 28.2ms
276:	learn: 0.3035265	total: 1.3s	remaining: 23.5ms
277:	learn: 0.3031691	total: 1.3s	remaining: 18.8ms
278:	learn: 0.2996760	total: 1.31s	remaining: 14.1ms
279:	learn: 0.2989792	total: 1.31s	remaining: 9.35ms
280:	learn: 0.2968023	total: 1.31s	remaining: 4.67ms
281:	learn: 0.2955450	total: 1.31s	remaining: 0us
0:	learn: 7.7234882	total: 88.4ms	remaining: 8.58s
1:	learn: 7.0808570	total: 130ms	remaining: 6.23s
2:	learn: 6.5094246	total: 181ms	remaining: 5.73s
3:	learn: 5.8368957	total: 221ms	remaining: 5.2s
4:	learn: 5.3322381	total: 263ms	remaining: 4.89s
5:	learn: 4.8932797	total: 301ms	remaining: 4.61s
6:	learn: 4.5218797	total: 359ms	remaining: 4.67s
7:	learn: 4.2760506	total: 413ms	remaining: 4.65s
8:	learn: 4.2143700	total: 417ms	remaining: 4.12s
9:	learn: 4.0421559	total: 477ms	remaining: 4.2s
10:	learn: 3.9823961	total: 536ms	remaining: 4.24s
11:	learn: 3.7473770	total: 576ms	remaining: 4.13s
12:	learn: 3.5831520	total: 621ms	

[32m[I 2021-05-29 01:02:29,223][0m Trial 3 finished with value: -2.327691201612903 and parameters: {'iterations': 98, 'depth': 10, 'learning_rate': 0.2942383352490169, 'random_strength': 45, 'bagging_temperature': 0.2830167831445327, 'od_type': 'Iter', 'od_wait': 26}. Best is trial 2 with value: -2.1352924351823033.[0m


0:	learn: 8.7715133	total: 58.6ms	remaining: 4.69s
1:	learn: 8.7319348	total: 85.7ms	remaining: 3.39s
2:	learn: 8.6830270	total: 108ms	remaining: 2.81s
3:	learn: 8.6295948	total: 145ms	remaining: 2.8s
4:	learn: 8.5664338	total: 197ms	remaining: 3s
5:	learn: 8.5081743	total: 239ms	remaining: 2.99s
6:	learn: 8.4835406	total: 250ms	remaining: 2.65s
7:	learn: 8.4335194	total: 300ms	remaining: 2.74s
8:	learn: 8.3650644	total: 353ms	remaining: 2.82s
9:	learn: 8.3262998	total: 387ms	remaining: 2.75s
10:	learn: 8.2672749	total: 422ms	remaining: 2.69s
11:	learn: 8.2026100	total: 461ms	remaining: 2.65s
12:	learn: 8.1450953	total: 486ms	remaining: 2.54s
13:	learn: 8.1304549	total: 488ms	remaining: 2.33s
14:	learn: 8.0773830	total: 524ms	remaining: 2.31s
15:	learn: 8.0179495	total: 563ms	remaining: 2.29s
16:	learn: 7.9698355	total: 606ms	remaining: 2.28s
17:	learn: 7.9356125	total: 643ms	remaining: 2.25s
18:	learn: 7.8870027	total: 679ms	remaining: 2.22s
19:	learn: 7.8361032	total: 714ms	remaining

[32m[I 2021-05-29 01:02:49,893][0m Trial 4 finished with value: -4.191822733453054 and parameters: {'iterations': 81, 'depth': 9, 'learning_rate': 0.016475528368573977, 'random_strength': 86, 'bagging_temperature': 0.5705564608809537, 'od_type': 'Iter', 'od_wait': 16}. Best is trial 2 with value: -2.1352924351823033.[0m


80:	learn: 5.4683756	total: 3.31s	remaining: 0us
0:	learn: 8.2447472	total: 7.82ms	remaining: 680ms
1:	learn: 7.9030555	total: 15.8ms	remaining: 681ms
2:	learn: 7.3836816	total: 19.8ms	remaining: 562ms
3:	learn: 7.2035537	total: 23.9ms	remaining: 502ms
4:	learn: 6.9397521	total: 27.2ms	remaining: 452ms
5:	learn: 6.7260758	total: 32.4ms	remaining: 443ms
6:	learn: 6.4984644	total: 37.3ms	remaining: 431ms
7:	learn: 6.0919334	total: 46.6ms	remaining: 466ms
8:	learn: 5.9703118	total: 61.4ms	remaining: 539ms
9:	learn: 5.6865558	total: 69.6ms	remaining: 543ms
10:	learn: 5.4440873	total: 79.3ms	remaining: 555ms
11:	learn: 5.3336445	total: 89.8ms	remaining: 569ms
12:	learn: 5.2076771	total: 107ms	remaining: 620ms
13:	learn: 5.0696482	total: 123ms	remaining: 653ms
14:	learn: 4.9167014	total: 133ms	remaining: 648ms
15:	learn: 4.7913383	total: 147ms	remaining: 663ms
16:	learn: 4.6192347	total: 154ms	remaining: 644ms
17:	learn: 4.5185647	total: 164ms	remaining: 639ms
18:	learn: 4.4102914	total: 170

[32m[I 2021-05-29 01:02:58,116][0m Trial 5 finished with value: -2.3028413612479244 and parameters: {'iterations': 88, 'depth': 6, 'learning_rate': 0.13967548539848726, 'random_strength': 40, 'bagging_temperature': 0.053936141029853044, 'od_type': 'Iter', 'od_wait': 46}. Best is trial 2 with value: -2.1352924351823033.[0m


85:	learn: 1.6038089	total: 1.12s	remaining: 26.1ms
86:	learn: 1.5941681	total: 1.13s	remaining: 13ms
87:	learn: 1.5790108	total: 1.14s	remaining: 0us
0:	learn: 8.7616870	total: 11.7ms	remaining: 1.87s
1:	learn: 8.7224233	total: 26.9ms	remaining: 2.12s
2:	learn: 8.6954505	total: 42.2ms	remaining: 2.21s
3:	learn: 8.6560889	total: 45.3ms	remaining: 1.77s
4:	learn: 8.6151158	total: 49ms	remaining: 1.52s
5:	learn: 8.5797550	total: 52.4ms	remaining: 1.34s
6:	learn: 8.5370199	total: 56.6ms	remaining: 1.24s
7:	learn: 8.4962440	total: 64.5ms	remaining: 1.23s
8:	learn: 8.4465080	total: 69.4ms	remaining: 1.16s
9:	learn: 8.3889112	total: 76.7ms	remaining: 1.15s
10:	learn: 8.3380463	total: 79.2ms	remaining: 1.07s
11:	learn: 8.3154815	total: 86.6ms	remaining: 1.07s
12:	learn: 8.2761317	total: 92.9ms	remaining: 1.05s
13:	learn: 8.2274683	total: 98.3ms	remaining: 1.02s
14:	learn: 8.1898868	total: 104ms	remaining: 1s
15:	learn: 8.1413801	total: 112ms	remaining: 1.01s
16:	learn: 8.1164692	total: 118ms	

[32m[I 2021-05-29 01:03:17,953][0m Trial 6 finished with value: -3.6082454006746483 and parameters: {'iterations': 160, 'depth': 5, 'learning_rate': 0.014265242739615158, 'random_strength': 43, 'bagging_temperature': 0.38640015146213635, 'od_type': 'Iter', 'od_wait': 38}. Best is trial 2 with value: -2.1352924351823033.[0m


0:	learn: 8.6349540	total: 2.15ms	remaining: 246ms
1:	learn: 8.4811674	total: 4.67ms	remaining: 264ms
2:	learn: 8.3894058	total: 7.36ms	remaining: 275ms
3:	learn: 8.2561363	total: 10.2ms	remaining: 282ms
4:	learn: 8.2444442	total: 13.3ms	remaining: 292ms
5:	learn: 8.0672622	total: 16.7ms	remaining: 304ms
6:	learn: 7.9210725	total: 21.1ms	remaining: 325ms
7:	learn: 7.7793812	total: 25.4ms	remaining: 340ms
8:	learn: 7.7108547	total: 29.5ms	remaining: 347ms
9:	learn: 7.5625925	total: 31.6ms	remaining: 331ms
10:	learn: 7.4119839	total: 37.3ms	remaining: 353ms
11:	learn: 7.3653587	total: 43.1ms	remaining: 370ms
12:	learn: 7.2508966	total: 46.8ms	remaining: 368ms
13:	learn: 7.1492705	total: 50ms	remaining: 361ms
14:	learn: 7.1089759	total: 52.1ms	remaining: 348ms
15:	learn: 6.9707495	total: 53.1ms	remaining: 329ms
16:	learn: 6.9119286	total: 54.5ms	remaining: 314ms
17:	learn: 6.8694146	total: 57.2ms	remaining: 308ms
18:	learn: 6.8348747	total: 60.6ms	remaining: 306ms
19:	learn: 6.7683402	tot

[32m[I 2021-05-29 01:03:26,976][0m Trial 7 finished with value: -2.913945736725042 and parameters: {'iterations': 115, 'depth': 4, 'learning_rate': 0.042383214291235094, 'random_strength': 34, 'bagging_temperature': 37.087809095485, 'od_type': 'IncToDec', 'od_wait': 37}. Best is trial 2 with value: -2.1352924351823033.[0m


110:	learn: 3.9480760	total: 481ms	remaining: 17.3ms
111:	learn: 3.9453364	total: 489ms	remaining: 13.1ms
112:	learn: 3.9085892	total: 490ms	remaining: 8.68ms
113:	learn: 3.8999265	total: 494ms	remaining: 4.33ms
114:	learn: 3.8861384	total: 496ms	remaining: 0us
0:	learn: 8.5600131	total: 31.2ms	remaining: 9.29s
1:	learn: 8.3245074	total: 70.4ms	remaining: 10.4s
2:	learn: 8.1189789	total: 101ms	remaining: 9.98s
3:	learn: 7.9249924	total: 124ms	remaining: 9.15s
4:	learn: 7.7071365	total: 154ms	remaining: 9.04s
5:	learn: 7.4607625	total: 179ms	remaining: 8.73s
6:	learn: 7.2855120	total: 188ms	remaining: 7.84s
7:	learn: 7.1242839	total: 247ms	remaining: 8.98s
8:	learn: 6.9079575	total: 283ms	remaining: 9.13s
9:	learn: 6.7938836	total: 308ms	remaining: 8.91s
10:	learn: 6.5969155	total: 342ms	remaining: 8.96s
11:	learn: 6.4228388	total: 366ms	remaining: 8.76s
12:	learn: 6.2671161	total: 438ms	remaining: 9.64s
13:	learn: 6.1736777	total: 449ms	remaining: 9.14s
14:	learn: 6.0198542	total: 480m

[32m[I 2021-05-29 01:04:34,629][0m Trial 8 finished with value: -2.1836195988136273 and parameters: {'iterations': 299, 'depth': 9, 'learning_rate': 0.06728993578612004, 'random_strength': 15, 'bagging_temperature': 21.616946862185006, 'od_type': 'Iter', 'od_wait': 25}. Best is trial 2 with value: -2.1352924351823033.[0m


294:	learn: 0.4884072	total: 10.9s	remaining: 147ms
295:	learn: 0.4852303	total: 10.9s	remaining: 110ms
296:	learn: 0.4817034	total: 10.9s	remaining: 73.4ms
297:	learn: 0.4784301	total: 10.9s	remaining: 36.7ms
298:	learn: 0.4762602	total: 10.9s	remaining: 0us
0:	learn: 8.1897347	total: 13.5ms	remaining: 2.13s
1:	learn: 7.7690461	total: 43.8ms	remaining: 3.44s
2:	learn: 7.3168838	total: 69.8ms	remaining: 3.63s
3:	learn: 6.8392991	total: 128ms	remaining: 4.96s
4:	learn: 6.3632634	total: 156ms	remaining: 4.82s
5:	learn: 5.9469142	total: 184ms	remaining: 4.7s
6:	learn: 5.8642426	total: 193ms	remaining: 4.18s
7:	learn: 5.6042678	total: 229ms	remaining: 4.32s
8:	learn: 5.1779035	total: 275ms	remaining: 4.58s
9:	learn: 4.9888690	total: 294ms	remaining: 4.38s
10:	learn: 4.7250926	total: 316ms	remaining: 4.25s
11:	learn: 4.5014871	total: 336ms	remaining: 4.12s
12:	learn: 4.2834662	total: 362ms	remaining: 4.06s
13:	learn: 4.2656466	total: 365ms	remaining: 3.78s
14:	learn: 4.1272933	total: 380ms	

[32m[I 2021-05-29 01:05:09,372][0m Trial 9 finished with value: -2.386212744551508 and parameters: {'iterations': 159, 'depth': 9, 'learning_rate': 0.2181972659782335, 'random_strength': 74, 'bagging_temperature': 0.02054765333767902, 'od_type': 'IncToDec', 'od_wait': 38}. Best is trial 2 with value: -2.1352924351823033.[0m


155:	learn: 0.2021918	total: 3.83s	remaining: 73.8ms
156:	learn: 0.1986208	total: 3.87s	remaining: 49.3ms
157:	learn: 0.1934581	total: 3.89s	remaining: 24.6ms
158:	learn: 0.1882613	total: 3.91s	remaining: 0us
CPU times: user 2min 28s, sys: 25.1 s, total: 2min 53s
Wall time: 5min 56s


In [7]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 7.8798080	total: 1.9ms	remaining: 534ms
1:	learn: 7.4827022	total: 3.58ms	remaining: 502ms
2:	learn: 7.2242465	total: 8.86ms	remaining: 824ms
3:	learn: 6.9364951	total: 10.9ms	remaining: 757ms
4:	learn: 6.6810219	total: 14.2ms	remaining: 789ms
5:	learn: 6.4374952	total: 16.1ms	remaining: 743ms
6:	learn: 6.1396703	total: 18.9ms	remaining: 742ms
7:	learn: 5.9755073	total: 20.8ms	remaining: 711ms
8:	learn: 5.6799235	total: 22.8ms	remaining: 693ms
9:	learn: 5.4349898	total: 25ms	remaining: 680ms
10:	learn: 5.2570654	total: 27.4ms	remaining: 676ms
11:	learn: 5.1677522	total: 29.1ms	remaining: 654ms
12:	learn: 5.0101363	total: 37ms	remaining: 766ms
13:	learn: 4.8927956	total: 39.3ms	remaining: 752ms
14:	learn: 4.7692774	total: 42.7ms	remaining: 760ms
15:	learn: 4.6294900	total: 44.7ms	remaining: 743ms
16:	learn: 4.4102814	total: 46.5ms	remaining: 725ms
17:	learn: 4.3582896	total: 48.4ms	remaining: 709ms
18:	learn: 4.3183086	total: 50.3ms	remaining: 696ms
19:	learn: 4.2879303	total:

In [8]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame(
        {
            "R2": r2_score(true, pred),
            "MAE": mean_absolute_error(true, pred),
            "MSE": mean_squared_error(true, pred),
            "RMSE": np.sqrt(mean_squared_error(true, pred)),
        },
        index=["scores"],
    )
    return scores

In [9]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE        MSE      RMSE
scores  0.879226  2.486826  12.630648  3.553962
