In [1]:
from typing import Union

import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import torch.nn as nn

from sklearn.model_selection import train_test_split
from nflows.distributions import ConditionalDiagonalNormal

from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal, EmbeddableOneHotEncoder
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll

from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

RANDOM_SEED = 42

setup_random_seed(RANDOM_SEED)

In [2]:
df = pd.read_csv('data/01_raw/CatData/diamonds/diamonds.csv', index_col=0)

In [3]:
x = df.drop(columns = ['price'])
y = np.log10(df[['price']])

In [4]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=42)
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=42)

In [5]:
x_train.shape, x_test.shape

((43152, 9), (10788, 9))

In [6]:
class TreeFlowWithoutShallow(TreeFlowBoost):
    
    def fit(self, X: np.ndarray, y: np.ndarray, X_val: Union[np.ndarray, None] = None,
            y_val: Union[np.ndarray, None] = None, n_epochs: int = 100, batch_size: int = 1000, verbose: bool = False):
        self.tree_model.fit(X, y)

        context: np.ndarray = self.tree_model.embed(X)
        params: np.ndarray = self.tree_model.pred_dist_param(X)
        y: np.ndarray = y if len(y.shape) == 2 else y.reshape(-1, 1)

        if X_val is not None and y_val is not None:
            context_val: np.ndarray = self.tree_model.embed(X_val)
            params_val: np.ndarray = self.tree_model.pred_dist_param(X_val)
            y_val: np.ndarray = y_val if len(y_val.shape) == 2 else y_val.reshape(-1, 1)
        else:
            context_val = None
            params_val = None
            y_val = None

        self.flow_model.setup_context_encoder(nn.Identity())

        self.flow_model.fit(y, context, params, y_val, context_val, params_val, n_epochs=n_epochs,
                            batch_size=batch_size, verbose=verbose)
        return self


In [7]:
num_trees = 200
depth = 4

tree = EmbeddableCatBoostPriorNormal(
    cat_features=[1, 2, 3], 
    loss_function="RMSEWithUncertainty",
    depth=depth,
    num_trees=num_trees,
    random_state=RANDOM_SEED
)

flow = ContinuousNormalizingFlow(input_dim=1, hidden_dims=(200, 100, 100, 50), 
                                 num_blocks=5, context_dim=num_trees*2**depth, conditional=True)

treeflow = TreeFlowWithoutShallow(tree, flow, embedding_size=num_trees*2**depth)

In [8]:
%time treeflow.fit(x_tr.values, y_tr.values, x_val.values, y_val.values, n_epochs=30, batch_size=1024, verbose=True)

0:	learn: 0.5579413	total: 53.7ms	remaining: 10.7s
1:	learn: 0.5227974	total: 59.7ms	remaining: 5.91s
2:	learn: 0.4911946	total: 64.6ms	remaining: 4.24s
3:	learn: 0.4624084	total: 70.4ms	remaining: 3.45s
4:	learn: 0.4353872	total: 75.8ms	remaining: 2.96s
5:	learn: 0.4100112	total: 81.2ms	remaining: 2.62s
6:	learn: 0.3852236	total: 86.1ms	remaining: 2.37s
7:	learn: 0.3613284	total: 90.7ms	remaining: 2.18s
8:	learn: 0.3385223	total: 95.7ms	remaining: 2.03s
9:	learn: 0.3171481	total: 101ms	remaining: 1.92s
10:	learn: 0.2952672	total: 106ms	remaining: 1.83s
11:	learn: 0.2733442	total: 111ms	remaining: 1.74s
12:	learn: 0.2523981	total: 116ms	remaining: 1.67s
13:	learn: 0.2310772	total: 121ms	remaining: 1.6s
14:	learn: 0.2105034	total: 126ms	remaining: 1.55s
15:	learn: 0.1904924	total: 131ms	remaining: 1.5s
16:	learn: 0.1697434	total: 136ms	remaining: 1.46s
17:	learn: 0.1498967	total: 141ms	remaining: 1.42s
18:	learn: 0.1294843	total: 145ms	remaining: 1.39s
19:	learn: 0.1094417	total: 150ms	

182:	learn: -1.3691934	total: 852ms	remaining: 79.1ms
183:	learn: -1.3711917	total: 857ms	remaining: 74.5ms
184:	learn: -1.3718332	total: 862ms	remaining: 69.9ms
185:	learn: -1.3729500	total: 867ms	remaining: 65.3ms
186:	learn: -1.3766775	total: 872ms	remaining: 60.6ms
187:	learn: -1.3788115	total: 877ms	remaining: 56ms
188:	learn: -1.3807965	total: 882ms	remaining: 51.3ms
189:	learn: -1.3831274	total: 886ms	remaining: 46.7ms
190:	learn: -1.3840225	total: 891ms	remaining: 42ms
191:	learn: -1.3852516	total: 894ms	remaining: 37.2ms
192:	learn: -1.3875799	total: 898ms	remaining: 32.6ms
193:	learn: -1.3889695	total: 902ms	remaining: 27.9ms
194:	learn: -1.3903934	total: 906ms	remaining: 23.2ms
195:	learn: -1.3938186	total: 910ms	remaining: 18.6ms
196:	learn: -1.3949155	total: 915ms	remaining: 13.9ms
197:	learn: -1.3988189	total: 920ms	remaining: 9.29ms
198:	learn: -1.3998589	total: 924ms	remaining: 4.64ms
199:	learn: -1.4018989	total: 929ms	remaining: 0us
train loss: -1.1898952722549438
val

TreeFlowWithoutShallow(embedding_size=3200,
                       flow_model=<src.probabilistic_flow_boosting.tfboost.flow.flow.ContinuousNormalizingFlow object at 0x7f2dcf1dea20>,
                       tree_model=<src.probabilistic_flow_boosting.tfboost.tree.ecatboost.EmbeddableCatBoostPriorNormal object at 0x7f2dcf1deeb8>)

In [9]:
print(num_trees*2**depth)

3200


In [12]:
%time calculate_nll(treeflow, x_train, y_train, batch_size = 1024)

CPU times: user 32.1 s, sys: 1.3 s, total: 33.4 s
Wall time: 25 s


-2.0762994

In [13]:
%time calculate_nll(treeflow, x_test, y_test, batch_size = 1024)

CPU times: user 12.1 s, sys: 525 ms, total: 12.6 s
Wall time: 6.37 s


-1.92711