In [1]:
from typing import Union

import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import torch.nn as nn

from sklearn.model_selection import train_test_split
from nflows.distributions import ConditionalDiagonalNormal

from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll

from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

RANDOM_SEED = 3

setup_random_seed(RANDOM_SEED)

In [2]:
df = pd.read_csv('data/01_raw/CatData/bigmart/bigmart.csv')
df['Outlet_Size'] = df['Outlet_Size'].fillna('')

In [3]:
x = df.drop(columns = ['Item_Identifier', 'Item_Outlet_Sales'])
y = np.log10(df[['Item_Outlet_Sales']])

In [4]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=42)
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=42)

In [5]:
x_train.shape, x_test.shape

((6818, 10), (1705, 10))

In [6]:
class TreeFlowWithoutShallow(TreeFlowBoost):
    
    def fit(self, X: np.ndarray, y: np.ndarray, X_val: Union[np.ndarray, None] = None,
            y_val: Union[np.ndarray, None] = None, n_epochs: int = 100, batch_size: int = 1000, verbose: bool = False):
        self.tree_model.fit(X, y)

        context: np.ndarray = self.tree_model.embed(X)
        params: np.ndarray = self.tree_model.pred_dist_param(X)
        y: np.ndarray = y if len(y.shape) == 2 else y.reshape(-1, 1)

        if X_val is not None and y_val is not None:
            context_val: np.ndarray = self.tree_model.embed(X_val)
            params_val: np.ndarray = self.tree_model.pred_dist_param(X_val)
            y_val: np.ndarray = y_val if len(y_val.shape) == 2 else y_val.reshape(-1, 1)
        else:
            context_val = None
            params_val = None
            y_val = None

        self.flow_model.setup_context_encoder(nn.Identity())

        self.flow_model.fit(y, context, params, y_val, context_val, params_val, n_epochs=n_epochs,
                            batch_size=batch_size, verbose=verbose)
        return self


In [7]:
depth = 2
num_trees = 500
context_dim = num_trees*2**depth

tree = EmbeddableCatBoostPriorNormal(
    cat_features=[1, 3, 5, 7, 8, 9], 
    loss_function="RMSEWithUncertainty",
    depth=depth,
    num_trees=num_trees,
    random_state=RANDOM_SEED
)
flow = ContinuousNormalizingFlow(input_dim=1, hidden_dims=(200, 200, 100, 50), 
                                 num_blocks=5, context_dim=context_dim, conditional=True)

treeflow = TreeFlowWithoutShallow(tree, flow, embedding_size=context_dim)

In [8]:
%time treeflow.fit(x_tr.values, y_tr.values, x_val.values, y_val.values, n_epochs=50, batch_size=1024, verbose=True)

0:	learn: 0.5530084	total: 49.7ms	remaining: 24.8s
1:	learn: 0.5190486	total: 52.4ms	remaining: 13.1s
2:	learn: 0.4932207	total: 54.9ms	remaining: 9.09s
3:	learn: 0.4733343	total: 55.9ms	remaining: 6.93s
4:	learn: 0.4555082	total: 57ms	remaining: 5.64s
5:	learn: 0.4405363	total: 58.1ms	remaining: 4.78s
6:	learn: 0.4276602	total: 59.1ms	remaining: 4.17s
7:	learn: 0.4132764	total: 60.2ms	remaining: 3.7s
8:	learn: 0.3994643	total: 61.3ms	remaining: 3.34s
9:	learn: 0.3876195	total: 62.3ms	remaining: 3.05s
10:	learn: 0.3754727	total: 63.3ms	remaining: 2.82s
11:	learn: 0.3638768	total: 64.4ms	remaining: 2.62s
12:	learn: 0.3560535	total: 65.6ms	remaining: 2.46s
13:	learn: 0.3448078	total: 66.6ms	remaining: 2.31s
14:	learn: 0.3338963	total: 67.7ms	remaining: 2.19s
15:	learn: 0.3243187	total: 68.8ms	remaining: 2.08s
16:	learn: 0.3146545	total: 70.1ms	remaining: 1.99s
17:	learn: 0.3045234	total: 71.3ms	remaining: 1.91s
18:	learn: 0.2979169	total: 72.4ms	remaining: 1.83s
19:	learn: 0.2880315	tota

