In [1]:
import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import sweetviz

from sklearn.model_selection import train_test_split
from nflows.distributions import ConditionalDiagonalNormal

from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll

from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

RANDOM_SEED = 10

setup_random_seed(RANDOM_SEED)

In [2]:
df = pd.read_csv('data/01_raw/CatData/avocado/avocado.csv', index_col=0)

In [3]:
# analysis = sweetviz.analyze(df)
# analysis.show_notebook()

In [4]:
x = df.drop(columns = ['Date', 'AveragePrice'])
y = df[['AveragePrice']]

In [5]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=RANDOM_SEED)
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=RANDOM_SEED)

In [6]:
x_train.shape, x_test.shape

((14599, 11), (3650, 11))

In [7]:
model = catboost.CatBoostRegressor(
    cat_features=['type', 'year', 'region'], 
    loss_function="RMSEWithUncertainty",
    num_trees=2000,
    random_state=RANDOM_SEED
)

In [8]:
%time model.fit(x_tr, y_tr, eval_set=(x_val, y_val))

0:	learn: 0.4836053	test: 0.4981347	best: 0.4981347 (0)	total: 65.5ms	remaining: 2m 10s
1:	learn: 0.4638331	test: 0.4774730	best: 0.4774730 (1)	total: 70ms	remaining: 1m 9s
2:	learn: 0.4413920	test: 0.4535651	best: 0.4535651 (2)	total: 73.3ms	remaining: 48.8s
3:	learn: 0.4222889	test: 0.4326436	best: 0.4326436 (3)	total: 76.7ms	remaining: 38.3s
4:	learn: 0.4061859	test: 0.4161256	best: 0.4161256 (4)	total: 79.8ms	remaining: 31.9s
5:	learn: 0.3900086	test: 0.3989042	best: 0.3989042 (5)	total: 82.8ms	remaining: 27.5s
6:	learn: 0.3749797	test: 0.3831905	best: 0.3831905 (6)	total: 85.7ms	remaining: 24.4s
7:	learn: 0.3599941	test: 0.3676352	best: 0.3676352 (7)	total: 89ms	remaining: 22.2s
8:	learn: 0.3454651	test: 0.3529090	best: 0.3529090 (8)	total: 92.1ms	remaining: 20.4s
9:	learn: 0.3319585	test: 0.3392962	best: 0.3392962 (9)	total: 95ms	remaining: 18.9s
10:	learn: 0.3198880	test: 0.3268236	best: 0.3268236 (10)	total: 98.2ms	remaining: 17.8s
11:	learn: 0.3074114	test: 0.3141876	best: 0.3

120:	learn: -0.1456095	test: -0.1326208	best: -0.1326208 (120)	total: 467ms	remaining: 7.25s
121:	learn: -0.1471977	test: -0.1343623	best: -0.1343623 (121)	total: 469ms	remaining: 7.23s
122:	learn: -0.1482040	test: -0.1350198	best: -0.1350198 (122)	total: 472ms	remaining: 7.21s
123:	learn: -0.1496806	test: -0.1361855	best: -0.1361855 (123)	total: 475ms	remaining: 7.19s
124:	learn: -0.1505385	test: -0.1369387	best: -0.1369387 (124)	total: 479ms	remaining: 7.18s
125:	learn: -0.1516179	test: -0.1380804	best: -0.1380804 (125)	total: 482ms	remaining: 7.17s
126:	learn: -0.1528461	test: -0.1391746	best: -0.1391746 (126)	total: 486ms	remaining: 7.16s
127:	learn: -0.1533677	test: -0.1395042	best: -0.1395042 (127)	total: 489ms	remaining: 7.15s
128:	learn: -0.1541394	test: -0.1402507	best: -0.1402507 (128)	total: 493ms	remaining: 7.14s
129:	learn: -0.1559282	test: -0.1417686	best: -0.1417686 (129)	total: 496ms	remaining: 7.14s
130:	learn: -0.1575808	test: -0.1432639	best: -0.1432639 (130)	total: 

