In [1]:
from typing import Any, Dict, List, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import uncertainty_toolbox as uct

from sklearn.model_selection import train_test_split, GridSearchCV

from src.probabilistic_flow_boosting.extras.datasets.uci_dataset import UCIDataSet
from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed
from src.probabilistic_flow_boosting.tfboost.softtreeflow import SoftTreeFlow

pd.set_option('display.float_format', lambda x: '%.5f' % x)

In [2]:
RANDOM_SEED = 42

setup_random_seed(RANDOM_SEED)

In [3]:
x_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()
y_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()

x_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()
y_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()

In [4]:
def softtreeflow_logprob(model: SoftTreeFlow, X: torch.Tensor, y: torch.Tensor):
    model.eval()
    
    with torch.no_grad():
        result = model.log_prob(X, y).mean().item()
        
    model.train()
    return result


def fit_softtreeflow(
    x_train: np.ndarray, 
    y_train: np.ndarray, 
    param_grid: Dict[str, List[Any]],
    cv: int = 3,
    n_epochs: int = 200, 
    patience: int = 50,
    batch_size: int = 128,
    random_state: int = 42,
    test_size: float = 0.2
):  
    model = SoftTreeFlow(
        input_dim=x_train.shape[1],
        output_dim=y_train.shape[1],
    )
    model.train()
    
    x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = test_size, random_state=random_state)
    
    grid = GridSearchCV(
        estimator=model,
        scoring=softtreeflow_logprob,
        cv=cv,
        param_grid=param_grid,
        refit=True,
        return_train_score=True
    )

    grid.fit(x_tr, y_tr, X_val=x_val, y_val=y_val, n_epochs=1)
    return grid.best_estimator_

In [5]:
param_grid = {
    "tree_depth": [2, 3, 4, 5]
}


model = fit_softtreeflow(
    x_train=x_train.values, 
    y_train=y_train.values,
    param_grid=param_grid
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.39s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.35s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/

In [8]:
def univariate_probabilistic_regression_report(estimator: SoftTreeFlow, X: np.ndarray, y: np.ndarray, **kwargs) -> Dict[str, float]:
    y_pred: np.ndarray = estimator.predict(X, **kwargs)
    
    metrics = {
        "nll": -estimator.log_prob(X, y).mean(),
        "crps": estimator.crps(X, y),
    }
    
    y: np.ndarray = y.squeeze()
    metrics.update(uct.get_all_accuracy_metrics(y_pred=y_pred, y_true=y, verbose=False))
    return metrics

In [9]:
univariate_probabilistic_regression_report(model, x_train.values, y_train.values)

{'log_prob': array([[-9.092747 ],
        [-1.7988298],
        [-4.5356703],
        ...,
        [-0.6083021],
        [-4.5737314],
        [-0.6005993]], dtype=float32),
 'crps': None,
 'mae': 1.5187798268766848,
 'rmse': 1.7030753923928892,
 'mdae': 1.1693744659423828,
 'marpd': 24.633862327780932,
 'r2': -3.425165169008465,
 'corr': 0.01244446970312382}

In [None]:
univariate_probabilistic_regression_report(model, x_test.values, y_test.values)

In [None]:
with torch.no_grad():
    model.eval()

    logprob_train = - model.log_prob(x_tr, y_tr, batch_size=2000).mean()
    logprob_val = - model.log_prob(x_val, y_val, batch_size=2000).mean()
    logprob_test = - model.log_prob(x_test, y_test, batch_size=2000).mean()

    print(logprob_train.item(), logprob_val.item(), logprob_test.item())

In [None]:
with torch.no_grad():
    model.eval()

    logprob_train = - model.log_prob(x_tr, y_tr, batch_size=2000).mean()
    logprob_val = - model.log_prob(x_val, y_val, batch_size=2000).mean()
    logprob_test = - model.log_prob(x_test, y_test, batch_size=2000).mean()

    print(logprob_train.item(), logprob_val.item(), logprob_test.item())

In [None]:
y_test_samples = model.sample(x_test, num_samples = 1000, batch_size = 2000)
y_test_samples = y_test_samples.detach().numpy().squeeze()

In [None]:
y_test_samples.shape

In [None]:
paths = model.predict_tree_path(x_test)
paths = paths.detach().numpy()

In [None]:
for i in range(30):
    print(f"Sample {i}, paths probability:")
    print(pd.DataFrame(paths[i, :]).sort_values(0, ascending=False).head(5))
    plt.axvline(x=y_test.detach().numpy()[i, :], color='r', label='True value')

    ## TreeFlow
    sns.kdeplot(y_test_samples[i, :], color='blue', label='TreeFlow')

    plt.xlim([0, 10])
    plt.legend()
    plt.show()
    plt.close()