In [1]:
from typing import Union

import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import torch.nn as nn

from sklearn.model_selection import train_test_split
from nflows.distributions import ConditionalDiagonalNormal

from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll

from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

RANDOM_SEED = 1

setup_random_seed(RANDOM_SEED)

In [2]:
df = pd.read_csv('data/01_raw/CatData/laptop/laptop_price.csv', index_col=0, engine='python')

In [3]:
df['Weight'] = pd.to_numeric(df['Weight'].str.replace('kg', ''))
df['Ram'] = pd.to_numeric(df['Ram'].str.replace('GB', ''))

In [4]:
x = df.drop(columns = ['Product', 'Price_euros'])
y = np.log10(df[['Price_euros']])

In [5]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=42)
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=42)

In [6]:
x_train.shape, x_test.shape

((1042, 10), (261, 10))

In [7]:
class TreeFlowWithoutShallow(TreeFlowBoost):
    
    def fit(self, X: np.ndarray, y: np.ndarray, X_val: Union[np.ndarray, None] = None,
            y_val: Union[np.ndarray, None] = None, n_epochs: int = 100, batch_size: int = 1000, verbose: bool = False):
        self.tree_model.fit(X, y)

        context: np.ndarray = self.tree_model.embed(X)
        params: np.ndarray = self.tree_model.pred_dist_param(X)
        y: np.ndarray = y if len(y.shape) == 2 else y.reshape(-1, 1)

        if X_val is not None and y_val is not None:
            context_val: np.ndarray = self.tree_model.embed(X_val)
            params_val: np.ndarray = self.tree_model.pred_dist_param(X_val)
            y_val: np.ndarray = y_val if len(y_val.shape) == 2 else y_val.reshape(-1, 1)
        else:
            context_val = None
            params_val = None
            y_val = None

        self.flow_model.setup_context_encoder(nn.Identity())

        self.flow_model.fit(y, context, params, y_val, context_val, params_val, n_epochs=n_epochs,
                            batch_size=batch_size, verbose=verbose)
        return self


In [8]:
depth = 2
num_trees = 500
context_dim = num_trees*2**depth

tree = EmbeddableCatBoostPriorNormal(
    cat_features=[0, 1, 3, 4, 6, 7, 8], 
    loss_function="RMSEWithUncertainty",
    depth=depth,
    num_trees=num_trees,
    random_state=RANDOM_SEED
)
flow = ContinuousNormalizingFlow(input_dim=1, hidden_dims=(200, 200, 100, 50), 
                                 num_blocks=5, context_dim=context_dim, conditional=True)

treeflow = TreeFlowWithoutShallow(tree, flow, embedding_size=context_dim)

In [9]:
%time treeflow.fit(x_tr.values, y_tr.values, x_val.values, y_val.values, n_epochs=100, batch_size=1024, verbose=True)

0:	learn: 0.1004663	total: 47.3ms	remaining: 23.6s
1:	learn: 0.0782722	total: 48ms	remaining: 12s
2:	learn: 0.0621595	total: 49ms	remaining: 8.11s
3:	learn: 0.0548402	total: 49.9ms	remaining: 6.19s
4:	learn: 0.0370204	total: 50.4ms	remaining: 4.99s
5:	learn: 0.0185338	total: 51.2ms	remaining: 4.21s
6:	learn: 0.0023820	total: 51.8ms	remaining: 3.65s
7:	learn: -0.0131330	total: 52.2ms	remaining: 3.21s
8:	learn: -0.0271088	total: 52.7ms	remaining: 2.88s
9:	learn: -0.0411166	total: 53.2ms	remaining: 2.61s
10:	learn: -0.0544373	total: 53.6ms	remaining: 2.38s
11:	learn: -0.0586728	total: 54.3ms	remaining: 2.21s
12:	learn: -0.0727161	total: 55ms	remaining: 2.06s
13:	learn: -0.0825138	total: 55.5ms	remaining: 1.93s
14:	learn: -0.0939279	total: 56.1ms	remaining: 1.81s
15:	learn: -0.1022889	total: 56.5ms	remaining: 1.71s
16:	learn: -0.1139086	total: 57ms	remaining: 1.62s
17:	learn: -0.1220989	total: 57.5ms	remaining: 1.54s
18:	learn: -0.1348636	total: 58.1ms	remaining: 1.47s
19:	learn: -0.139777