238:	learn: -0.2384821	test: -0.2051667	best: -0.2051667 (238)	total: 863ms	remaining: 6.36s
239:	learn: -0.2391431	test: -0.2056714	best: -0.2056714 (239)	total: 867ms	remaining: 6.35s
240:	learn: -0.2397284	test: -0.2060848	best: -0.2060848 (240)	total: 870ms	remaining: 6.35s
241:	learn: -0.2407043	test: -0.2068085	best: -0.2068085 (241)	total: 873ms	remaining: 6.34s
242:	learn: -0.2407234	test: -0.2067721	best: -0.2068085 (241)	total: 875ms	remaining: 6.32s
243:	learn: -0.2407356	test: -0.2067338	best: -0.2068085 (241)	total: 876ms	remaining: 6.3s
244:	learn: -0.2413952	test: -0.2073534	best: -0.2073534 (244)	total: 880ms	remaining: 6.3s
245:	learn: -0.2416901	test: -0.2075188	best: -0.2075188 (245)	total: 883ms	remaining: 6.29s
246:	learn: -0.2419480	test: -0.2077830	best: -0.2077830 (246)	total: 887ms	remaining: 6.29s
247:	learn: -0.2430962	test: -0.2087992	best: -0.2087992 (247)	total: 890ms	remaining: 6.29s
248:	learn: -0.2436737	test: -0.2091463	best: -0.2091463 (248)	total: 89

355:	learn: -0.3051916	test: -0.2548020	best: -0.2548020 (355)	total: 1.26s	remaining: 5.83s
356:	learn: -0.3057392	test: -0.2552369	best: -0.2552369 (356)	total: 1.27s	remaining: 5.83s
357:	learn: -0.3062574	test: -0.2554854	best: -0.2554854 (357)	total: 1.27s	remaining: 5.83s
358:	learn: -0.3067164	test: -0.2556804	best: -0.2556804 (358)	total: 1.27s	remaining: 5.82s
359:	learn: -0.3069626	test: -0.2558158	best: -0.2558158 (359)	total: 1.28s	remaining: 5.82s
360:	learn: -0.3074234	test: -0.2561920	best: -0.2561920 (360)	total: 1.28s	remaining: 5.81s
361:	learn: -0.3077549	test: -0.2564295	best: -0.2564295 (361)	total: 1.28s	remaining: 5.81s
362:	learn: -0.3085707	test: -0.2566783	best: -0.2566783 (362)	total: 1.29s	remaining: 5.81s
363:	learn: -0.3087620	test: -0.2567229	best: -0.2567229 (363)	total: 1.29s	remaining: 5.8s
364:	learn: -0.3091741	test: -0.2570940	best: -0.2570940 (364)	total: 1.29s	remaining: 5.8s
365:	learn: -0.3096078	test: -0.2574814	best: -0.2574814 (365)	total: 1.

465:	learn: -0.3647304	test: -0.2962068	best: -0.2962068 (465)	total: 1.66s	remaining: 5.46s
466:	learn: -0.3651372	test: -0.2964273	best: -0.2964273 (466)	total: 1.66s	remaining: 5.46s
467:	learn: -0.3659455	test: -0.2969832	best: -0.2969832 (467)	total: 1.67s	remaining: 5.46s
468:	learn: -0.3666758	test: -0.2976377	best: -0.2976377 (468)	total: 1.67s	remaining: 5.46s
469:	learn: -0.3670268	test: -0.2979884	best: -0.2979884 (469)	total: 1.68s	remaining: 5.45s
470:	learn: -0.3673945	test: -0.2982164	best: -0.2982164 (470)	total: 1.68s	remaining: 5.45s
471:	learn: -0.3677771	test: -0.2984660	best: -0.2984660 (471)	total: 1.68s	remaining: 5.45s
472:	learn: -0.3683950	test: -0.2987948	best: -0.2987948 (472)	total: 1.69s	remaining: 5.44s
473:	learn: -0.3688264	test: -0.2991552	best: -0.2991552 (473)	total: 1.69s	remaining: 5.44s
474:	learn: -0.3693459	test: -0.2996640	best: -0.2996640 (474)	total: 1.69s	remaining: 5.43s
475:	learn: -0.3695827	test: -0.2997894	best: -0.2997894 (475)	total: 

