In [1]:
import warnings

warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import ngboost as ng
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing

from models.flow import build_model
from tfboost.flow import ContinuousNormalizingFlow
from tfboost.tree.engboost import EmbeddableNGBoost, EmbeddableNGBoost2, EmbeddableNGBoostDecisionPath
from tfboost.tree.ecatboost import EmbeddableCatBoost
from tfboost.tfboost import TreeFlowBoost

In [2]:
train = pd.read_csv('data/catboost/adult/train', sep='\t', header=None)
test = pd.read_csv('data/catboost/adult/test', sep='\t', header=None)

In [3]:
x_train, y_train = train.loc[:, 2:].values, train[1].values
x_test, y_test = test.loc[:, 2:].values, test[1].values

In [4]:
cat_features = [0, 2, 4, 5, 6, 7, 8, 12]

In [5]:
flow = ContinuousNormalizingFlow(build_model(
    input_dim=1,
    hidden_dims=(80, 40),
    context_dim=100,
    conditional=True,
))

tree = EmbeddableCatBoost(max_depth=3, cat_features=cat_features)

tfb = TreeFlowBoost(flow_model=flow, tree_model = tree, embedding_size=100)

In [6]:
tfb.fit(x_train, y_train, n_epochs=200)

Learning rate set to 0.072029
0:	learn: 13.3718057	total: 53.9ms	remaining: 53.9s
1:	learn: 13.0591973	total: 58.3ms	remaining: 29.1s
2:	learn: 12.7843169	total: 62.7ms	remaining: 20.8s
3:	learn: 12.5297040	total: 68.2ms	remaining: 17s
4:	learn: 12.3060430	total: 72.5ms	remaining: 14.4s
5:	learn: 12.1100417	total: 75.5ms	remaining: 12.5s
6:	learn: 11.9339570	total: 78.3ms	remaining: 11.1s
7:	learn: 11.7814654	total: 81.1ms	remaining: 10.1s
8:	learn: 11.6301723	total: 84.8ms	remaining: 9.33s
9:	learn: 11.4982404	total: 88.2ms	remaining: 8.73s
10:	learn: 11.3827204	total: 91.7ms	remaining: 8.24s
11:	learn: 11.2817363	total: 95ms	remaining: 7.82s
12:	learn: 11.1945273	total: 97.6ms	remaining: 7.41s
13:	learn: 11.1178831	total: 100ms	remaining: 7.08s
14:	learn: 11.0444158	total: 103ms	remaining: 6.76s
15:	learn: 10.9694920	total: 106ms	remaining: 6.55s
16:	learn: 10.9035157	total: 110ms	remaining: 6.37s
17:	learn: 10.8428278	total: 114ms	remaining: 6.21s
18:	learn: 10.7877635	total: 117ms	

182:	learn: 9.9220053	total: 779ms	remaining: 3.48s
183:	learn: 9.9213978	total: 783ms	remaining: 3.47s
184:	learn: 9.9208337	total: 787ms	remaining: 3.47s
185:	learn: 9.9204366	total: 796ms	remaining: 3.48s
186:	learn: 9.9176618	total: 802ms	remaining: 3.49s
187:	learn: 9.9167612	total: 807ms	remaining: 3.48s
188:	learn: 9.9157873	total: 812ms	remaining: 3.48s
189:	learn: 9.9151436	total: 816ms	remaining: 3.48s
190:	learn: 9.9143503	total: 820ms	remaining: 3.47s
191:	learn: 9.9138578	total: 824ms	remaining: 3.47s
192:	learn: 9.9133003	total: 829ms	remaining: 3.46s
193:	learn: 9.9127300	total: 833ms	remaining: 3.46s
194:	learn: 9.9119513	total: 838ms	remaining: 3.46s
195:	learn: 9.9114774	total: 842ms	remaining: 3.46s
196:	learn: 9.9109674	total: 847ms	remaining: 3.45s
197:	learn: 9.9098358	total: 850ms	remaining: 3.44s
198:	learn: 9.9091942	total: 855ms	remaining: 3.44s
199:	learn: 9.9085088	total: 859ms	remaining: 3.44s
200:	learn: 9.9080626	total: 862ms	remaining: 3.43s
201:	learn: 

352:	learn: 9.8205318	total: 1.56s	remaining: 2.87s
353:	learn: 9.8204226	total: 1.57s	remaining: 2.86s
354:	learn: 9.8199123	total: 1.57s	remaining: 2.86s
355:	learn: 9.8185144	total: 1.58s	remaining: 2.86s
356:	learn: 9.8183074	total: 1.59s	remaining: 2.86s
357:	learn: 9.8178992	total: 1.59s	remaining: 2.85s
358:	learn: 9.8177282	total: 1.59s	remaining: 2.85s
359:	learn: 9.8176825	total: 1.6s	remaining: 2.85s
360:	learn: 9.8170927	total: 1.61s	remaining: 2.84s
361:	learn: 9.8158915	total: 1.61s	remaining: 2.84s
362:	learn: 9.8153845	total: 1.61s	remaining: 2.83s
363:	learn: 9.8149244	total: 1.62s	remaining: 2.83s
364:	learn: 9.8144596	total: 1.63s	remaining: 2.83s
365:	learn: 9.8142553	total: 1.64s	remaining: 2.84s
366:	learn: 9.8138309	total: 1.64s	remaining: 2.83s
367:	learn: 9.8132422	total: 1.65s	remaining: 2.83s
368:	learn: 9.8128621	total: 1.65s	remaining: 2.83s
369:	learn: 9.8127884	total: 1.66s	remaining: 2.82s
370:	learn: 9.8125031	total: 1.66s	remaining: 2.82s
371:	learn: 9

