In [1]:
from typing import Union

import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import torch.nn as nn

from sklearn.model_selection import train_test_split
from nflows.distributions import ConditionalDiagonalNormal

from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll

from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

RANDOM_SEED = 1

setup_random_seed(RANDOM_SEED)

In [2]:
df = pd.read_csv('data/01_raw/CatData/diamonds2/diamonds_dataset.csv')

In [3]:
x = df.drop(columns = ['id', 'url', 'price', 'date_fetched'])
y = np.log10(df[['price']])

In [4]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=42)
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=42)

In [5]:
x_train.shape, x_test.shape

((95445, 7), (23862, 7))

In [6]:
class TreeFlowWithoutShallow(TreeFlowBoost):
    
    def fit(self, X: np.ndarray, y: np.ndarray, X_val: Union[np.ndarray, None] = None,
            y_val: Union[np.ndarray, None] = None, n_epochs: int = 100, batch_size: int = 1000, verbose: bool = False):
        self.tree_model.fit(X, y)

        context: np.ndarray = self.tree_model.embed(X)
        params: np.ndarray = self.tree_model.pred_dist_param(X)
        y: np.ndarray = y if len(y.shape) == 2 else y.reshape(-1, 1)

        if X_val is not None and y_val is not None:
            context_val: np.ndarray = self.tree_model.embed(X_val)
            params_val: np.ndarray = self.tree_model.pred_dist_param(X_val)
            y_val: np.ndarray = y_val if len(y_val.shape) == 2 else y_val.reshape(-1, 1)
        else:
            context_val = None
            params_val = None
            y_val = None

        self.flow_model.setup_context_encoder(nn.Identity())

        self.flow_model.fit(y, context, params, y_val, context_val, params_val, n_epochs=n_epochs,
                            batch_size=batch_size, verbose=verbose)
        return self


In [7]:
depth = 4
num_trees = 200
context_dim = num_trees*2**depth

tree = EmbeddableCatBoostPriorNormal(
    cat_features=[0, 2, 3, 4, 5, 6], 
    loss_function="RMSEWithUncertainty",
    depth=depth,
    num_trees=num_trees,
    random_state=RANDOM_SEED
)
flow = ContinuousNormalizingFlow(input_dim=1, hidden_dims=(200, 200, 100, 50), 
                                 num_blocks=5, context_dim=context_dim, conditional=True)

treeflow = TreeFlowWithoutShallow(tree, flow, embedding_size=context_dim)

In [8]:
%time treeflow.fit(x_tr.values, y_tr.values, x_val.values, y_val.values, n_epochs=30, batch_size=2048, verbose=True)

0:	learn: 0.4777632	total: 62.2ms	remaining: 12.4s
1:	learn: 0.4378460	total: 76.8ms	remaining: 7.6s
2:	learn: 0.4033551	total: 90ms	remaining: 5.91s
3:	learn: 0.3773403	total: 105ms	remaining: 5.13s
4:	learn: 0.3495352	total: 119ms	remaining: 4.64s
5:	learn: 0.3231689	total: 132ms	remaining: 4.27s
6:	learn: 0.2991023	total: 144ms	remaining: 3.98s
7:	learn: 0.2755703	total: 155ms	remaining: 3.72s
8:	learn: 0.2528433	total: 167ms	remaining: 3.54s
9:	learn: 0.2306212	total: 179ms	remaining: 3.41s
10:	learn: 0.2110095	total: 191ms	remaining: 3.27s
11:	learn: 0.1936577	total: 202ms	remaining: 3.17s
12:	learn: 0.1764162	total: 214ms	remaining: 3.07s
13:	learn: 0.1555244	total: 225ms	remaining: 2.99s
14:	learn: 0.1356510	total: 237ms	remaining: 2.92s
15:	learn: 0.1153216	total: 250ms	remaining: 2.87s
16:	learn: 0.0953899	total: 261ms	remaining: 2.81s
17:	learn: 0.0775564	total: 274ms	remaining: 2.77s
18:	learn: 0.0598135	total: 286ms	remaining: 2.72s
19:	learn: 0.0406675	total: 299ms	remaini

164:	learn: -1.2291296	total: 1.91s	remaining: 406ms
165:	learn: -1.2323368	total: 1.92s	remaining: 394ms
166:	learn: -1.2344290	total: 1.94s	remaining: 382ms
167:	learn: -1.2373028	total: 1.95s	remaining: 371ms
168:	learn: -1.2394368	total: 1.96s	remaining: 360ms
169:	learn: -1.2414209	total: 1.97s	remaining: 348ms
170:	learn: -1.2444188	total: 1.98s	remaining: 336ms
171:	learn: -1.2470315	total: 1.99s	remaining: 324ms
172:	learn: -1.2489727	total: 2s	remaining: 312ms
173:	learn: -1.2519358	total: 2.01s	remaining: 300ms
174:	learn: -1.2559915	total: 2.02s	remaining: 289ms
175:	learn: -1.2595412	total: 2.03s	remaining: 277ms
176:	learn: -1.2612947	total: 2.04s	remaining: 265ms
177:	learn: -1.2630619	total: 2.05s	remaining: 254ms
178:	learn: -1.2648368	total: 2.07s	remaining: 243ms
179:	learn: -1.2663116	total: 2.08s	remaining: 231ms
180:	learn: -1.2716811	total: 2.09s	remaining: 219ms
181:	learn: -1.2719955	total: 2.09s	remaining: 207ms
182:	learn: -1.2748397	total: 2.1s	remaining: 195

TreeFlowWithoutShallow(embedding_size=3200,
                       flow_model=<src.probabilistic_flow_boosting.tfboost.flow.flow.ContinuousNormalizingFlow object at 0x7f917adc1358>,
                       tree_model=<src.probabilistic_flow_boosting.tfboost.tree.ecatboost.EmbeddableCatBoostPriorNormal object at 0x7f917adc1320>)

In [9]:
print(context_dim)

3200


In [10]:
%time calculate_nll(treeflow, x_train, y_train, batch_size = 1024)

CPU times: user 2min 4s, sys: 4.36 s, total: 2min 8s
Wall time: 1min 54s


-2.1184237

In [11]:
%time calculate_nll(treeflow, x_test, y_test, batch_size = 1024)

CPU times: user 37.9 s, sys: 1.1 s, total: 39 s
Wall time: 30.5 s


-2.0629003