583:	learn: -0.4087653	test: -0.3223087	best: -0.3223087 (583)	total: 2.06s	remaining: 4.99s
584:	learn: -0.4091044	test: -0.3224126	best: -0.3224126 (584)	total: 2.06s	remaining: 4.99s
585:	learn: -0.4094276	test: -0.3227431	best: -0.3227431 (585)	total: 2.06s	remaining: 4.98s
586:	learn: -0.4097289	test: -0.3228699	best: -0.3228699 (586)	total: 2.07s	remaining: 4.98s
587:	learn: -0.4101694	test: -0.3233663	best: -0.3233663 (587)	total: 2.07s	remaining: 4.98s
588:	learn: -0.4106593	test: -0.3235298	best: -0.3235298 (588)	total: 2.08s	remaining: 4.98s
589:	learn: -0.4110257	test: -0.3235104	best: -0.3235298 (588)	total: 2.08s	remaining: 4.98s
590:	learn: -0.4113734	test: -0.3237529	best: -0.3237529 (590)	total: 2.09s	remaining: 4.97s
591:	learn: -0.4115585	test: -0.3238487	best: -0.3238487 (591)	total: 2.09s	remaining: 4.97s
592:	learn: -0.4119969	test: -0.3240582	best: -0.3240582 (592)	total: 2.09s	remaining: 4.96s
593:	learn: -0.4123589	test: -0.3239496	best: -0.3240582 (592)	total: 

702:	learn: -0.4456346	test: -0.3408793	best: -0.3408793 (702)	total: 2.46s	remaining: 4.54s
703:	learn: -0.4459415	test: -0.3409523	best: -0.3409523 (703)	total: 2.46s	remaining: 4.53s
704:	learn: -0.4461878	test: -0.3408912	best: -0.3409523 (703)	total: 2.46s	remaining: 4.53s
705:	learn: -0.4464009	test: -0.3408648	best: -0.3409523 (703)	total: 2.47s	remaining: 4.52s
706:	learn: -0.4467945	test: -0.3409459	best: -0.3409523 (703)	total: 2.47s	remaining: 4.52s
707:	learn: -0.4475722	test: -0.3417624	best: -0.3417624 (707)	total: 2.48s	remaining: 4.52s
708:	learn: -0.4477950	test: -0.3417627	best: -0.3417627 (708)	total: 2.48s	remaining: 4.51s
709:	learn: -0.4478919	test: -0.3418016	best: -0.3418016 (709)	total: 2.48s	remaining: 4.51s
710:	learn: -0.4485718	test: -0.3425062	best: -0.3425062 (710)	total: 2.48s	remaining: 4.51s
711:	learn: -0.4488361	test: -0.3425760	best: -0.3425760 (711)	total: 2.49s	remaining: 4.5s
712:	learn: -0.4490134	test: -0.3426386	best: -0.3426386 (712)	total: 2

818:	learn: -0.4818566	test: -0.3578655	best: -0.3578711 (816)	total: 2.86s	remaining: 4.12s
819:	learn: -0.4821588	test: -0.3579201	best: -0.3579201 (819)	total: 2.86s	remaining: 4.12s
820:	learn: -0.4824711	test: -0.3582222	best: -0.3582222 (820)	total: 2.86s	remaining: 4.11s
821:	learn: -0.4826621	test: -0.3582695	best: -0.3582695 (821)	total: 2.87s	remaining: 4.11s
822:	learn: -0.4831338	test: -0.3583810	best: -0.3583810 (822)	total: 2.87s	remaining: 4.11s
823:	learn: -0.4832890	test: -0.3584548	best: -0.3584548 (823)	total: 2.88s	remaining: 4.1s
824:	learn: -0.4836085	test: -0.3585527	best: -0.3585527 (824)	total: 2.88s	remaining: 4.1s
825:	learn: -0.4838995	test: -0.3585700	best: -0.3585700 (825)	total: 2.88s	remaining: 4.1s
826:	learn: -0.4842623	test: -0.3588061	best: -0.3588061 (826)	total: 2.89s	remaining: 4.09s
827:	learn: -0.4844273	test: -0.3588179	best: -0.3588179 (827)	total: 2.89s	remaining: 4.09s
828:	learn: -0.4846710	test: -0.3588785	best: -0.3588785 (828)	total: 2.8