167:	learn: -0.0443669	total: 246ms	remaining: 485ms
168:	learn: -0.0444182	total: 247ms	remaining: 483ms
169:	learn: -0.0444728	total: 248ms	remaining: 482ms
170:	learn: -0.0454972	total: 249ms	remaining: 480ms
171:	learn: -0.0463885	total: 251ms	remaining: 478ms
172:	learn: -0.0464336	total: 252ms	remaining: 476ms
173:	learn: -0.0466713	total: 253ms	remaining: 474ms
174:	learn: -0.0468739	total: 254ms	remaining: 472ms
175:	learn: -0.0470112	total: 256ms	remaining: 471ms
176:	learn: -0.0471963	total: 257ms	remaining: 469ms
177:	learn: -0.0473819	total: 258ms	remaining: 467ms
178:	learn: -0.0475885	total: 259ms	remaining: 465ms
179:	learn: -0.0476132	total: 260ms	remaining: 462ms
180:	learn: -0.0477912	total: 261ms	remaining: 460ms
181:	learn: -0.0487658	total: 262ms	remaining: 459ms
182:	learn: -0.0488661	total: 264ms	remaining: 457ms
183:	learn: -0.0496411	total: 265ms	remaining: 455ms
184:	learn: -0.0496914	total: 266ms	remaining: 453ms
185:	learn: -0.0497179	total: 267ms	remaining:

330:	learn: -0.0804499	total: 440ms	remaining: 225ms
331:	learn: -0.0805515	total: 442ms	remaining: 224ms
332:	learn: -0.0807306	total: 443ms	remaining: 222ms
333:	learn: -0.0809640	total: 444ms	remaining: 221ms
334:	learn: -0.0809962	total: 445ms	remaining: 219ms
335:	learn: -0.0810279	total: 447ms	remaining: 218ms
336:	learn: -0.0813829	total: 448ms	remaining: 217ms
337:	learn: -0.0815312	total: 450ms	remaining: 215ms
338:	learn: -0.0817023	total: 451ms	remaining: 214ms
339:	learn: -0.0817047	total: 452ms	remaining: 213ms
340:	learn: -0.0818370	total: 453ms	remaining: 211ms
341:	learn: -0.0820382	total: 455ms	remaining: 210ms
342:	learn: -0.0822476	total: 456ms	remaining: 209ms
343:	learn: -0.0823813	total: 457ms	remaining: 207ms
344:	learn: -0.0824743	total: 458ms	remaining: 206ms
345:	learn: -0.0827221	total: 460ms	remaining: 205ms
346:	learn: -0.0827511	total: 461ms	remaining: 203ms
347:	learn: -0.0830980	total: 462ms	remaining: 202ms
348:	learn: -0.0831277	total: 463ms	remaining:

490:	learn: -0.0945047	total: 635ms	remaining: 11.6ms
491:	learn: -0.0945954	total: 636ms	remaining: 10.3ms
492:	learn: -0.0946727	total: 637ms	remaining: 9.05ms
493:	learn: -0.0946914	total: 639ms	remaining: 7.75ms
494:	learn: -0.0947874	total: 640ms	remaining: 6.46ms
495:	learn: -0.0947979	total: 641ms	remaining: 5.17ms
496:	learn: -0.0949390	total: 642ms	remaining: 3.88ms
497:	learn: -0.0949587	total: 643ms	remaining: 2.58ms
498:	learn: -0.0949749	total: 644ms	remaining: 1.29ms
499:	learn: -0.0950703	total: 646ms	remaining: 0us
train loss: 0.5950672030448914
val loss: 0.48311564326286316
train loss: 0.08451010286808014
val loss: 0.11328640580177307
train loss: -0.009409218095242977
val loss: 0.014782877638936043
train loss: -0.04572813957929611
val loss: 0.014453115873038769
train loss: -0.07006014138460159
val loss: -0.04006689414381981
train loss: -0.06970561295747757
val loss: -0.05144418030977249
train loss: -0.07544279098510742
val loss: -0.04950166493654251
train loss: -0.0800

TreeFlowWithoutShallow(embedding_size=2000,
                       flow_model=<src.probabilistic_flow_boosting.tfboost.flow.flow.ContinuousNormalizingFlow object at 0x7f638b94ab38>,
                       tree_model=<src.probabilistic_flow_boosting.tfboost.tree.ecatboost.EmbeddableCatBoostPriorNormal object at 0x7f638b94a710>)

In [9]:
print(context_dim)

2000


In [10]:
%time calculate_nll(treeflow, x_train, y_train, batch_size = 1024)

CPU times: user 7.67 s, sys: 402 ms, total: 8.07 s
Wall time: 2.86 s


-0.08942635

In [11]:
%time calculate_nll(treeflow, x_test, y_test, batch_size = 1024)

CPU times: user 4.17 s, sys: 164 ms, total: 4.33 s
Wall time: 798 ms


-0.0644494