## CatBoost-Optunaのサンプルコード

In [1]:
%load_ext lab_black

In [2]:
# ライブラリーのインポート
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline

# ボストンの住宅価格データ
from sklearn.datasets import load_boston

# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# CatBoost
import catboost as cb
from catboost import CatBoost, Pool

# Optuna
import optuna
from optuna.samplers import CmaEsSampler
from sklearn.model_selection import cross_val_score

# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [3]:
print(cb.__version__)

0.26


In [4]:
print(optuna.__version__)

2.8.0


In [5]:
# データセットの読込み
boston = load_boston()

# 説明変数の格納
df = pd.DataFrame(boston.data, columns=boston.feature_names)
# 目的変数の追加
df["MEDV"] = boston.target

# データの中身を確認
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


#### 前処理

In [6]:
# ランダムシード値
RANDOM_STATE = 10

# 学習データと評価データの割合
TEST_SIZE = 0.2

# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(
    df.iloc[:, 0 : df.shape[1] - 1],
    df.iloc[:, df.shape[1] - 1],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

In [7]:
def objective(trial):

    iterations = trial.suggest_int("iterations", 50, 300)
    depth = trial.suggest_int("depth", 4, 10)
    learning_rate = trial.suggest_loguniform("learning_rate", 0.01, 0.3)
    random_strength = trial.suggest_int("random_strength", 0, 100)
    bagging_temperature = trial.suggest_loguniform("bagging_temperature", 0.01, 100.00)
    od_type = trial.suggest_categorical("od_type", ["IncToDec", "Iter"])
    od_wait = trial.suggest_int("od_wait", 10, 50)

    model = cb.CatBoostRegressor(
        iterations=iterations,
        depth=depth,
        learning_rate=learning_rate,
        random_strength=random_strength,
        bagging_temperature=bagging_temperature,
        od_type=od_type,
        od_wait=od_wait,
    )

    score = cross_val_score(
        model, x_train, y_train, cv=5, scoring="neg_mean_absolute_error"
    )
    mae = score.mean()

    return mae

In [8]:
%%time
# optunaで最適値を見つける
# 注：cross_val_scoreの出力は全て高いほど良い
study = optuna.create_study(direction='maximize', sampler=CmaEsSampler(seed=RANDOM_STATE))
study.optimize(objective, n_trials=10)