932:	learn: -0.5129949	test: -0.3711926	best: -0.3711926 (932)	total: 3.26s	remaining: 3.72s
933:	learn: -0.5130947	test: -0.3712914	best: -0.3712914 (933)	total: 3.26s	remaining: 3.72s
934:	learn: -0.5133643	test: -0.3712815	best: -0.3712914 (933)	total: 3.26s	remaining: 3.72s
935:	learn: -0.5136080	test: -0.3713709	best: -0.3713709 (935)	total: 3.27s	remaining: 3.71s
936:	learn: -0.5138253	test: -0.3715280	best: -0.3715280 (936)	total: 3.27s	remaining: 3.71s
937:	learn: -0.5139562	test: -0.3715538	best: -0.3715538 (937)	total: 3.27s	remaining: 3.71s
938:	learn: -0.5144065	test: -0.3722388	best: -0.3722388 (938)	total: 3.28s	remaining: 3.7s
939:	learn: -0.5147013	test: -0.3724618	best: -0.3724618 (939)	total: 3.28s	remaining: 3.7s
940:	learn: -0.5149502	test: -0.3725355	best: -0.3725355 (940)	total: 3.28s	remaining: 3.69s
941:	learn: -0.5150814	test: -0.3726389	best: -0.3726389 (941)	total: 3.29s	remaining: 3.69s
942:	learn: -0.5152242	test: -0.3725895	best: -0.3726389 (941)	total: 3.

1046:	learn: -0.5422780	test: -0.3874503	best: -0.3874503 (1046)	total: 3.65s	remaining: 3.33s
1047:	learn: -0.5424844	test: -0.3874878	best: -0.3874878 (1047)	total: 3.66s	remaining: 3.32s
1048:	learn: -0.5428619	test: -0.3876362	best: -0.3876362 (1048)	total: 3.66s	remaining: 3.32s
1049:	learn: -0.5430116	test: -0.3876988	best: -0.3876988 (1049)	total: 3.67s	remaining: 3.32s
1050:	learn: -0.5431661	test: -0.3876669	best: -0.3876988 (1049)	total: 3.67s	remaining: 3.31s
1051:	learn: -0.5435460	test: -0.3878897	best: -0.3878897 (1051)	total: 3.67s	remaining: 3.31s
1052:	learn: -0.5436239	test: -0.3877975	best: -0.3878897 (1051)	total: 3.68s	remaining: 3.31s
1053:	learn: -0.5437986	test: -0.3878298	best: -0.3878897 (1051)	total: 3.68s	remaining: 3.3s
1054:	learn: -0.5438625	test: -0.3877917	best: -0.3878897 (1051)	total: 3.68s	remaining: 3.3s
1055:	learn: -0.5440767	test: -0.3878801	best: -0.3878897 (1051)	total: 3.69s	remaining: 3.3s
1056:	learn: -0.5443791	test: -0.3879419	best: -0.387

1164:	learn: -0.5617284	test: -0.3925202	best: -0.3925202 (1164)	total: 4.05s	remaining: 2.9s
1165:	learn: -0.5618501	test: -0.3925342	best: -0.3925342 (1165)	total: 4.06s	remaining: 2.9s
1166:	learn: -0.5619782	test: -0.3926474	best: -0.3926474 (1166)	total: 4.06s	remaining: 2.9s
1167:	learn: -0.5620905	test: -0.3927087	best: -0.3927087 (1167)	total: 4.06s	remaining: 2.89s
1168:	learn: -0.5621716	test: -0.3926297	best: -0.3927087 (1167)	total: 4.07s	remaining: 2.89s
1169:	learn: -0.5622499	test: -0.3926867	best: -0.3927087 (1167)	total: 4.07s	remaining: 2.89s
1170:	learn: -0.5624157	test: -0.3928027	best: -0.3928027 (1170)	total: 4.07s	remaining: 2.88s
1171:	learn: -0.5626592	test: -0.3930798	best: -0.3930798 (1171)	total: 4.08s	remaining: 2.88s
1172:	learn: -0.5627438	test: -0.3931348	best: -0.3931348 (1172)	total: 4.08s	remaining: 2.88s
1173:	learn: -0.5628411	test: -0.3930331	best: -0.3931348 (1172)	total: 4.08s	remaining: 2.87s
1174:	learn: -0.5630307	test: -0.3930663	best: -0.393

