## CatBoost-Optunaのサンプルコード

In [1]:
%load_ext lab_black

In [2]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from optuna.samplers import TPESampler
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [3]:
print(cb.__version__)

0.26


In [4]:
print(optuna.__version__)

2.8.0


In [5]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns=boston.feature_names)
# 目的変数の追加
df["MEDV"] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [6]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(
    df.iloc[:, 0 : df.shape[1] - 1],
    df.iloc[:, df.shape[1] - 1],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

In [7]:
def objective(trial):

    iterations = trial.suggest_int("iterations", 50, 300)
    depth = trial.suggest_int("depth", 4, 10)
    learning_rate = trial.suggest_loguniform("learning_rate", 0.01, 0.3)
    random_strength = trial.suggest_int("random_strength", 0, 100)
    bagging_temperature = trial.suggest_loguniform("bagging_temperature", 0.01, 100.00)
    od_type = trial.suggest_categorical("od_type", ["IncToDec", "Iter"])
    od_wait = trial.suggest_int("od_wait", 10, 50)

    model = cb.CatBoostRegressor(
        iterations=iterations,
        depth=depth,
        learning_rate=learning_rate,
        random_strength=random_strength,
        bagging_temperature=bagging_temperature,
        od_type=od_type,
        od_wait=od_wait,
    )

    score = cross_val_score(
        model, x_train, y_train, cv=5, scoring="neg_mean_absolute_error"
    )
    mae = score.mean()

    return mae

In [8]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize', sampler=TPESampler(seed=RANDOM_STATE))
study.optimize(objective, n_trials=10)

[32m[I 2021-07-08 03:43:22,213][0m A new study created in memory with name: no-name-f293d2d5-f1ae-488e-9f77-9262ab133aa0[0m


0:	learn: 8.5922409	total: 47.2ms	remaining: 11.4s
1:	learn: 8.2893103	total: 47.8ms	remaining: 5.76s
2:	learn: 8.1242319	total: 48.3ms	remaining: 3.87s
3:	learn: 7.8783668	total: 48.9ms	remaining: 2.92s
4:	learn: 7.8568728	total: 49.3ms	remaining: 2.35s
5:	learn: 7.5191292	total: 49.8ms	remaining: 1.97s
6:	learn: 7.2639482	total: 50.4ms	remaining: 1.7s
7:	learn: 7.0118040	total: 51ms	remaining: 1.5s
8:	learn: 6.9121598	total: 51.6ms	remaining: 1.34s
9:	learn: 6.6747831	total: 52.1ms	remaining: 1.21s
10:	learn: 6.4660698	total: 52.6ms	remaining: 1.11s
11:	learn: 6.4102914	total: 53.2ms	remaining: 1.02s
12:	learn: 6.2519145	total: 53.7ms	remaining: 950ms
13:	learn: 6.1160302	total: 54.2ms	remaining: 887ms
14:	learn: 6.0731424	total: 54.7ms	remaining: 832ms
15:	learn: 5.8761473	total: 55.3ms	remaining: 784ms
16:	learn: 5.7990006	total: 55.8ms	remaining: 742ms
17:	learn: 5.7461937	total: 56.3ms	remaining: 704ms
18:	learn: 5.7123859	total: 56.9ms	remaining: 671ms
19:	learn: 5.6460952	total

[32m[I 2021-07-08 03:43:26,439][0m Trial 0 finished with value: -2.0383364230490484 and parameters: {'iterations': 243, 'depth': 4, 'learning_rate': 0.08629294202140579, 'random_strength': 75, 'bagging_temperature': 0.986343187233007, 'od_type': 'IncToDec', 'od_wait': 41}. Best is trial 0 with value: -2.0383364230490484.[0m


0:	learn: 8.5495398	total: 1.06ms	remaining: 96.7ms
1:	learn: 8.1937218	total: 1.75ms	remaining: 78.8ms
2:	learn: 8.0060134	total: 2.44ms	remaining: 72.4ms
3:	learn: 7.7241865	total: 3.15ms	remaining: 69.3ms
4:	learn: 7.6994018	total: 3.71ms	remaining: 64.5ms
5:	learn: 7.3085786	total: 4.48ms	remaining: 64.2ms
6:	learn: 7.0234833	total: 5.26ms	remaining: 63.9ms
7:	learn: 6.7396682	total: 5.95ms	remaining: 62.4ms
8:	learn: 6.6357289	total: 6.82ms	remaining: 62.9ms
9:	learn: 6.3787949	total: 7.56ms	remaining: 62ms
10:	learn: 6.1588691	total: 8.51ms	remaining: 62.7ms
11:	learn: 6.1051517	total: 9.6ms	remaining: 64ms
12:	learn: 5.9435438	total: 10.4ms	remaining: 63.5ms
13:	learn: 5.8054623	total: 11.3ms	remaining: 62.7ms
14:	learn: 5.7658590	total: 11.9ms	remaining: 61.3ms
15:	learn: 5.5606122	total: 12.7ms	remaining: 60.2ms
16:	learn: 5.4825938	total: 13.4ms	remaining: 59ms
17:	learn: 5.4302452	total: 14.1ms	remaining: 57.9ms
18:	learn: 5.3993389	total: 14.8ms	remaining: 56.7ms
19:	learn:

[32m[I 2021-07-08 03:43:28,083][0m Trial 1 finished with value: -2.4923876458514056 and parameters: {'iterations': 92, 'depth': 4, 'learning_rate': 0.1028867751008824, 'random_strength': 96, 'bagging_temperature': 0.01037034167132714, 'od_type': 'Iter', 'od_wait': 35}. Best is trial 0 with value: -2.0383364230490484.[0m


45:	learn: 4.1704343	total: 40ms	remaining: 40ms
46:	learn: 4.1243329	total: 41.5ms	remaining: 39.7ms
47:	learn: 4.1052837	total: 42.4ms	remaining: 38.8ms
48:	learn: 4.0634783	total: 42.9ms	remaining: 37.7ms
49:	learn: 4.0124100	total: 43.5ms	remaining: 36.6ms
50:	learn: 3.9838622	total: 44.1ms	remaining: 35.4ms
51:	learn: 3.9768166	total: 45.2ms	remaining: 34.7ms
52:	learn: 3.9050252	total: 46.6ms	remaining: 34.3ms
53:	learn: 3.8998454	total: 47.2ms	remaining: 33.2ms
54:	learn: 3.8579823	total: 47.9ms	remaining: 32.2ms
55:	learn: 3.8574914	total: 48.5ms	remaining: 31.2ms
56:	learn: 3.8539477	total: 49.3ms	remaining: 30.3ms
57:	learn: 3.8383278	total: 49.8ms	remaining: 29.2ms
58:	learn: 3.7706051	total: 50.9ms	remaining: 28.5ms
59:	learn: 3.7476149	total: 52ms	remaining: 27.7ms
60:	learn: 3.7353387	total: 53.1ms	remaining: 27ms
61:	learn: 3.6384011	total: 54.9ms	remaining: 26.5ms
62:	learn: 3.6284983	total: 56ms	remaining: 25.8ms
63:	learn: 3.6153534	total: 57.3ms	remaining: 25.1ms
64:

[32m[I 2021-07-08 03:43:34,814][0m Trial 2 finished with value: -2.185127259175691 and parameters: {'iterations': 231, 'depth': 6, 'learning_rate': 0.22681076499740013, 'random_strength': 72, 'bagging_temperature': 1.4797129411694714, 'od_type': 'Iter', 'od_wait': 37}. Best is trial 0 with value: -2.0383364230490484.[0m


0:	learn: 8.4869522	total: 13.3ms	remaining: 2.12s
1:	learn: 8.2205788	total: 22.4ms	remaining: 1.77s
2:	learn: 7.9833178	total: 25.2ms	remaining: 1.32s
3:	learn: 7.7559933	total: 34.6ms	remaining: 1.35s
4:	learn: 7.4939533	total: 43.4ms	remaining: 1.34s
5:	learn: 7.2577574	total: 52ms	remaining: 1.33s
6:	learn: 7.0389291	total: 62.9ms	remaining: 1.37s
7:	learn: 6.8221687	total: 66.6ms	remaining: 1.26s
8:	learn: 6.7273961	total: 81ms	remaining: 1.36s
9:	learn: 6.6141214	total: 89.9ms	remaining: 1.35s
10:	learn: 6.4049395	total: 99.3ms	remaining: 1.34s
11:	learn: 6.2767197	total: 105ms	remaining: 1.29s
12:	learn: 6.1757365	total: 119ms	remaining: 1.34s
13:	learn: 6.0950315	total: 122ms	remaining: 1.27s
14:	learn: 5.9634965	total: 134ms	remaining: 1.29s
15:	learn: 5.8671789	total: 147ms	remaining: 1.32s
16:	learn: 5.7283037	total: 158ms	remaining: 1.33s
17:	learn: 5.6443854	total: 171ms	remaining: 1.35s
18:	learn: 5.5691962	total: 174ms	remaining: 1.29s
19:	learn: 5.4543500	total: 182ms	

[32m[I 2021-07-08 03:43:44,134][0m Trial 3 finished with value: -2.228374945884673 and parameters: {'iterations': 160, 'depth': 7, 'learning_rate': 0.0817554539585436, 'random_strength': 51, 'bagging_temperature': 3.995661855958767, 'od_type': 'Iter', 'od_wait': 31}. Best is trial 0 with value: -2.0383364230490484.[0m


158:	learn: 1.4096377	total: 988ms	remaining: 6.21ms
159:	learn: 1.3944956	total: 997ms	remaining: 0us
0:	learn: 8.7646113	total: 1.8ms	remaining: 498ms
1:	learn: 8.7255967	total: 3.45ms	remaining: 477ms
2:	learn: 8.6663064	total: 5.06ms	remaining: 464ms
3:	learn: 8.6387246	total: 6.41ms	remaining: 439ms
4:	learn: 8.5967512	total: 7.34ms	remaining: 401ms
5:	learn: 8.5583467	total: 8.59ms	remaining: 390ms
6:	learn: 8.5205257	total: 9.71ms	remaining: 376ms
7:	learn: 8.4562335	total: 10.8ms	remaining: 364ms
8:	learn: 8.4234590	total: 11.7ms	remaining: 350ms
9:	learn: 8.3666334	total: 12.8ms	remaining: 343ms
10:	learn: 8.3171907	total: 13.7ms	remaining: 332ms
11:	learn: 8.2844629	total: 14.7ms	remaining: 325ms
12:	learn: 8.2502046	total: 15.6ms	remaining: 318ms
13:	learn: 8.2051766	total: 16.5ms	remaining: 311ms
14:	learn: 8.1578658	total: 18.2ms	remaining: 319ms
15:	learn: 8.1018063	total: 19.4ms	remaining: 317ms
16:	learn: 8.0496489	total: 20.4ms	remaining: 314ms
17:	learn: 8.0087428	tot

[32m[I 2021-07-08 03:43:51,063][0m Trial 4 finished with value: -3.0096098903450015 and parameters: {'iterations': 278, 'depth': 6, 'learning_rate': 0.01360252170284926, 'random_strength': 30, 'bagging_temperature': 0.028571789869344867, 'od_type': 'IncToDec', 'od_wait': 35}. Best is trial 0 with value: -2.0383364230490484.[0m


265:	learn: 3.8039544	total: 336ms	remaining: 15.2ms
266:	learn: 3.7954743	total: 337ms	remaining: 13.9ms
267:	learn: 3.7921248	total: 338ms	remaining: 12.6ms
268:	learn: 3.7888957	total: 339ms	remaining: 11.4ms
269:	learn: 3.7795806	total: 341ms	remaining: 10.1ms
270:	learn: 3.7679138	total: 343ms	remaining: 8.85ms
271:	learn: 3.7590532	total: 345ms	remaining: 7.61ms
272:	learn: 3.7454324	total: 346ms	remaining: 6.34ms
273:	learn: 3.7427179	total: 348ms	remaining: 5.08ms
274:	learn: 3.7346784	total: 349ms	remaining: 3.81ms
275:	learn: 3.7321462	total: 350ms	remaining: 2.54ms
276:	learn: 3.7224812	total: 351ms	remaining: 1.27ms
277:	learn: 3.7121059	total: 352ms	remaining: 0us
0:	learn: 8.7617270	total: 7.12ms	remaining: 1.32s
1:	learn: 8.7146221	total: 12.3ms	remaining: 1.13s
2:	learn: 8.6565349	total: 22.4ms	remaining: 1.37s
3:	learn: 8.5931387	total: 28.8ms	remaining: 1.31s
4:	learn: 8.5183927	total: 35.7ms	remaining: 1.3s
5:	learn: 8.4495361	total: 47.6ms	remaining: 1.43s
6:	learn:

[32m[I 2021-07-08 03:44:03,430][0m Trial 5 finished with value: -3.0630207859676393 and parameters: {'iterations': 187, 'depth': 9, 'learning_rate': 0.019672956852114582, 'random_strength': 86, 'bagging_temperature': 0.2550413260241303, 'od_type': 'IncToDec', 'od_wait': 46}. Best is trial 0 with value: -2.0383364230490484.[0m


174:	learn: 3.5353493	total: 1.57s	remaining: 108ms
175:	learn: 3.5257122	total: 1.58s	remaining: 98.5ms
176:	learn: 3.5114544	total: 1.59s	remaining: 89.7ms
177:	learn: 3.5048042	total: 1.59s	remaining: 80.5ms
178:	learn: 3.5034716	total: 1.59s	remaining: 71.2ms
179:	learn: 3.4976160	total: 1.59s	remaining: 62ms
180:	learn: 3.4860367	total: 1.6s	remaining: 53.1ms
181:	learn: 3.4685352	total: 1.61s	remaining: 44.1ms
182:	learn: 3.4629814	total: 1.61s	remaining: 35.3ms
183:	learn: 3.4607748	total: 1.61s	remaining: 26.3ms
184:	learn: 3.4468292	total: 1.62s	remaining: 17.5ms
185:	learn: 3.4340157	total: 1.63s	remaining: 8.74ms
186:	learn: 3.4298414	total: 1.63s	remaining: 0us
0:	learn: 8.6616680	total: 1.4ms	remaining: 182ms
1:	learn: 8.5497070	total: 2.5ms	remaining: 161ms
2:	learn: 8.4810405	total: 3.32ms	remaining: 142ms
3:	learn: 8.3827033	total: 4.52ms	remaining: 144ms
4:	learn: 8.2498153	total: 5.84ms	remaining: 147ms
5:	learn: 8.1589502	total: 6.67ms	remaining: 139ms
6:	learn: 8.02

[32m[I 2021-07-08 03:44:05,877][0m Trial 6 finished with value: -2.6836579532082596 and parameters: {'iterations': 131, 'depth': 5, 'learning_rate': 0.03800259814899221, 'random_strength': 9, 'bagging_temperature': 19.249640862099252, 'od_type': 'Iter', 'od_wait': 48}. Best is trial 0 with value: -2.0383364230490484.[0m


98:	learn: 3.4668382	total: 96.7ms	remaining: 31.2ms
99:	learn: 3.4446361	total: 97.2ms	remaining: 30.1ms
100:	learn: 3.4362948	total: 98ms	remaining: 29.1ms
101:	learn: 3.4152302	total: 98.6ms	remaining: 28ms
102:	learn: 3.3987077	total: 99.3ms	remaining: 27ms
103:	learn: 3.3824778	total: 99.9ms	remaining: 25.9ms
104:	learn: 3.3641483	total: 100ms	remaining: 24.9ms
105:	learn: 3.3460621	total: 101ms	remaining: 23.8ms
106:	learn: 3.3369735	total: 102ms	remaining: 22.8ms
107:	learn: 3.3193919	total: 102ms	remaining: 21.8ms
108:	learn: 3.3043545	total: 103ms	remaining: 20.8ms
109:	learn: 3.2915542	total: 104ms	remaining: 19.8ms
110:	learn: 3.2753304	total: 104ms	remaining: 18.8ms
111:	learn: 3.2589236	total: 105ms	remaining: 17.8ms
112:	learn: 3.2534182	total: 106ms	remaining: 16.9ms
113:	learn: 3.2518255	total: 106ms	remaining: 15.9ms
114:	learn: 3.2444516	total: 107ms	remaining: 14.9ms
115:	learn: 3.2408963	total: 108ms	remaining: 14ms
116:	learn: 3.2306575	total: 109ms	remaining: 13ms

[32m[I 2021-07-08 03:44:16,383][0m Trial 7 finished with value: -2.165174896630533 and parameters: {'iterations': 297, 'depth': 7, 'learning_rate': 0.166067103795112, 'random_strength': 25, 'bagging_temperature': 2.4518087630137058, 'od_type': 'IncToDec', 'od_wait': 34}. Best is trial 0 with value: -2.0383364230490484.[0m


293:	learn: 0.1390902	total: 1.38s	remaining: 14.1ms
294:	learn: 0.1372148	total: 1.38s	remaining: 9.38ms
295:	learn: 0.1361613	total: 1.39s	remaining: 4.69ms
296:	learn: 0.1347658	total: 1.39s	remaining: 0us
0:	learn: 8.7666925	total: 2.96ms	remaining: 172ms
1:	learn: 8.7290690	total: 4.42ms	remaining: 126ms
2:	learn: 8.6718915	total: 6.85ms	remaining: 128ms
3:	learn: 8.6452640	total: 8.84ms	remaining: 122ms
4:	learn: 8.6047412	total: 11.6ms	remaining: 125ms
5:	learn: 8.5676419	total: 13.9ms	remaining: 122ms
6:	learn: 8.5311098	total: 16.4ms	remaining: 122ms
7:	learn: 8.4690260	total: 18.3ms	remaining: 117ms
8:	learn: 8.4373179	total: 20.5ms	remaining: 114ms
9:	learn: 8.3823913	total: 22.7ms	remaining: 111ms
10:	learn: 8.3345957	total: 25.1ms	remaining: 110ms
11:	learn: 8.3029161	total: 26.9ms	remaining: 105ms
12:	learn: 8.2697515	total: 29.4ms	remaining: 104ms
13:	learn: 8.2261432	total: 32.4ms	remaining: 104ms
14:	learn: 8.1803332	total: 35.7ms	remaining: 105ms
15:	learn: 8.1260640	

[32m[I 2021-07-08 03:44:18,249][0m Trial 8 finished with value: -4.764836134101692 and parameters: {'iterations': 59, 'depth': 6, 'learning_rate': 0.01310986403693353, 'random_strength': 30, 'bagging_temperature': 0.21031838969395286, 'od_type': 'IncToDec', 'od_wait': 27}. Best is trial 0 with value: -2.0383364230490484.[0m


55:	learn: 6.7101914	total: 185ms	remaining: 9.88ms
56:	learn: 6.6890312	total: 188ms	remaining: 6.58ms
57:	learn: 6.6535002	total: 191ms	remaining: 3.29ms
58:	learn: 6.6215036	total: 192ms	remaining: 0us
0:	learn: 8.6809446	total: 477us	remaining: 61.1ms
1:	learn: 8.5436818	total: 4.87ms	remaining: 309ms
2:	learn: 8.4052972	total: 10.6ms	remaining: 444ms
3:	learn: 8.2675510	total: 14.8ms	remaining: 461ms
4:	learn: 8.1437394	total: 19.6ms	remaining: 487ms
5:	learn: 8.0041002	total: 24.6ms	remaining: 504ms
6:	learn: 7.8831789	total: 29.9ms	remaining: 522ms
7:	learn: 7.7744427	total: 36.3ms	remaining: 549ms
8:	learn: 7.6420837	total: 42ms	remaining: 559ms
9:	learn: 7.5293479	total: 49.2ms	remaining: 586ms
10:	learn: 7.4275725	total: 57ms	remaining: 611ms
11:	learn: 7.3354754	total: 66.6ms	remaining: 649ms
12:	learn: 7.2510757	total: 73.5ms	remaining: 656ms
13:	learn: 7.1497962	total: 81.3ms	remaining: 668ms
14:	learn: 7.0578836	total: 90.7ms	remaining: 690ms
15:	learn: 6.9564753	total: 9

[32m[I 2021-07-08 03:44:24,359][0m Trial 9 finished with value: -2.551511397410123 and parameters: {'iterations': 129, 'depth': 8, 'learning_rate': 0.03247855784527887, 'random_strength': 4, 'bagging_temperature': 33.08725191864784, 'od_type': 'Iter', 'od_wait': 27}. Best is trial 0 with value: -2.0383364230490484.[0m


128:	learn: 2.5341023	total: 644ms	remaining: 0us
CPU times: user 42.1 s, sys: 18.5 s, total: 1min
Wall time: 1min 2s


In [9]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.5545919	total: 1.34ms	remaining: 325ms
1:	learn: 8.2758365	total: 1.92ms	remaining: 231ms
2:	learn: 8.1254918	total: 2.42ms	remaining: 194ms
3:	learn: 7.8676602	total: 2.89ms	remaining: 172ms
4:	learn: 7.8437689	total: 3.29ms	remaining: 157ms
5:	learn: 7.5069089	total: 3.99ms	remaining: 158ms
6:	learn: 7.2411512	total: 4.65ms	remaining: 157ms
7:	learn: 6.9952850	total: 5.32ms	remaining: 156ms
8:	learn: 6.8964687	total: 5.92ms	remaining: 154ms
9:	learn: 6.6667356	total: 6.56ms	remaining: 153ms
10:	learn: 6.4656556	total: 7.11ms	remaining: 150ms
11:	learn: 6.4025256	total: 7.7ms	remaining: 148ms
12:	learn: 6.2475438	total: 8.29ms	remaining: 147ms
13:	learn: 6.0911106	total: 8.87ms	remaining: 145ms
14:	learn: 6.0563384	total: 9.48ms	remaining: 144ms
15:	learn: 5.8618741	total: 10.3ms	remaining: 147ms
16:	learn: 5.8004092	total: 11ms	remaining: 147ms
17:	learn: 5.7543304	total: 11.7ms	remaining: 146ms
18:	learn: 5.6872703	total: 12.3ms	remaining: 145ms
19:	learn: 5.6333654	tota

In [10]:
# チューニングしたハイパーパラメータ
study.best_params

{'iterations': 243,
 'depth': 4,
 'learning_rate': 0.08629294202140579,
 'random_strength': 75,
 'bagging_temperature': 0.986343187233007,
 'od_type': 'IncToDec',
 'od_wait': 41}

In [11]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.5545919	total: 509us	remaining: 123ms
1:	learn: 8.2758365	total: 968us	remaining: 117ms
2:	learn: 8.1254918	total: 1.41ms	remaining: 113ms
3:	learn: 7.8676602	total: 1.95ms	remaining: 117ms
4:	learn: 7.8437689	total: 2.34ms	remaining: 111ms
5:	learn: 7.5069089	total: 2.81ms	remaining: 111ms
6:	learn: 7.2411512	total: 3.33ms	remaining: 112ms
7:	learn: 6.9952850	total: 3.8ms	remaining: 112ms
8:	learn: 6.8964687	total: 4.29ms	remaining: 112ms
9:	learn: 6.6667356	total: 4.74ms	remaining: 110ms
10:	learn: 6.4656556	total: 5.19ms	remaining: 110ms
11:	learn: 6.4025256	total: 5.63ms	remaining: 108ms
12:	learn: 6.2475438	total: 6.1ms	remaining: 108ms
13:	learn: 6.0911106	total: 6.55ms	remaining: 107ms
14:	learn: 6.0563384	total: 7.02ms	remaining: 107ms
15:	learn: 5.8618741	total: 7.47ms	remaining: 106ms
16:	learn: 5.8004092	total: 8.01ms	remaining: 106ms
17:	learn: 5.7543304	total: 8.48ms	remaining: 106ms
18:	learn: 5.6872703	total: 8.98ms	remaining: 106ms
19:	learn: 5.6333654	total

In [12]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame(
        {
            "R2": r2_score(true, pred),
            "MAE": mean_absolute_error(true, pred),
            "MSE": mean_squared_error(true, pred),
            "RMSE": np.sqrt(mean_squared_error(true, pred)),
        },
        index=["scores"],
    )
    return scores

In [13]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE       MSE      RMSE
scores  0.871765  2.572099  13.41089  3.662088