366:	learn: -0.8533085	total: 235ms	remaining: 85.3ms
367:	learn: -0.8536114	total: 236ms	remaining: 84.7ms
368:	learn: -0.8537597	total: 237ms	remaining: 84ms
369:	learn: -0.8538549	total: 237ms	remaining: 83.3ms
370:	learn: -0.8540950	total: 238ms	remaining: 82.6ms
371:	learn: -0.8543767	total: 238ms	remaining: 82ms
372:	learn: -0.8546475	total: 239ms	remaining: 81.3ms
373:	learn: -0.8548140	total: 239ms	remaining: 80.6ms
374:	learn: -0.8549613	total: 240ms	remaining: 79.9ms
375:	learn: -0.8556378	total: 240ms	remaining: 79.2ms
376:	learn: -0.8559866	total: 241ms	remaining: 78.5ms
377:	learn: -0.8562025	total: 241ms	remaining: 77.8ms
378:	learn: -0.8565572	total: 242ms	remaining: 77.2ms
379:	learn: -0.8570891	total: 242ms	remaining: 76.5ms
380:	learn: -0.8571429	total: 243ms	remaining: 75.8ms
381:	learn: -0.8573466	total: 243ms	remaining: 75.1ms
382:	learn: -0.8576566	total: 244ms	remaining: 74.4ms
383:	learn: -0.8578483	total: 244ms	remaining: 73.8ms
384:	learn: -0.8580088	total: 24

train loss: -0.8354262113571167
val loss: -0.8248106241226196
train loss: -0.8527163863182068
val loss: -0.8443257212638855
train loss: -0.8688827753067017
val loss: -0.8631647825241089
train loss: -0.8826000094413757
val loss: -0.8798490762710571
train loss: -0.8932698369026184
val loss: -0.8935533761978149
train loss: -0.9010194540023804
val loss: -0.904115617275238
train loss: -0.9065582752227783
val loss: -0.9119405746459961
train loss: -0.9109293818473816
val loss: -0.9177980422973633
train loss: -0.9152336120605469
val loss: -0.9225893020629883
train loss: -0.9203722476959229
val loss: -0.9271175861358643
train loss: -0.9268805384635925
val loss: -0.9319059252738953
train loss: -0.9348483681678772
val loss: -0.9371142387390137
train loss: -0.9439576268196106
val loss: -0.9425539970397949
train loss: -0.9535970091819763
val loss: -0.9477899670600891
train loss: -0.9630493521690369
val loss: -0.9523224830627441
train loss: -0.971698522567749
val loss: -0.9557737708091736
train loss

TreeFlowWithoutShallow(embedding_size=2000,
                       flow_model=<src.probabilistic_flow_boosting.tfboost.flow.flow.ContinuousNormalizingFlow object at 0x7fc78f4aa748>,
                       tree_model=<src.probabilistic_flow_boosting.tfboost.tree.ecatboost.EmbeddableCatBoostPriorNormal object at 0x7fc78f4aa6a0>)

In [10]:
print(context_dim)

2000


In [11]:
%time calculate_nll(treeflow, x_train, y_train, batch_size = 1024)

CPU times: user 3.84 s, sys: 116 ms, total: 3.96 s
Wall time: 709 ms


-1.1460997

In [12]:
%time calculate_nll(treeflow, x_test, y_test, batch_size = 1024)

CPU times: user 3.17 s, sys: 76.2 ms, total: 3.25 s
Wall time: 286 ms


-0.9361255