1280:	learn: -0.5771103	test: -0.3957045	best: -0.3960929 (1265)	total: 4.45s	remaining: 2.5s
1281:	learn: -0.5773066	test: -0.3955978	best: -0.3960929 (1265)	total: 4.45s	remaining: 2.49s
1282:	learn: -0.5775668	test: -0.3956065	best: -0.3960929 (1265)	total: 4.46s	remaining: 2.49s
1283:	learn: -0.5776392	test: -0.3955549	best: -0.3960929 (1265)	total: 4.46s	remaining: 2.49s
1284:	learn: -0.5778672	test: -0.3957160	best: -0.3960929 (1265)	total: 4.46s	remaining: 2.48s
1285:	learn: -0.5779582	test: -0.3955858	best: -0.3960929 (1265)	total: 4.47s	remaining: 2.48s
1286:	learn: -0.5781396	test: -0.3957426	best: -0.3960929 (1265)	total: 4.47s	remaining: 2.48s
1287:	learn: -0.5782921	test: -0.3958053	best: -0.3960929 (1265)	total: 4.47s	remaining: 2.47s
1288:	learn: -0.5784661	test: -0.3956043	best: -0.3960929 (1265)	total: 4.48s	remaining: 2.47s
1289:	learn: -0.5787009	test: -0.3957052	best: -0.3960929 (1265)	total: 4.48s	remaining: 2.47s
1290:	learn: -0.5789147	test: -0.3958622	best: -0.3

1396:	learn: -0.5945456	test: -0.3979338	best: -0.3979338 (1396)	total: 4.85s	remaining: 2.09s
1397:	learn: -0.5947491	test: -0.3978718	best: -0.3979338 (1396)	total: 4.85s	remaining: 2.09s
1398:	learn: -0.5948824	test: -0.3978340	best: -0.3979338 (1396)	total: 4.85s	remaining: 2.08s
1399:	learn: -0.5949735	test: -0.3976803	best: -0.3979338 (1396)	total: 4.86s	remaining: 2.08s
1400:	learn: -0.5951265	test: -0.3977941	best: -0.3979338 (1396)	total: 4.86s	remaining: 2.08s
1401:	learn: -0.5953363	test: -0.3976748	best: -0.3979338 (1396)	total: 4.86s	remaining: 2.07s
1402:	learn: -0.5954260	test: -0.3976586	best: -0.3979338 (1396)	total: 4.87s	remaining: 2.07s
1403:	learn: -0.5956631	test: -0.3976830	best: -0.3979338 (1396)	total: 4.87s	remaining: 2.07s
1404:	learn: -0.5957516	test: -0.3976703	best: -0.3979338 (1396)	total: 4.87s	remaining: 2.06s
1405:	learn: -0.5958609	test: -0.3979357	best: -0.3979357 (1405)	total: 4.88s	remaining: 2.06s
1406:	learn: -0.5959416	test: -0.3979082	best: -0.

1514:	learn: -0.6113608	test: -0.3995915	best: -0.4000917 (1502)	total: 5.25s	remaining: 1.68s
1515:	learn: -0.6114163	test: -0.3995904	best: -0.4000917 (1502)	total: 5.25s	remaining: 1.68s
1516:	learn: -0.6115181	test: -0.3996626	best: -0.4000917 (1502)	total: 5.25s	remaining: 1.67s
1517:	learn: -0.6116835	test: -0.3996426	best: -0.4000917 (1502)	total: 5.26s	remaining: 1.67s
1518:	learn: -0.6119318	test: -0.3996847	best: -0.4000917 (1502)	total: 5.26s	remaining: 1.67s
1519:	learn: -0.6120469	test: -0.3995547	best: -0.4000917 (1502)	total: 5.26s	remaining: 1.66s
1520:	learn: -0.6121032	test: -0.3995482	best: -0.4000917 (1502)	total: 5.27s	remaining: 1.66s
1521:	learn: -0.6121468	test: -0.3995636	best: -0.4000917 (1502)	total: 5.27s	remaining: 1.65s
1522:	learn: -0.6122154	test: -0.3994993	best: -0.4000917 (1502)	total: 5.27s	remaining: 1.65s
1523:	learn: -0.6123719	test: -0.3994590	best: -0.4000917 (1502)	total: 5.27s	remaining: 1.65s
1524:	learn: -0.6124599	test: -0.3992861	best: -0.

