In [1]:
import torch

from sklearn.model_selection import train_test_split

from src.probabilistic_flow_boosting.extras.datasets.uci_dataset import UCIDataSet
from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed
from src.probabilistic_flow_boosting.tfboost.softtreeflow import SoftTreeFlow

In [2]:
RANDOM_SEED = 42
TRAIN = False
MODEL_FILEPATH = 'treeflow_wine.model'

setup_random_seed(RANDOM_SEED)

In [3]:
x_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()
y_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()

x_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()
y_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()

In [4]:
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=RANDOM_SEED)

In [5]:
x_tr = torch.Tensor(x_tr.values)
x_val = torch.Tensor(x_val.values)
x_test = torch.Tensor(x_test.values)

y_tr = torch.Tensor(y_tr.values)
y_val = torch.Tensor(y_val.values)
y_test = torch.Tensor(y_test.values)

In [6]:
model = SoftTreeFlow(
    input_dim=x_tr.shape[1],
    output_dim=y_tr.shape[1],
    tree_depth=10
)

In [7]:
model.fit(x_tr, y_tr, x_val, y_val, n_epochs=100)

0
Loss 17.400732040405273
Loss 7.352973461151123
Loss 3.408940315246582
Loss 1.1664212942123413
Loss 1.3002073764801025
Loss 1.8551942110061646
Loss 2.3327219486236572
Loss 3.124753475189209
Loss 2.9238228797912598
Loss validation 3.17427921295166
1
Loss 2.6553080081939697
Loss 2.308692216873169
Loss 1.7537710666656494
Loss 1.7112067937850952
Loss 1.6154204607009888
Loss 1.4688996076583862
Loss 1.1862447261810303
Loss 1.2359182834625244
Loss 1.2193986177444458
Loss validation 1.2884440422058105
2
Loss 1.3096102476119995
Loss 1.2146120071411133
Loss 1.2533512115478516
Loss 1.2171361446380615
Loss 1.3942463397979736
Loss 1.4009062051773071
Loss 1.3824520111083984
Loss 1.2279256582260132
Loss 1.3114547729492188
Loss validation 1.3340744972229004
3
Loss 1.1459684371948242
Loss 1.2380015850067139
Loss 1.3452630043029785
Loss 1.278502106666565
Loss 1.2614468336105347
Loss 1.128752589225769
Loss 1.3482824563980103
Loss 1.2675280570983887
Loss 1.203536868095398
Loss validation 1.29790639877319

SoftTreeFlow(device=device(type='cpu'), input_dim=11, output_dim=1,
             tree_depth=10)

In [12]:
logprob_train = - model.log_prob(x_tr, y_tr).mean()
logprob_val = - model.log_prob(x_val, y_val).mean()
logprob_test = - model.log_prob(x_test, y_test).mean()

In [13]:
logprob_train

-0.17518461

In [14]:
logprob_val

0.078504

In [15]:
logprob_test

-0.17684701