538:	learn: 9.7520818	total: 2.35s	remaining: 2.01s
539:	learn: 9.7519914	total: 2.35s	remaining: 2.01s
540:	learn: 9.7516445	total: 2.36s	remaining: 2s
541:	learn: 9.7508368	total: 2.37s	remaining: 2s
542:	learn: 9.7503338	total: 2.37s	remaining: 2s
543:	learn: 9.7502040	total: 2.37s	remaining: 1.99s
544:	learn: 9.7499156	total: 2.38s	remaining: 1.99s
545:	learn: 9.7496332	total: 2.38s	remaining: 1.98s
546:	learn: 9.7494696	total: 2.38s	remaining: 1.98s
547:	learn: 9.7490644	total: 2.39s	remaining: 1.97s
548:	learn: 9.7488161	total: 2.39s	remaining: 1.97s
549:	learn: 9.7487908	total: 2.4s	remaining: 1.96s
550:	learn: 9.7487715	total: 2.4s	remaining: 1.96s
551:	learn: 9.7487688	total: 2.4s	remaining: 1.95s
552:	learn: 9.7482377	total: 2.41s	remaining: 1.95s
553:	learn: 9.7478952	total: 2.41s	remaining: 1.94s
554:	learn: 9.7476968	total: 2.42s	remaining: 1.94s
555:	learn: 9.7469132	total: 2.42s	remaining: 1.93s
556:	learn: 9.7467012	total: 2.42s	remaining: 1.93s
557:	learn: 9.7462111	to

712:	learn: 9.7109340	total: 3.14s	remaining: 1.26s
713:	learn: 9.7109322	total: 3.14s	remaining: 1.26s
714:	learn: 9.7108366	total: 3.15s	remaining: 1.25s
715:	learn: 9.7107134	total: 3.15s	remaining: 1.25s
716:	learn: 9.7106597	total: 3.15s	remaining: 1.25s
717:	learn: 9.7099956	total: 3.16s	remaining: 1.24s
718:	learn: 9.7099135	total: 3.16s	remaining: 1.24s
719:	learn: 9.7099116	total: 3.17s	remaining: 1.23s
720:	learn: 9.7096179	total: 3.17s	remaining: 1.23s
721:	learn: 9.7093444	total: 3.18s	remaining: 1.22s
722:	learn: 9.7093001	total: 3.19s	remaining: 1.22s
723:	learn: 9.7092525	total: 3.19s	remaining: 1.22s
724:	learn: 9.7092235	total: 3.19s	remaining: 1.21s
725:	learn: 9.7090194	total: 3.2s	remaining: 1.21s
726:	learn: 9.7084508	total: 3.2s	remaining: 1.2s
727:	learn: 9.7083774	total: 3.21s	remaining: 1.2s
728:	learn: 9.7083064	total: 3.22s	remaining: 1.2s
729:	learn: 9.7080821	total: 3.22s	remaining: 1.19s
730:	learn: 9.7076323	total: 3.23s	remaining: 1.19s
731:	learn: 9.707

871:	learn: 9.6819853	total: 3.93s	remaining: 576ms
872:	learn: 9.6819032	total: 3.93s	remaining: 572ms
873:	learn: 9.6816628	total: 3.93s	remaining: 567ms
874:	learn: 9.6815814	total: 3.94s	remaining: 563ms
875:	learn: 9.6812124	total: 3.94s	remaining: 559ms
876:	learn: 9.6811739	total: 3.95s	remaining: 554ms
877:	learn: 9.6811217	total: 3.95s	remaining: 549ms
878:	learn: 9.6811188	total: 3.96s	remaining: 545ms
879:	learn: 9.6809680	total: 3.96s	remaining: 540ms
880:	learn: 9.6803627	total: 3.97s	remaining: 536ms
881:	learn: 9.6801200	total: 3.97s	remaining: 531ms
882:	learn: 9.6800627	total: 3.97s	remaining: 527ms
883:	learn: 9.6799711	total: 3.98s	remaining: 522ms
884:	learn: 9.6797752	total: 3.98s	remaining: 517ms
885:	learn: 9.6796200	total: 3.99s	remaining: 513ms
886:	learn: 9.6795842	total: 3.99s	remaining: 508ms
887:	learn: 9.6794592	total: 4s	remaining: 504ms
888:	learn: 9.6794454	total: 4s	remaining: 499ms
889:	learn: 9.6793670	total: 4s	remaining: 495ms
890:	learn: 9.6793088

7.077732086181641:   4%|‚ñç         | 9/200 [02:12<46:50, 14.71s/it]


RuntimeError: 

In [None]:
print("Train")
y_hat_train_tree = tfb.tree_model.predict(x_train)
print(mean_squared_error(y_train, y_hat_train_tree))

y_hat_train_tfb = tfb.predict(x_train, num_samples=100)
print(mean_squared_error(y_train, y_hat_train_tfb))

print("Test")
y_hat_test_tree = tfb.tree_model.predict(x_test)
print(mean_squared_error(y_test, y_hat_test_tree))

y_hat_test_tfb = tfb.predict(x_test, num_samples=100)
print(mean_squared_error(y_test, y_hat_test_tfb))