1632:	learn: -0.6275025	test: -0.4009365	best: -0.4012533 (1578)	total: 5.64s	remaining: 1.27s
1633:	learn: -0.6277890	test: -0.4010957	best: -0.4012533 (1578)	total: 5.65s	remaining: 1.26s
1634:	learn: -0.6278583	test: -0.4011303	best: -0.4012533 (1578)	total: 5.65s	remaining: 1.26s
1635:	learn: -0.6280552	test: -0.4011301	best: -0.4012533 (1578)	total: 5.65s	remaining: 1.26s
1636:	learn: -0.6281271	test: -0.4010060	best: -0.4012533 (1578)	total: 5.66s	remaining: 1.25s
1637:	learn: -0.6281884	test: -0.4010228	best: -0.4012533 (1578)	total: 5.66s	remaining: 1.25s
1638:	learn: -0.6284380	test: -0.4012779	best: -0.4012779 (1638)	total: 5.67s	remaining: 1.25s
1639:	learn: -0.6288584	test: -0.4015479	best: -0.4015479 (1639)	total: 5.67s	remaining: 1.24s
1640:	learn: -0.6290533	test: -0.4015236	best: -0.4015479 (1639)	total: 5.67s	remaining: 1.24s
1641:	learn: -0.6292665	test: -0.4017663	best: -0.4017663 (1641)	total: 5.68s	remaining: 1.24s
1642:	learn: -0.6293296	test: -0.4017624	best: -0.

1749:	learn: -0.6436892	test: -0.4035446	best: -0.4038460 (1744)	total: 6.04s	remaining: 863ms
1750:	learn: -0.6438301	test: -0.4033131	best: -0.4038460 (1744)	total: 6.04s	remaining: 860ms
1751:	learn: -0.6440944	test: -0.4035161	best: -0.4038460 (1744)	total: 6.05s	remaining: 856ms
1752:	learn: -0.6442019	test: -0.4034754	best: -0.4038460 (1744)	total: 6.05s	remaining: 853ms
1753:	learn: -0.6443889	test: -0.4035135	best: -0.4038460 (1744)	total: 6.05s	remaining: 849ms
1754:	learn: -0.6446074	test: -0.4036239	best: -0.4038460 (1744)	total: 6.06s	remaining: 846ms
1755:	learn: -0.6448174	test: -0.4036655	best: -0.4038460 (1744)	total: 6.06s	remaining: 842ms
1756:	learn: -0.6448838	test: -0.4036245	best: -0.4038460 (1744)	total: 6.07s	remaining: 839ms
1757:	learn: -0.6449606	test: -0.4034630	best: -0.4038460 (1744)	total: 6.07s	remaining: 835ms
1758:	learn: -0.6450092	test: -0.4033819	best: -0.4038460 (1744)	total: 6.07s	remaining: 832ms
1759:	learn: -0.6450965	test: -0.4033322	best: -0.