[32m[I 2021-07-08 06:21:06,296][0m A new study created in memory with name: no-name-28612e3c-bba0-4ade-863a-de70ea8c0cd6[0m


0:	learn: 8.5922409	total: 49.8ms	remaining: 12s
1:	learn: 8.2893103	total: 51.4ms	remaining: 6.19s
2:	learn: 8.1242319	total: 54.9ms	remaining: 4.39s
3:	learn: 7.8783668	total: 56.7ms	remaining: 3.39s
4:	learn: 7.8568728	total: 58.6ms	remaining: 2.79s
5:	learn: 7.5191292	total: 60.2ms	remaining: 2.38s
6:	learn: 7.2639482	total: 61.2ms	remaining: 2.06s
7:	learn: 7.0118040	total: 62.2ms	remaining: 1.83s
8:	learn: 6.9121598	total: 63.5ms	remaining: 1.65s
9:	learn: 6.6747831	total: 64.8ms	remaining: 1.51s
10:	learn: 6.4660698	total: 66.5ms	remaining: 1.4s
11:	learn: 6.4102914	total: 67.8ms	remaining: 1.3s
12:	learn: 6.2519145	total: 69.9ms	remaining: 1.24s
13:	learn: 6.1160302	total: 71.1ms	remaining: 1.16s
14:	learn: 6.0731424	total: 72.6ms	remaining: 1.1s
15:	learn: 5.8761473	total: 74.1ms	remaining: 1.05s
16:	learn: 5.7990006	total: 75.1ms	remaining: 999ms
17:	learn: 5.7461937	total: 76.5ms	remaining: 956ms
18:	learn: 5.7123859	total: 77.8ms	remaining: 917ms
19:	learn: 5.6460952	total:

[32m[I 2021-07-08 06:21:17,042][0m Trial 0 finished with value: -2.0383364230490484 and parameters: {'iterations': 243, 'depth': 4, 'learning_rate': 0.08629294202140579, 'random_strength': 75, 'bagging_temperature': 0.986343187233007, 'od_type': 'IncToDec', 'od_wait': 41}. Best is trial 0 with value: -2.0383364230490484.[0m


240:	learn: 1.4250360	total: 728ms	remaining: 6.04ms
241:	learn: 1.4173680	total: 731ms	remaining: 3.02ms
242:	learn: 1.4151471	total: 736ms	remaining: 0us
0:	learn: 8.7075941	total: 10.7ms	remaining: 1.85s
1:	learn: 8.6347212	total: 24.3ms	remaining: 2.1s
2:	learn: 8.5291119	total: 39.5ms	remaining: 2.27s
3:	learn: 8.4285440	total: 50.5ms	remaining: 2.16s
4:	learn: 8.3562589	total: 63.7ms	remaining: 2.17s
5:	learn: 8.2461077	total: 74.1ms	remaining: 2.09s
6:	learn: 8.1478683	total: 85.5ms	remaining: 2.05s
7:	learn: 8.0855509	total: 106ms	remaining: 2.2s
8:	learn: 8.0036814	total: 129ms	remaining: 2.37s
9:	learn: 7.9129093	total: 144ms	remaining: 2.38s
10:	learn: 7.8737984	total: 148ms	remaining: 2.21s
11:	learn: 7.7880271	total: 160ms	remaining: 2.18s
12:	learn: 7.7183322	total: 165ms	remaining: 2.06s
13:	learn: 7.6354401	total: 180ms	remaining: 2.07s
14:	learn: 7.5341017	total: 183ms	remaining: 1.96s
15:	learn: 7.4812743	total: 187ms	remaining: 1.85s
16:	learn: 7.4021796	total: 202ms

[32m[I 2021-07-08 06:21:29,796][0m Trial 1 finished with value: -2.782343713054268 and parameters: {'iterations': 175, 'depth': 8, 'learning_rate': 0.028966079322940607, 'random_strength': 49, 'bagging_temperature': 0.7370362034389614, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


163:	learn: 3.0898782	total: 1.69s	remaining: 114ms
164:	learn: 3.0825264	total: 1.7s	remaining: 103ms
165:	learn: 3.0648982	total: 1.71s	remaining: 92.6ms
166:	learn: 3.0547071	total: 1.72s	remaining: 82.3ms
167:	learn: 3.0377205	total: 1.73s	remaining: 72ms
168:	learn: 3.0264765	total: 1.74s	remaining: 61.6ms
169:	learn: 3.0202795	total: 1.74s	remaining: 51.1ms
170:	learn: 3.0117722	total: 1.75s	remaining: 40.8ms
171:	learn: 2.9987338	total: 1.75s	remaining: 30.6ms
172:	learn: 2.9823503	total: 1.76s	remaining: 20.4ms
173:	learn: 2.9799026	total: 1.76s	remaining: 10.1ms
174:	learn: 2.9692683	total: 1.77s	remaining: 0us
0:	learn: 8.7573997	total: 7.22ms	remaining: 1.26s
1:	learn: 8.7027471	total: 12.1ms	remaining: 1.04s
2:	learn: 8.6510754	total: 17.3ms	remaining: 989ms
3:	learn: 8.5994648	total: 27ms	remaining: 1.15s
4:	learn: 8.5394719	total: 37.6ms	remaining: 1.28s
5:	learn: 8.4806743	total: 45.1ms	remaining: 1.27s
6:	learn: 8.4236890	total: 56.2ms	remaining: 1.35s
7:	learn: 8.36382

[32m[I 2021-07-08 06:21:37,794][0m Trial 2 finished with value: -3.3605406943633875 and parameters: {'iterations': 175, 'depth': 7, 'learning_rate': 0.015613370417884475, 'random_strength': 50, 'bagging_temperature': 0.69024383280847, 'od_type': 'Iter', 'od_wait': 29}. Best is trial 0 with value: -2.0383364230490484.[0m


161:	learn: 4.3131370	total: 894ms	remaining: 71.7ms
162:	learn: 4.3094481	total: 898ms	remaining: 66.1ms
163:	learn: 4.2971802	total: 909ms	remaining: 61ms
164:	learn: 4.2863411	total: 915ms	remaining: 55.5ms
165:	learn: 4.2819279	total: 920ms	remaining: 49.9ms
166:	learn: 4.2617518	total: 929ms	remaining: 44.5ms
167:	learn: 4.2453314	total: 932ms	remaining: 38.8ms
168:	learn: 4.2294145	total: 937ms	remaining: 33.3ms
169:	learn: 4.2254471	total: 942ms	remaining: 27.7ms
170:	learn: 4.2073162	total: 947ms	remaining: 22.1ms
171:	learn: 4.1931948	total: 952ms	remaining: 16.6ms
172:	learn: 4.1831482	total: 964ms	remaining: 11.1ms
173:	learn: 4.1691719	total: 968ms	remaining: 5.56ms
174:	learn: 4.1641233	total: 973ms	remaining: 0us
0:	learn: 8.5967532	total: 4.05ms	remaining: 706ms
1:	learn: 8.4131210	total: 8.96ms	remaining: 775ms
2:	learn: 8.2454702	total: 14.6ms	remaining: 837ms
3:	learn: 8.0820953	total: 20.8ms	remaining: 889ms
4:	learn: 7.8943982	total: 26.7ms	remaining: 908ms
5:	learn

[32m[I 2021-07-08 06:21:45,771][0m Trial 3 finished with value: -2.382976629538914 and parameters: {'iterations': 175, 'depth': 7, 'learning_rate': 0.05470111775234296, 'random_strength': 50, 'bagging_temperature': 0.7758081335702709, 'od_type': 'Iter', 'od_wait': 31}. Best is trial 0 with value: -2.0383364230490484.[0m


148:	learn: 2.2848179	total: 674ms	remaining: 118ms
149:	learn: 2.2658906	total: 676ms	remaining: 113ms
150:	learn: 2.2539276	total: 678ms	remaining: 108ms
151:	learn: 2.2482346	total: 680ms	remaining: 103ms
152:	learn: 2.2427063	total: 681ms	remaining: 97.9ms
153:	learn: 2.2303139	total: 683ms	remaining: 93.1ms
154:	learn: 2.2166341	total: 685ms	remaining: 88.4ms
155:	learn: 2.2046147	total: 687ms	remaining: 83.7ms
156:	learn: 2.1911713	total: 689ms	remaining: 78.9ms
157:	learn: 2.1777765	total: 691ms	remaining: 74.3ms
158:	learn: 2.1665154	total: 693ms	remaining: 69.7ms
159:	learn: 2.1556616	total: 694ms	remaining: 65.1ms
160:	learn: 2.1408412	total: 696ms	remaining: 60.5ms
161:	learn: 2.1284284	total: 698ms	remaining: 56ms
162:	learn: 2.1172236	total: 700ms	remaining: 51.5ms
163:	learn: 2.1081558	total: 702ms	remaining: 47.1ms
164:	learn: 2.0967895	total: 704ms	remaining: 42.6ms
165:	learn: 2.0852244	total: 706ms	remaining: 38.3ms
166:	learn: 2.0736932	total: 708ms	remaining: 33.9ms

[32m[I 2021-07-08 06:21:53,590][0m Trial 4 finished with value: -2.7699019921906562 and parameters: {'iterations': 175, 'depth': 7, 'learning_rate': 0.03069307922035012, 'random_strength': 50, 'bagging_temperature': 1.2421993474408937, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


165:	learn: 3.1153275	total: 535ms	remaining: 29ms
166:	learn: 3.0975062	total: 538ms	remaining: 25.8ms
167:	learn: 3.0826989	total: 541ms	remaining: 22.5ms
168:	learn: 3.0663281	total: 543ms	remaining: 19.3ms
169:	learn: 3.0622734	total: 545ms	remaining: 16ms
170:	learn: 3.0450166	total: 547ms	remaining: 12.8ms
171:	learn: 3.0345949	total: 553ms	remaining: 9.65ms
172:	learn: 3.0240976	total: 555ms	remaining: 6.42ms
173:	learn: 3.0130023	total: 559ms	remaining: 3.21ms
174:	learn: 3.0099448	total: 561ms	remaining: 0us
0:	learn: 8.6942507	total: 2.8ms	remaining: 487ms
1:	learn: 8.6131706	total: 5.72ms	remaining: 495ms
2:	learn: 8.4958177	total: 8.56ms	remaining: 491ms
3:	learn: 8.3844593	total: 10.9ms	remaining: 467ms
4:	learn: 8.3047830	total: 13.8ms	remaining: 470ms
5:	learn: 8.1831390	total: 17.6ms	remaining: 497ms
6:	learn: 8.0841639	total: 20.2ms	remaining: 486ms
7:	learn: 8.0155969	total: 23.4ms	remaining: 488ms
8:	learn: 7.9259848	total: 30.9ms	remaining: 570ms
9:	learn: 7.8264470

[32m[I 2021-07-08 06:22:07,506][0m Trial 5 finished with value: -2.7289487499846756 and parameters: {'iterations': 175, 'depth': 8, 'learning_rate': 0.03235828463102615, 'random_strength': 50, 'bagging_temperature': 1.190495932869843, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


174:	learn: 2.7813801	total: 3.36s	remaining: 0us
0:	learn: 8.6090910	total: 2.66ms	remaining: 463ms
1:	learn: 8.4350329	total: 12.1ms	remaining: 1.04s
2:	learn: 8.2756879	total: 35.4ms	remaining: 2.03s
3:	learn: 8.1201116	total: 44ms	remaining: 1.88s
4:	learn: 7.9412237	total: 57.1ms	remaining: 1.94s
5:	learn: 7.7723773	total: 80.8ms	remaining: 2.27s
6:	learn: 7.6130138	total: 92.7ms	remaining: 2.22s
7:	learn: 7.4508386	total: 112ms	remaining: 2.33s
8:	learn: 7.3756789	total: 127ms	remaining: 2.33s
9:	learn: 7.2850991	total: 137ms	remaining: 2.27s
10:	learn: 7.1216549	total: 151ms	remaining: 2.26s
11:	learn: 7.0147467	total: 155ms	remaining: 2.11s
12:	learn: 6.9283804	total: 169ms	remaining: 2.11s
13:	learn: 6.8557721	total: 180ms	remaining: 2.06s
14:	learn: 6.7442366	total: 182ms	remaining: 1.94s
15:	learn: 6.6558867	total: 194ms	remaining: 1.92s
16:	learn: 6.5338236	total: 208ms	remaining: 1.94s
17:	learn: 6.4559903	total: 221ms	remaining: 1.93s
18:	learn: 6.3868900	total: 225ms	rem

[32m[I 2021-07-08 06:22:21,284][0m Trial 6 finished with value: -2.4333455260549774 and parameters: {'iterations': 175, 'depth': 7, 'learning_rate': 0.051678840853704625, 'random_strength': 50, 'bagging_temperature': 0.5039679252875129, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


0:	learn: 8.6446974	total: 10.8ms	remaining: 1.86s
1:	learn: 8.4985810	total: 19.9ms	remaining: 1.71s
2:	learn: 8.3637625	total: 21.6ms	remaining: 1.23s
3:	learn: 8.2314090	total: 31.9ms	remaining: 1.35s
4:	learn: 8.0788452	total: 34.8ms	remaining: 1.18s
5:	learn: 7.9334988	total: 43.4ms	remaining: 1.22s
6:	learn: 7.7954330	total: 49.2ms	remaining: 1.17s
7:	learn: 7.6538454	total: 54.9ms	remaining: 1.14s
8:	learn: 7.5870480	total: 56.8ms	remaining: 1.04s
9:	learn: 7.5063631	total: 59.9ms	remaining: 982ms
10:	learn: 7.3625642	total: 68.9ms	remaining: 1.02s
11:	learn: 7.2663151	total: 71.8ms	remaining: 969ms
12:	learn: 7.1878568	total: 81.7ms	remaining: 1.01s
13:	learn: 7.1207727	total: 84.2ms	remaining: 963ms
14:	learn: 7.0196926	total: 95.9ms	remaining: 1.02s
15:	learn: 6.9376870	total: 101ms	remaining: 1s
16:	learn: 6.8258980	total: 110ms	remaining: 1.01s
17:	learn: 6.7532843	total: 121ms	remaining: 1.05s
18:	learn: 6.6886682	total: 125ms	remaining: 1.02s
19:	learn: 6.5959878	total: 1

[32m[I 2021-07-08 06:22:32,453][0m Trial 7 finished with value: -2.5740135911445168 and parameters: {'iterations': 174, 'depth': 7, 'learning_rate': 0.042976044935830626, 'random_strength': 50, 'bagging_temperature': 0.928447224832653, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


0:	learn: 8.7139610	total: 1.56ms	remaining: 270ms
1:	learn: 8.6235145	total: 5.77ms	remaining: 496ms
2:	learn: 8.5387944	total: 10.4ms	remaining: 594ms
3:	learn: 8.4547346	total: 15.1ms	remaining: 643ms
4:	learn: 8.3573467	total: 19.1ms	remaining: 645ms
5:	learn: 8.2629172	total: 26ms	remaining: 728ms
6:	learn: 8.1721002	total: 30.7ms	remaining: 733ms
7:	learn: 8.0775724	total: 37ms	remaining: 769ms
8:	learn: 8.0313997	total: 40.9ms	remaining: 749ms
9:	learn: 7.9753315	total: 46.2ms	remaining: 758ms
10:	learn: 7.8778975	total: 51.1ms	remaining: 758ms
11:	learn: 7.8096837	total: 58.5ms	remaining: 789ms
12:	learn: 7.7529966	total: 63.7ms	remaining: 789ms
13:	learn: 7.7029429	total: 67.9ms	remaining: 776ms
14:	learn: 7.6302446	total: 72.6ms	remaining: 769ms
15:	learn: 7.5684842	total: 77.3ms	remaining: 764ms
16:	learn: 7.4857219	total: 82.1ms	remaining: 758ms
17:	learn: 7.4303476	total: 90.4ms	remaining: 783ms
18:	learn: 7.3807207	total: 94.2ms	remaining: 768ms
19:	learn: 7.3117909	total

[32m[I 2021-07-08 06:22:43,823][0m Trial 8 finished with value: -2.8851952375498584 and parameters: {'iterations': 174, 'depth': 7, 'learning_rate': 0.026127392446553644, 'random_strength': 50, 'bagging_temperature': 0.7302894477560489, 'od_type': 'Iter', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


166:	learn: 3.3194858	total: 836ms	remaining: 35ms
167:	learn: 3.3042292	total: 846ms	remaining: 30.2ms
168:	learn: 3.2879999	total: 854ms	remaining: 25.3ms
169:	learn: 3.2840664	total: 864ms	remaining: 20.3ms
170:	learn: 3.2668330	total: 874ms	remaining: 15.3ms
171:	learn: 3.2549433	total: 885ms	remaining: 10.3ms
172:	learn: 3.2444416	total: 892ms	remaining: 5.15ms
173:	learn: 3.2327978	total: 905ms	remaining: 0us
0:	learn: 8.7034155	total: 2.29ms	remaining: 398ms
1:	learn: 8.6043814	total: 5.2ms	remaining: 450ms
2:	learn: 8.5118286	total: 9.6ms	remaining: 551ms
3:	learn: 8.4201453	total: 20ms	remaining: 853ms
4:	learn: 8.3140091	total: 29.4ms	remaining: 1s
5:	learn: 8.2113684	total: 40.6ms	remaining: 1.14s
6:	learn: 8.1128395	total: 49.4ms	remaining: 1.18s
7:	learn: 8.0105151	total: 58.8ms	remaining: 1.23s
8:	learn: 7.9607979	total: 67.8ms	remaining: 1.25s
9:	learn: 7.9004783	total: 77.9ms	remaining: 1.28s
10:	learn: 7.7952435	total: 86.9ms	remaining: 1.29s
11:	learn: 7.7220707	total

[32m[I 2021-07-08 06:22:53,737][0m Trial 9 finished with value: -2.8102495335147575 and parameters: {'iterations': 175, 'depth': 7, 'learning_rate': 0.02868590328555137, 'random_strength': 50, 'bagging_temperature': 0.2806815622908689, 'od_type': 'IncToDec', 'od_wait': 30}. Best is trial 0 with value: -2.0383364230490484.[0m


169:	learn: 3.1388626	total: 760ms	remaining: 22.4ms
170:	learn: 3.1218500	total: 765ms	remaining: 17.9ms
171:	learn: 3.1112588	total: 768ms	remaining: 13.4ms
172:	learn: 3.1007260	total: 772ms	remaining: 8.92ms
173:	learn: 3.0894482	total: 775ms	remaining: 4.46ms
174:	learn: 3.0863349	total: 781ms	remaining: 0us
CPU times: user 1min, sys: 27.6 s, total: 1min 28s
Wall time: 1min 47s


In [9]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.5545919	total: 1.05ms	remaining: 255ms
1:	learn: 8.2758365	total: 2.29ms	remaining: 276ms
2:	learn: 8.1254918	total: 3.23ms	remaining: 258ms
3:	learn: 7.8676602	total: 4.6ms	remaining: 275ms
4:	learn: 7.8437689	total: 5.29ms	remaining: 252ms
5:	learn: 7.5069089	total: 6.16ms	remaining: 243ms
6:	learn: 7.2411512	total: 7.08ms	remaining: 239ms
7:	learn: 6.9952850	total: 8.26ms	remaining: 243ms
8:	learn: 6.8964687	total: 9.07ms	remaining: 236ms
9:	learn: 6.6667356	total: 10.1ms	remaining: 235ms
10:	learn: 6.4656556	total: 10.7ms	remaining: 226ms
11:	learn: 6.4025256	total: 11.4ms	remaining: 220ms
12:	learn: 6.2475438	total: 12.1ms	remaining: 215ms
13:	learn: 6.0911106	total: 13.1ms	remaining: 215ms
14:	learn: 6.0563384	total: 13.7ms	remaining: 209ms
15:	learn: 5.8618741	total: 14.7ms	remaining: 208ms
16:	learn: 5.8004092	total: 15.4ms	remaining: 204ms
17:	learn: 5.7543304	total: 16.4ms	remaining: 204ms
18:	learn: 5.6872703	total: 17.3ms	remaining: 204ms
19:	learn: 5.6333654	to

In [10]:
# チューニングしたハイパーパラメータ
study.best_params

{'iterations': 243,
 'depth': 4,
 'learning_rate': 0.08629294202140579,
 'random_strength': 75,
 'bagging_temperature': 0.986343187233007,
 'od_type': 'IncToDec',
 'od_wait': 41}

In [11]:
# チューニングしたハイパーパラメーターをフィット
optimised_model = cb.CatBoostRegressor(
    iterations=study.best_params["iterations"],
    depth=study.best_params["depth"],
    learning_rate=study.best_params["learning_rate"],
    random_strength=study.best_params["random_strength"],
    bagging_temperature=study.best_params["bagging_temperature"],
    od_type=study.best_params["od_type"],
    od_wait=study.best_params["od_wait"],
)

optimised_model.fit(x_train, y_train)

# CatBoost推論
y_pred = optimised_model.predict(x_test)

0:	learn: 8.5545919	total: 702us	remaining: 170ms
1:	learn: 8.2758365	total: 1.37ms	remaining: 165ms
2:	learn: 8.1254918	total: 1.96ms	remaining: 157ms
3:	learn: 7.8676602	total: 2.55ms	remaining: 152ms
4:	learn: 7.8437689	total: 3.05ms	remaining: 145ms
5:	learn: 7.5069089	total: 3.63ms	remaining: 143ms
6:	learn: 7.2411512	total: 4.26ms	remaining: 144ms
7:	learn: 6.9952850	total: 4.82ms	remaining: 142ms
8:	learn: 6.8964687	total: 5.49ms	remaining: 143ms
9:	learn: 6.6667356	total: 6.15ms	remaining: 143ms
10:	learn: 6.4656556	total: 6.82ms	remaining: 144ms
11:	learn: 6.4025256	total: 7.47ms	remaining: 144ms
12:	learn: 6.2475438	total: 8.15ms	remaining: 144ms
13:	learn: 6.0911106	total: 8.76ms	remaining: 143ms
14:	learn: 6.0563384	total: 9.61ms	remaining: 146ms
15:	learn: 5.8618741	total: 10.3ms	remaining: 146ms
16:	learn: 5.8004092	total: 11.1ms	remaining: 148ms
17:	learn: 5.7543304	total: 11.9ms	remaining: 148ms
18:	learn: 5.6872703	total: 12.5ms	remaining: 147ms
19:	learn: 5.6333654	to

In [12]:
# 評価
def calculate_scores(true, pred):
    """全ての評価指標を計算する

    Parameters
    ----------
    true (np.array)       : 実測値
    pred (np.array)       : 予測値

    Returns
    -------
    scores (pd.DataFrame) : 各評価指標を纏めた結果

    """
    scores = {}
    scores = pd.DataFrame(
        {
            "R2": r2_score(true, pred),
            "MAE": mean_absolute_error(true, pred),
            "MSE": mean_squared_error(true, pred),
            "RMSE": np.sqrt(mean_squared_error(true, pred)),
        },
        index=["scores"],
    )
    return scores

In [13]:
scores = calculate_scores(y_test, y_pred)
print(scores)

              R2       MAE       MSE      RMSE
scores  0.871765  2.572099  13.41089  3.662088
