## CatBoost-Optunaのサンプルコード

In [1]:
%load_ext lab_black

In [2]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [3]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns=boston.feature_names)
# 目的変数の追加
df["MEDV"] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [4]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(
    df.iloc[:, 0 : df.shape[1] - 1],
    df.iloc[:, df.shape[1] - 1],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

In [5]:
def objective(trial):

    iterations = trial.suggest_int("iterations", 50, 300)
    depth = trial.suggest_int("depth", 4, 10)
    learning_rate = trial.suggest_loguniform("learning_rate", 0.01, 0.3)
    random_strength = trial.suggest_int("random_strength", 0, 100)
    bagging_temperature = trial.suggest_loguniform("bagging_temperature", 0.01, 100.00)
    od_type = trial.suggest_categorical("od_type", ["IncToDec", "Iter"])
    od_wait = trial.suggest_int("od_wait", 10, 50)

    model = cb.CatBoostRegressor(
        iterations=iterations,
        depth=depth,
        learning_rate=learning_rate,
        random_strength=random_strength,
        bagging_temperature=bagging_temperature,
        od_type=od_type,
        od_wait=od_wait,
    )

    score = cross_val_score(
        model, x_train, y_train, cv=5, scoring="neg_mean_absolute_error"
    )
    mae = score.mean()

    return mae

In [6]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

[32m[I 2021-05-28 15:00:47,195][0m A new study created in memory with name: no-name-31db246b-4693-49a2-be1c-b5967e7d7855[0m


0:	learn: 8.7475219	total: 68.8ms	remaining: 3.58s
1:	learn: 8.6971378	total: 73.2ms	remaining: 1.86s
2:	learn: 8.6205823	total: 79.9ms	remaining: 1.33s
3:	learn: 8.5852829	total: 89ms	remaining: 1.09s
4:	learn: 8.5315819	total: 95.5ms	remaining: 917ms
5:	learn: 8.4826872	total: 97ms	remaining: 760ms
6:	learn: 8.4344941	total: 99.3ms	remaining: 653ms
7:	learn: 8.3523818	total: 101ms	remaining: 571ms
8:	learn: 8.3111668	total: 103ms	remaining: 504ms
9:	learn: 8.2391643	total: 105ms	remaining: 451ms
10:	learn: 8.1765713	total: 106ms	remaining: 406ms
11:	learn: 8.1355926	total: 108ms	remaining: 369ms
12:	learn: 8.0927283	total: 110ms	remaining: 339ms
13:	learn: 8.0365863	total: 115ms	remaining: 319ms
14:	learn: 7.9774973	total: 117ms	remaining: 296ms
15:	learn: 7.9073579	total: 134ms	remaining: 311ms
16:	learn: 7.8422313	total: 186ms	remaining: 393ms
17:	learn: 7.7918533	total: 197ms	remaining: 383ms
18:	learn: 7.7472088	total: 206ms	remaining: 369ms
19:	learn: 7.6770582	total: 218ms	rema

[32m[I 2021-05-28 15:00:52,234][0m Trial 0 finished with value: -4.578114446355497 and parameters: {'iterations': 53, 'depth': 6, 'learning_rate': 0.017651158009228764, 'random_strength': 29, 'bagging_temperature': 0.5283333075503086, 'od_type': 'Iter', 'od_wait': 11}. Best is trial 0 with value: -4.578114446355497.[0m


50:	learn: 6.3704412	total: 304ms	remaining: 11.9ms
51:	learn: 6.3504792	total: 308ms	remaining: 5.93ms
52:	learn: 6.3145912	total: 312ms	remaining: 0us
0:	learn: 8.4605037	total: 31.3ms	remaining: 4.23s
1:	learn: 8.2538625	total: 48.8ms	remaining: 3.27s
2:	learn: 8.0091408	total: 66ms	remaining: 2.93s
3:	learn: 7.7506151	total: 85.8ms	remaining: 2.83s
4:	learn: 7.4663337	total: 100ms	remaining: 2.62s
5:	learn: 7.2086322	total: 117ms	remaining: 2.54s
6:	learn: 7.1236299	total: 123ms	remaining: 2.26s
7:	learn: 6.9115896	total: 136ms	remaining: 2.17s
8:	learn: 6.6299682	total: 153ms	remaining: 2.16s
9:	learn: 6.4858113	total: 164ms	remaining: 2.06s
10:	learn: 6.2766058	total: 181ms	remaining: 2.06s
11:	learn: 6.0661135	total: 195ms	remaining: 2.02s
12:	learn: 5.8798132	total: 207ms	remaining: 1.95s
13:	learn: 5.8482282	total: 208ms	remaining: 1.81s
14:	learn: 5.7000272	total: 232ms	remaining: 1.87s
15:	learn: 5.5184467	total: 251ms	remaining: 1.88s
16:	learn: 5.4016222	total: 270ms	remai

[32m[I 2021-05-28 15:01:11,402][0m Trial 1 finished with value: -2.394129071691389 and parameters: {'iterations': 136, 'depth': 9, 'learning_rate': 0.09328944499157514, 'random_strength': 45, 'bagging_temperature': 13.376360372058201, 'od_type': 'Iter', 'od_wait': 42}. Best is trial 1 with value: -2.394129071691389.[0m


131:	learn: 1.1997150	total: 2.17s	remaining: 65.8ms
132:	learn: 1.1872048	total: 2.19s	remaining: 49.3ms
133:	learn: 1.1770550	total: 2.2s	remaining: 32.8ms
134:	learn: 1.1714245	total: 2.22s	remaining: 16.4ms
135:	learn: 1.1528452	total: 2.23s	remaining: 0us
0:	learn: 8.2087307	total: 21.9ms	remaining: 5.74s
1:	learn: 7.7982952	total: 34.8ms	remaining: 4.54s
2:	learn: 7.3548943	total: 50.5ms	remaining: 4.38s
3:	learn: 6.8866070	total: 64.3ms	remaining: 4.16s
4:	learn: 6.4170515	total: 80.2ms	remaining: 4.14s
5:	learn: 6.0060421	total: 95.3ms	remaining: 4.08s
6:	learn: 5.9221275	total: 97.6ms	remaining: 3.57s
7:	learn: 5.6631774	total: 116ms	remaining: 3.7s
8:	learn: 5.2408144	total: 155ms	remaining: 4.38s
9:	learn: 5.0523566	total: 174ms	remaining: 4.39s
10:	learn: 4.7884987	total: 190ms	remaining: 4.36s
11:	learn: 4.5628440	total: 206ms	remaining: 4.32s
12:	learn: 4.3448560	total: 220ms	remaining: 4.23s
13:	learn: 4.3265348	total: 221ms	remaining: 3.94s
14:	learn: 4.1865383	total: 2

[32m[I 2021-05-28 15:01:48,576][0m Trial 2 finished with value: -2.349479813970581 and parameters: {'iterations': 263, 'depth': 9, 'learning_rate': 0.2111595957847497, 'random_strength': 70, 'bagging_temperature': 1.9401050666809319, 'od_type': 'Iter', 'od_wait': 15}. Best is trial 2 with value: -2.349479813970581.[0m


261:	learn: 0.0337299	total: 4.76s	remaining: 18.2ms
262:	learn: 0.0332453	total: 4.77s	remaining: 0us
0:	learn: 8.7687224	total: 36.7ms	remaining: 4.99s
1:	learn: 8.7188550	total: 68.4ms	remaining: 4.62s
2:	learn: 8.6679288	total: 90ms	remaining: 4.02s
3:	learn: 8.6022439	total: 113ms	remaining: 3.77s
4:	learn: 8.5351401	total: 134ms	remaining: 3.54s
5:	learn: 8.4648760	total: 158ms	remaining: 3.45s
6:	learn: 8.4042285	total: 182ms	remaining: 3.37s
7:	learn: 8.3471983	total: 218ms	remaining: 3.52s
8:	learn: 8.3205419	total: 221ms	remaining: 3.15s
9:	learn: 8.2632540	total: 240ms	remaining: 3.05s
10:	learn: 8.2273264	total: 264ms	remaining: 3.03s
11:	learn: 8.1672399	total: 287ms	remaining: 2.99s
12:	learn: 8.1064757	total: 311ms	remaining: 2.97s
13:	learn: 8.0527752	total: 333ms	remaining: 2.92s
14:	learn: 8.0103098	total: 364ms	remaining: 2.96s
15:	learn: 7.9545735	total: 388ms	remaining: 2.93s
16:	learn: 7.8954142	total: 392ms	remaining: 2.77s
17:	learn: 7.8393687	total: 416ms	remai

[32m[I 2021-05-28 15:02:11,900][0m Trial 3 finished with value: -3.484283046519194 and parameters: {'iterations': 137, 'depth': 10, 'learning_rate': 0.016749270670677583, 'random_strength': 94, 'bagging_temperature': 0.05665628574379606, 'od_type': 'IncToDec', 'od_wait': 18}. Best is trial 2 with value: -2.349479813970581.[0m


132:	learn: 4.3759011	total: 3.13s	remaining: 94.1ms
133:	learn: 4.3620622	total: 3.16s	remaining: 70.7ms
134:	learn: 4.3518630	total: 3.18s	remaining: 47.1ms
135:	learn: 4.3401131	total: 3.2s	remaining: 23.5ms
136:	learn: 4.3277383	total: 3.22s	remaining: 0us
0:	learn: 8.1549532	total: 1.59ms	remaining: 267ms
1:	learn: 7.7702192	total: 3.27ms	remaining: 273ms
2:	learn: 7.1846270	total: 5.13ms	remaining: 284ms
3:	learn: 6.9925071	total: 9.13ms	remaining: 377ms
4:	learn: 6.7110259	total: 11.2ms	remaining: 367ms
5:	learn: 6.4895323	total: 15.7ms	remaining: 427ms
6:	learn: 6.2470572	total: 17.3ms	remaining: 401ms
7:	learn: 5.8128021	total: 21.8ms	remaining: 438ms
8:	learn: 5.6955781	total: 23.5ms	remaining: 418ms
9:	learn: 5.4038695	total: 25.1ms	remaining: 399ms
10:	learn: 5.1552313	total: 27.5ms	remaining: 395ms
11:	learn: 5.0492234	total: 29ms	remaining: 380ms
12:	learn: 4.9235334	total: 31.5ms	remaining: 378ms
13:	learn: 4.7916021	total: 32.9ms	remaining: 364ms
14:	learn: 4.6443357	to

[32m[I 2021-05-28 15:02:21,718][0m Trial 4 finished with value: -2.1462088632112724 and parameters: {'iterations': 169, 'depth': 6, 'learning_rate': 0.1621454610773983, 'random_strength': 47, 'bagging_temperature': 9.576623495090212, 'od_type': 'Iter', 'od_wait': 41}. Best is trial 4 with value: -2.1462088632112724.[0m


163:	learn: 0.7116290	total: 513ms	remaining: 15.6ms
164:	learn: 0.7034866	total: 515ms	remaining: 12.5ms
165:	learn: 0.6977557	total: 518ms	remaining: 9.36ms
166:	learn: 0.6910427	total: 521ms	remaining: 6.24ms
167:	learn: 0.6832040	total: 525ms	remaining: 3.12ms
168:	learn: 0.6761587	total: 529ms	remaining: 0us
0:	learn: 8.7020266	total: 25ms	remaining: 3.09s
1:	learn: 8.5917337	total: 61.7ms	remaining: 3.79s
2:	learn: 8.4804391	total: 85.6ms	remaining: 3.48s
3:	learn: 8.3375779	total: 110ms	remaining: 3.34s
4:	learn: 8.1920124	total: 136ms	remaining: 3.25s
5:	learn: 8.0445795	total: 166ms	remaining: 3.29s
6:	learn: 7.9180883	total: 191ms	remaining: 3.21s
7:	learn: 7.8006744	total: 220ms	remaining: 3.22s
8:	learn: 7.7492487	total: 224ms	remaining: 2.88s
9:	learn: 7.6364527	total: 246ms	remaining: 2.83s
10:	learn: 7.5697485	total: 266ms	remaining: 2.76s
11:	learn: 7.4473160	total: 289ms	remaining: 2.72s
12:	learn: 7.3310466	total: 313ms	remaining: 2.7s
13:	learn: 7.2353104	total: 336m

[32m[I 2021-05-28 15:02:42,978][0m Trial 5 finished with value: -2.798970312587162 and parameters: {'iterations': 125, 'depth': 10, 'learning_rate': 0.037841392799911396, 'random_strength': 77, 'bagging_temperature': 0.028040086144031362, 'od_type': 'Iter', 'od_wait': 28}. Best is trial 4 with value: -2.1462088632112724.[0m


124:	learn: 2.8299419	total: 2.85s	remaining: 0us
0:	learn: 8.7322005	total: 2.67ms	remaining: 257ms
1:	learn: 8.6742097	total: 3.68ms	remaining: 175ms
2:	learn: 8.6345933	total: 5.96ms	remaining: 187ms
3:	learn: 8.5771242	total: 7.01ms	remaining: 163ms
4:	learn: 8.5285529	total: 8.75ms	remaining: 161ms
5:	learn: 8.4773138	total: 11.6ms	remaining: 177ms
6:	learn: 8.4176342	total: 14.9ms	remaining: 192ms
7:	learn: 8.3589241	total: 38.2ms	remaining: 425ms
8:	learn: 8.2876643	total: 42.8ms	remaining: 418ms
9:	learn: 8.2049898	total: 43.9ms	remaining: 382ms
10:	learn: 8.1330742	total: 45ms	remaining: 352ms
11:	learn: 8.1012271	total: 46.6ms	remaining: 330ms
12:	learn: 8.0456733	total: 49ms	remaining: 317ms
13:	learn: 7.9774781	total: 50ms	remaining: 297ms
14:	learn: 7.9261796	total: 51.5ms	remaining: 282ms
15:	learn: 7.8591088	total: 53.7ms	remaining: 272ms
16:	learn: 7.8256882	total: 55.2ms	remaining: 260ms
17:	learn: 7.7805385	total: 56.5ms	remaining: 248ms
18:	learn: 7.7511039	total: 57

[32m[I 2021-05-28 15:02:50,266][0m Trial 6 finished with value: -3.7682181367395424 and parameters: {'iterations': 97, 'depth': 5, 'learning_rate': 0.021241914311665576, 'random_strength': 54, 'bagging_temperature': 0.016889218419894617, 'od_type': 'IncToDec', 'od_wait': 35}. Best is trial 4 with value: -2.1462088632112724.[0m


72:	learn: 5.6800652	total: 88.9ms	remaining: 29.2ms
73:	learn: 5.6663007	total: 90.5ms	remaining: 28.1ms
74:	learn: 5.6182876	total: 92.1ms	remaining: 27ms
75:	learn: 5.5836955	total: 93.6ms	remaining: 25.9ms
76:	learn: 5.5642072	total: 94.7ms	remaining: 24.6ms
77:	learn: 5.5268290	total: 96.3ms	remaining: 23.5ms
78:	learn: 5.5076449	total: 97.4ms	remaining: 22.2ms
79:	learn: 5.4710953	total: 98.8ms	remaining: 21ms
80:	learn: 5.4405924	total: 99.9ms	remaining: 19.7ms
81:	learn: 5.4268996	total: 101ms	remaining: 18.4ms
82:	learn: 5.3983212	total: 102ms	remaining: 17.2ms
83:	learn: 5.3635685	total: 103ms	remaining: 15.9ms
84:	learn: 5.3306361	total: 105ms	remaining: 14.8ms
85:	learn: 5.2944806	total: 106ms	remaining: 13.5ms
86:	learn: 5.2716279	total: 107ms	remaining: 12.2ms
87:	learn: 5.2486700	total: 108ms	remaining: 11.1ms
88:	learn: 5.2185985	total: 109ms	remaining: 9.82ms
89:	learn: 5.2081395	total: 110ms	remaining: 8.58ms
90:	learn: 5.1930880	total: 111ms	remaining: 7.34ms
91:	lea

[32m[I 2021-05-28 15:02:56,498][0m Trial 7 finished with value: -3.994852074330274 and parameters: {'iterations': 121, 'depth': 6, 'learning_rate': 0.012366657132830063, 'random_strength': 34, 'bagging_temperature': 0.033435245628234765, 'od_type': 'Iter', 'od_wait': 17}. Best is trial 4 with value: -2.1462088632112724.[0m


116:	learn: 5.5665250	total: 348ms	remaining: 11.9ms
117:	learn: 5.5452501	total: 362ms	remaining: 9.19ms
118:	learn: 5.5361321	total: 374ms	remaining: 6.28ms
119:	learn: 5.5163922	total: 380ms	remaining: 3.17ms
120:	learn: 5.5076087	total: 382ms	remaining: 0us
0:	learn: 8.0294535	total: 5.99ms	remaining: 760ms
1:	learn: 7.4675206	total: 10.1ms	remaining: 639ms
2:	learn: 7.0189695	total: 14ms	remaining: 584ms
3:	learn: 6.6179760	total: 24.1ms	remaining: 746ms
4:	learn: 6.1687219	total: 29.1ms	remaining: 717ms
5:	learn: 5.8128024	total: 33.8ms	remaining: 688ms
6:	learn: 5.5109643	total: 39.6ms	remaining: 685ms
7:	learn: 5.2363254	total: 42.6ms	remaining: 639ms
8:	learn: 5.1365521	total: 49ms	remaining: 647ms
9:	learn: 5.0128615	total: 52.8ms	remaining: 622ms
10:	learn: 4.7781904	total: 55.7ms	remaining: 592ms
11:	learn: 4.6616817	total: 61.5ms	remaining: 594ms
12:	learn: 4.5683443	total: 66.5ms	remaining: 588ms
13:	learn: 4.5057925	total: 74.3ms	remaining: 605ms
14:	learn: 4.3847756	tot

[32m[I 2021-05-28 15:03:04,603][0m Trial 8 finished with value: -2.2554173941713644 and parameters: {'iterations': 128, 'depth': 7, 'learning_rate': 0.19795699801568598, 'random_strength': 42, 'bagging_temperature': 0.025858603085983166, 'od_type': 'IncToDec', 'od_wait': 34}. Best is trial 4 with value: -2.1462088632112724.[0m


124:	learn: 0.6019402	total: 565ms	remaining: 13.6ms
125:	learn: 0.5919986	total: 575ms	remaining: 9.13ms
126:	learn: 0.5874453	total: 588ms	remaining: 4.63ms
127:	learn: 0.5810956	total: 590ms	remaining: 0us
0:	learn: 8.6856553	total: 19ms	remaining: 4.66s
1:	learn: 8.5595627	total: 36.2ms	remaining: 4.42s
2:	learn: 8.4488703	total: 58.4ms	remaining: 4.73s
3:	learn: 8.3177675	total: 78.4ms	remaining: 4.74s
4:	learn: 8.1840025	total: 96ms	remaining: 4.63s
5:	learn: 8.0481569	total: 111ms	remaining: 4.42s
6:	learn: 7.9313550	total: 127ms	remaining: 4.34s
7:	learn: 7.8228535	total: 144ms	remaining: 4.27s
8:	learn: 7.7743642	total: 145ms	remaining: 3.82s
9:	learn: 7.6691052	total: 162ms	remaining: 3.81s
10:	learn: 7.6017658	total: 176ms	remaining: 3.75s
11:	learn: 7.4945924	total: 193ms	remaining: 3.76s
12:	learn: 7.3910499	total: 217ms	remaining: 3.89s
13:	learn: 7.3015815	total: 239ms	remaining: 3.96s
14:	learn: 7.1993510	total: 254ms	remaining: 3.92s
15:	learn: 7.1049198	total: 271ms	r

[32m[I 2021-05-28 15:03:37,771][0m Trial 9 finished with value: -2.3716752323583505 and parameters: {'iterations': 246, 'depth': 10, 'learning_rate': 0.034716087948196195, 'random_strength': 20, 'bagging_temperature': 3.2079668827273053, 'od_type': 'Iter', 'od_wait': 49}. Best is trial 4 with value: -2.1462088632112724.[0m


240:	learn: 1.6611746	total: 4.87s	remaining: 101ms
241:	learn: 1.6548321	total: 4.9s	remaining: 81ms
242:	learn: 1.6451596	total: 4.93s	remaining: 60.8ms
243:	learn: 1.6378092	total: 4.93s	remaining: 40.4ms
244:	learn: 1.6283470	total: 4.95s	remaining: 20.2ms
245:	learn: 1.6218378	total: 4.97s	remaining: 0us
CPU times: user 2min 30s, sys: 13.3 s, total: 2min 44s
Wall time: 2min 50s


In [7]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.0878271	total: 1.35ms	remaining: 228ms
1:	learn: 7.7317133	total: 2.94ms	remaining: 246ms
2:	learn: 7.1322833	total: 6.53ms	remaining: 361ms
3:	learn: 6.9407280	total: 9.03ms	remaining: 372ms
4:	learn: 6.6622336	total: 11.6ms	remaining: 379ms
5:	learn: 6.4466713	total: 14.6ms	remaining: 396ms
6:	learn: 6.3189913	total: 16.1ms	remaining: 373ms
7:	learn: 5.8711181	total: 17.5ms	remaining: 352ms
8:	learn: 5.7548422	total: 18.9ms	remaining: 336ms
9:	learn: 5.4081181	total: 20.3ms	remaining: 323ms
10:	learn: 5.1628842	total: 22.1ms	remaining: 317ms
11:	learn: 5.0685140	total: 23.6ms	remaining: 308ms
12:	learn: 4.9318039	total: 25ms	remaining: 300ms
13:	learn: 4.7761419	total: 26.4ms	remaining: 292ms
14:	learn: 4.6298204	total: 28ms	remaining: 287ms
15:	learn: 4.4922132	total: 30ms	remaining: 287ms
16:	learn: 4.3214346	total: 31.3ms	remaining: 280ms
17:	learn: 4.2064487	total: 33.3ms	remaining: 280ms
18:	learn: 4.0886323	total: 34.6ms	remaining: 273ms
19:	learn: 3.9145181	total: 

In [8]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame(
        {
            "R2": r2_score(true, pred),
            "MAE": mean_absolute_error(true, pred),
            "MSE": mean_squared_error(true, pred),
            "RMSE": np.sqrt(mean_squared_error(true, pred)),
        },
        index=["scores"],
    )
    return scores

In [9]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE        MSE      RMSE
scores  0.864608  2.541975  14.159432  3.762902