1867:	learn: -0.6578711	test: -0.4043961	best: -0.4043961 (1867)	total: 6.44s	remaining: 455ms
1868:	learn: -0.6579599	test: -0.4043370	best: -0.4043961 (1867)	total: 6.44s	remaining: 452ms
1869:	learn: -0.6583040	test: -0.4044638	best: -0.4044638 (1869)	total: 6.45s	remaining: 448ms
1870:	learn: -0.6585070	test: -0.4045128	best: -0.4045128 (1870)	total: 6.45s	remaining: 445ms
1871:	learn: -0.6587522	test: -0.4045361	best: -0.4045361 (1871)	total: 6.46s	remaining: 441ms
1872:	learn: -0.6588050	test: -0.4044768	best: -0.4045361 (1871)	total: 6.46s	remaining: 438ms
1873:	learn: -0.6588792	test: -0.4043590	best: -0.4045361 (1871)	total: 6.46s	remaining: 435ms
1874:	learn: -0.6589682	test: -0.4043989	best: -0.4045361 (1871)	total: 6.46s	remaining: 431ms
1875:	learn: -0.6590805	test: -0.4043220	best: -0.4045361 (1871)	total: 6.47s	remaining: 428ms
1876:	learn: -0.6591407	test: -0.4043386	best: -0.4045361 (1871)	total: 6.47s	remaining: 424ms
1877:	learn: -0.6592207	test: -0.4043740	best: -0.

1983:	learn: -0.6717559	test: -0.4033376	best: -0.4045361 (1871)	total: 6.84s	remaining: 55.2ms
1984:	learn: -0.6718156	test: -0.4032887	best: -0.4045361 (1871)	total: 6.84s	remaining: 51.7ms
1985:	learn: -0.6719524	test: -0.4032936	best: -0.4045361 (1871)	total: 6.85s	remaining: 48.3ms
1986:	learn: -0.6720420	test: -0.4032593	best: -0.4045361 (1871)	total: 6.85s	remaining: 44.8ms
1987:	learn: -0.6721079	test: -0.4031020	best: -0.4045361 (1871)	total: 6.85s	remaining: 41.4ms
1988:	learn: -0.6721437	test: -0.4030444	best: -0.4045361 (1871)	total: 6.86s	remaining: 37.9ms
1989:	learn: -0.6722627	test: -0.4029381	best: -0.4045361 (1871)	total: 6.86s	remaining: 34.5ms
1990:	learn: -0.6723327	test: -0.4028918	best: -0.4045361 (1871)	total: 6.86s	remaining: 31ms
1991:	learn: -0.6723744	test: -0.4028468	best: -0.4045361 (1871)	total: 6.86s	remaining: 27.6ms
1992:	learn: -0.6724764	test: -0.4028405	best: -0.4045361 (1871)	total: 6.87s	remaining: 24.1ms
1993:	learn: -0.6725187	test: -0.4028153	b

<catboost.core.CatBoostRegressor at 0x7f1183f454a8>

In [9]:
def calculate_nll_catboost(model: catboost.CatBoostRegressor, x: pd.DataFrame, y: pd.DataFrame):
    x: np.ndarray = x.values
    y: np.ndarray = y.values

    if y.shape[1] > 1:
        return np.nan

    y_hat_tree = model.predict(x)
    y_hat_tree[:, 1] = np.log(np.sqrt(y_hat_tree[:, 1]))  # Transform var to log std / CatBoost RMSEWithUncertainty

    distribution = ConditionalDiagonalNormal(shape=[1])  # Assume 1D distribution
    return -distribution.log_prob(y, y_hat_tree).numpy().mean()

In [10]:
%time calculate_nll_catboost(model, x_train, y_train)

CPU times: user 263 ms, sys: 3.22 ms, total: 266 ms
Wall time: 33.3 ms


-0.595981489421833

In [11]:
%time calculate_nll_catboost(model, x_test, y_test)

CPU times: user 226 ms, sys: 10.6 ms, total: 237 ms
Wall time: 12.1 ms


-0.4018589076764371

In [12]:
tree = EmbeddableCatBoostPriorNormal(
    cat_features=[8, 9, 10], 
    loss_function="RMSEWithUncertainty",
    depth=4,
    num_trees=200,
    random_state=RANDOM_SEED
)
flow = ContinuousNormalizingFlow(input_dim=1, hidden_dims=(200, 200, 100, 50), 
                                 num_blocks=5, context_dim=100, conditional=True)

treeflow = TreeFlowBoost(tree, flow, embedding_size=100)

In [None]:
%time treeflow.fit(x_tr.values, y_tr.values, x_val.values, y_val.values, n_epochs=50, batch_size=2048, verbose=True)

0:	learn: 0.4887763	total: 3.07ms	remaining: 612ms
1:	learn: 0.4724965	total: 5.9ms	remaining: 584ms
2:	learn: 0.4535065	total: 8.14ms	remaining: 534ms
3:	learn: 0.4356101	total: 10.2ms	remaining: 498ms
4:	learn: 0.4205684	total: 12.3ms	remaining: 479ms
5:	learn: 0.4051330	total: 14.1ms	remaining: 455ms
6:	learn: 0.3904869	total: 15.8ms	remaining: 437ms
7:	learn: 0.3763966	total: 17.6ms	remaining: 423ms
8:	learn: 0.3636416	total: 19.8ms	remaining: 420ms
9:	learn: 0.3510158	total: 21.8ms	remaining: 415ms
10:	learn: 0.3400005	total: 23.8ms	remaining: 410ms
11:	learn: 0.3284093	total: 25.8ms	remaining: 404ms
12:	learn: 0.3168091	total: 27.5ms	remaining: 396ms
13:	learn: 0.3070087	total: 29.6ms	remaining: 393ms
14:	learn: 0.2969666	total: 31.9ms	remaining: 393ms
15:	learn: 0.2865719	total: 33.8ms	remaining: 389ms
16:	learn: 0.2761719	total: 35.6ms	remaining: 384ms
17:	learn: 0.2665670	total: 37.6ms	remaining: 380ms
18:	learn: 0.2572313	total: 39.4ms	remaining: 375ms
19:	learn: 0.2488853	to

179:	learn: -0.1246050	total: 398ms	remaining: 44.2ms
180:	learn: -0.1248205	total: 400ms	remaining: 42ms
181:	learn: -0.1252273	total: 403ms	remaining: 39.8ms
182:	learn: -0.1259464	total: 405ms	remaining: 37.6ms
183:	learn: -0.1269690	total: 406ms	remaining: 35.3ms
184:	learn: -0.1274339	total: 409ms	remaining: 33.1ms
185:	learn: -0.1278428	total: 411ms	remaining: 30.9ms
186:	learn: -0.1280714	total: 413ms	remaining: 28.7ms
187:	learn: -0.1280836	total: 414ms	remaining: 26.4ms
188:	learn: -0.1284054	total: 416ms	remaining: 24.2ms
189:	learn: -0.1290842	total: 418ms	remaining: 22ms
190:	learn: -0.1298225	total: 420ms	remaining: 19.8ms
191:	learn: -0.1301655	total: 422ms	remaining: 17.6ms
192:	learn: -0.1305142	total: 425ms	remaining: 15.4ms
193:	learn: -0.1311342	total: 427ms	remaining: 13.2ms
194:	learn: -0.1317702	total: 429ms	remaining: 11ms
195:	learn: -0.1325286	total: 431ms	remaining: 8.8ms
196:	learn: -0.1331129	total: 434ms	remaining: 6.61ms
197:	learn: -0.1334405	total: 436ms

In [None]:
%time calculate_nll(treeflow, x_train, y_train, batch_size = 1024)

In [None]:
%time calculate_nll(treeflow, x_test, y_test, batch_size = 1024)

In [None]:
data = x_test.iloc[:10, :]

In [None]:
y_test_catboost = model.predict(data)

y_test_samples = treeflow.sample(data, num_samples=1000)
y_test_samples = y_test_samples.squeeze()

In [None]:
for i in range(10):
    plt.axvline(x=y_test.values[i, :], color='r', label='True value')

    sns.kdeplot(y_test_samples[i, :], color='blue', label='TreeFlow')

    x = np.linspace(
        stats.norm.ppf(0.001, loc=y_test_catboost[i, 0], scale=np.sqrt(y_test_catboost[i, 1])), 
        stats.norm.ppf(0.999, loc=y_test_catboost[i, 0], scale=np.sqrt(y_test_catboost[i, 1])), 
        100
    )
    plt.plot(x, stats.norm.pdf(x, loc=y_test_catboost[i, 0], scale=np.sqrt(y_test_catboost[i, 1])), color = 'orange', label='CatBoost')

    plt.legend()
    plt.show()