In [1]:
import numpy as np 

import torch

from pytorch_tabnet.tab_model import TabNetRegressor
from pytorch_tabnet.metrics import Metric
from scipy.stats import spearmanr

In [2]:
class spearmanr_metric(Metric):
    """
    Spearman correlation metric
    """

    def __init__(self):
        self._name = "spearmanr"
        self._maximize = True

    def __call__(self, y_true, y_score):
        """
        Compute Spearman r of predictions.

        Parameters
        ----------
        y_true: np.ndarray
            Target matrix or vector
        y_score: np.ndarray
            Score matrix or vector

        Returns
        -------
            float
            Spearman r of predictions vs targets.
        """
        return spearmanr(y_score, y_true).statistic

def spearmanr_loss_fn(y_pred, y_true):
    """
    Dummy example similar to using default torch.nn.functional.cross_entropy
    """
    num = torch.sum((y_pred - torch.mean(y_pred)) * (y_true - torch.mean(y_true)))
    den = torch.norm(y_pred - torch.mean(y_pred)) * torch.norm(y_true - torch.mean(y_true))
    
    return 1 - num/den

In [3]:
n_chunks = 5

clf = TabNetRegressor()

chunk_num = 0
warm_start = False

for train_chunk_num in range(1, n_chunks):
    train_X = np.load(f"train_X_{train_chunk_num}.npy")
    train_Y = np.load(f"train_Y_{train_chunk_num}.npy")
    for val_chunk_num in range(1, n_chunks):

        print("TRAIN_NUM: ", train_chunk_num, ", VAL_NUM: ", val_chunk_num)

        val_X = np.load(f"val_X_{val_chunk_num}.npy")
        val_Y = np.load(f"val_Y_{val_chunk_num}.npy")
        
        clf.fit(
            train_X.reshape(train_X.shape[0], -1), train_Y.reshape(-1, 1),
            eval_set=[(val_X.reshape(val_X.shape[0], -1), val_Y.reshape(-1, 1))],
            eval_metric=['spearmanr', spearmanr_metric],
            loss_fn=spearmanr_loss_fn,
            batch_size=1024, virtual_batch_size=128,
            max_epochs=50 , patience=10,
            drop_last=False, warm_start=warm_start,
        )
    
        warm_start = True

# save tabnet model
saving_path_name = "./tabnet_model_test_1"
saved_filepath = clf.save_model(saving_path_name)



TRAIN_NUM:  1 , VAL_NUM:  1
epoch 0  | loss: 0.98693 | val_0_spearmanr: 0.02345 |  0:00:03s
epoch 1  | loss: 0.99182 | val_0_spearmanr: 0.00326 |  0:00:06s
epoch 2  | loss: 0.97326 | val_0_spearmanr: 0.01664 |  0:00:09s
epoch 3  | loss: 0.96265 | val_0_spearmanr: -0.00934|  0:00:13s
epoch 4  | loss: 0.95075 | val_0_spearmanr: -0.01168|  0:00:16s
epoch 5  | loss: 0.94479 | val_0_spearmanr: 0.0433  |  0:00:19s
epoch 6  | loss: 0.9507  | val_0_spearmanr: 0.02598 |  0:00:22s
epoch 7  | loss: 0.9416  | val_0_spearmanr: -0.02945|  0:00:25s
epoch 8  | loss: 0.97185 | val_0_spearmanr: -0.03285|  0:00:28s
epoch 9  | loss: 0.93883 | val_0_spearmanr: -0.00203|  0:00:31s
epoch 10 | loss: 0.81317 | val_0_spearmanr: -0.01054|  0:00:34s
epoch 11 | loss: 0.79402 | val_0_spearmanr: -0.01527|  0:00:37s
epoch 12 | loss: 0.81463 | val_0_spearmanr: -0.00955|  0:00:40s
epoch 13 | loss: 0.78666 | val_0_spearmanr: 0.02664 |  0:00:43s
epoch 14 | loss: 0.7659  | val_0_spearmanr: -0.03208|  0:00:47s
epoch 15 | l



TRAIN_NUM:  1 , VAL_NUM:  2
epoch 0  | loss: 0.97367 | val_0_spearmanr: 0.03336 |  0:00:03s
epoch 1  | loss: 0.95495 | val_0_spearmanr: 0.00431 |  0:00:06s
epoch 2  | loss: 0.94104 | val_0_spearmanr: 0.06734 |  0:00:09s
epoch 3  | loss: 0.87406 | val_0_spearmanr: 0.0699  |  0:00:12s
epoch 4  | loss: 0.94377 | val_0_spearmanr: 0.00573 |  0:00:15s
epoch 5  | loss: 0.96111 | val_0_spearmanr: 0.04524 |  0:00:18s
epoch 6  | loss: 0.91848 | val_0_spearmanr: 0.05543 |  0:00:21s
epoch 7  | loss: 0.89217 | val_0_spearmanr: 0.01172 |  0:00:24s
epoch 8  | loss: 0.80836 | val_0_spearmanr: -0.0089 |  0:00:28s
epoch 9  | loss: 0.79394 | val_0_spearmanr: 0.02695 |  0:00:31s
epoch 10 | loss: 0.77219 | val_0_spearmanr: 0.07131 |  0:00:34s
epoch 11 | loss: 0.7912  | val_0_spearmanr: 0.06087 |  0:00:37s
epoch 12 | loss: 0.7732  | val_0_spearmanr: 0.08622 |  0:00:40s
epoch 13 | loss: 0.78778 | val_0_spearmanr: 0.08745 |  0:00:43s
epoch 14 | loss: 0.78226 | val_0_spearmanr: 0.0529  |  0:00:46s
epoch 15 | l



TRAIN_NUM:  1 , VAL_NUM:  3
epoch 0  | loss: 0.65058 | val_0_spearmanr: 0.1914  |  0:00:03s
epoch 1  | loss: 0.64148 | val_0_spearmanr: 0.17395 |  0:00:06s
epoch 2  | loss: 0.63983 | val_0_spearmanr: 0.14325 |  0:00:09s
epoch 3  | loss: 0.65433 | val_0_spearmanr: 0.19141 |  0:00:12s
epoch 4  | loss: 0.6394  | val_0_spearmanr: 0.19171 |  0:00:15s
epoch 5  | loss: 0.63224 | val_0_spearmanr: 0.19384 |  0:00:18s
epoch 6  | loss: 0.63573 | val_0_spearmanr: 0.16959 |  0:00:21s
epoch 7  | loss: 0.6487  | val_0_spearmanr: 0.16903 |  0:00:25s
epoch 8  | loss: 0.62996 | val_0_spearmanr: 0.1746  |  0:00:28s
epoch 9  | loss: 0.65031 | val_0_spearmanr: 0.19464 |  0:00:31s
epoch 10 | loss: 0.60744 | val_0_spearmanr: 0.16772 |  0:00:34s
epoch 11 | loss: 0.60088 | val_0_spearmanr: 0.15015 |  0:00:37s
epoch 12 | loss: 0.59504 | val_0_spearmanr: 0.13821 |  0:00:40s
epoch 13 | loss: 0.58319 | val_0_spearmanr: 0.18087 |  0:00:43s
epoch 14 | loss: 0.56603 | val_0_spearmanr: 0.1674  |  0:00:47s
epoch 15 | l



TRAIN_NUM:  1 , VAL_NUM:  4
epoch 0  | loss: 0.61998 | val_0_spearmanr: 0.21046 |  0:00:03s
epoch 1  | loss: 0.63011 | val_0_spearmanr: 0.19511 |  0:00:06s
epoch 2  | loss: 0.64576 | val_0_spearmanr: 0.20095 |  0:00:10s
epoch 3  | loss: 0.64626 | val_0_spearmanr: 0.21577 |  0:00:13s
epoch 4  | loss: 0.6255  | val_0_spearmanr: 0.22606 |  0:00:17s
epoch 5  | loss: 0.61112 | val_0_spearmanr: 0.22163 |  0:00:20s
epoch 6  | loss: 0.59243 | val_0_spearmanr: 0.22884 |  0:00:24s
epoch 7  | loss: 0.58452 | val_0_spearmanr: 0.22692 |  0:00:27s
epoch 8  | loss: 0.57635 | val_0_spearmanr: 0.24084 |  0:00:31s
epoch 9  | loss: 0.57254 | val_0_spearmanr: 0.24072 |  0:00:34s
epoch 10 | loss: 0.57956 | val_0_spearmanr: 0.23787 |  0:00:38s
epoch 11 | loss: 0.54006 | val_0_spearmanr: 0.2085  |  0:00:41s
epoch 12 | loss: 0.53685 | val_0_spearmanr: 0.22221 |  0:00:45s
epoch 13 | loss: 0.51602 | val_0_spearmanr: 0.216   |  0:00:48s
epoch 14 | loss: 0.49969 | val_0_spearmanr: 0.24301 |  0:00:52s
epoch 15 | l



TRAIN_NUM:  2 , VAL_NUM:  1
epoch 0  | loss: 0.65601 | val_0_spearmanr: 0.23214 |  0:00:03s
epoch 1  | loss: 0.63857 | val_0_spearmanr: 0.23848 |  0:00:06s
epoch 2  | loss: 0.62706 | val_0_spearmanr: 0.26888 |  0:00:09s
epoch 3  | loss: 0.6172  | val_0_spearmanr: 0.27456 |  0:00:12s
epoch 4  | loss: 0.6039  | val_0_spearmanr: 0.25967 |  0:00:15s
epoch 5  | loss: 0.58414 | val_0_spearmanr: 0.25144 |  0:00:18s
epoch 6  | loss: 0.57484 | val_0_spearmanr: 0.21931 |  0:00:22s
epoch 7  | loss: 0.62591 | val_0_spearmanr: 0.21462 |  0:00:25s
epoch 8  | loss: 0.60153 | val_0_spearmanr: 0.19891 |  0:00:28s
epoch 9  | loss: 0.55793 | val_0_spearmanr: 0.21323 |  0:00:31s
epoch 10 | loss: 0.53581 | val_0_spearmanr: 0.23026 |  0:00:34s
epoch 11 | loss: 0.52384 | val_0_spearmanr: 0.24927 |  0:00:37s
epoch 12 | loss: 0.54101 | val_0_spearmanr: 0.21892 |  0:00:40s
epoch 13 | loss: 0.50822 | val_0_spearmanr: 0.18435 |  0:00:43s

Early stopping occurred at epoch 13 with best_epoch = 3 and best_val_0_spea



TRAIN_NUM:  2 , VAL_NUM:  2
epoch 0  | loss: 0.62383 | val_0_spearmanr: 0.22299 |  0:00:03s
epoch 1  | loss: 0.61812 | val_0_spearmanr: 0.21895 |  0:00:06s
epoch 2  | loss: 0.59333 | val_0_spearmanr: 0.18425 |  0:00:09s
epoch 3  | loss: 0.57126 | val_0_spearmanr: 0.20406 |  0:00:12s
epoch 4  | loss: 0.55669 | val_0_spearmanr: 0.24899 |  0:00:15s
epoch 5  | loss: 0.56642 | val_0_spearmanr: 0.18403 |  0:00:18s
epoch 6  | loss: 0.54473 | val_0_spearmanr: 0.19738 |  0:00:21s
epoch 7  | loss: 0.53134 | val_0_spearmanr: 0.19952 |  0:00:25s
epoch 8  | loss: 0.51744 | val_0_spearmanr: 0.23134 |  0:00:28s
epoch 9  | loss: 0.51687 | val_0_spearmanr: 0.23204 |  0:00:31s
epoch 10 | loss: 0.50004 | val_0_spearmanr: 0.21186 |  0:00:34s
epoch 11 | loss: 0.48613 | val_0_spearmanr: 0.21657 |  0:00:37s
epoch 12 | loss: 0.4798  | val_0_spearmanr: 0.24242 |  0:00:40s
epoch 13 | loss: 0.46695 | val_0_spearmanr: 0.22519 |  0:00:43s
epoch 14 | loss: 0.45861 | val_0_spearmanr: 0.24102 |  0:00:47s

Early stopp



TRAIN_NUM:  2 , VAL_NUM:  3
epoch 0  | loss: 0.57871 | val_0_spearmanr: 0.19094 |  0:00:03s
epoch 1  | loss: 0.55977 | val_0_spearmanr: 0.21115 |  0:00:06s
epoch 2  | loss: 0.55556 | val_0_spearmanr: 0.19487 |  0:00:09s
epoch 3  | loss: 0.52739 | val_0_spearmanr: 0.19201 |  0:00:12s
epoch 4  | loss: 0.54604 | val_0_spearmanr: 0.16154 |  0:00:15s
epoch 5  | loss: 0.54831 | val_0_spearmanr: 0.19401 |  0:00:18s
epoch 6  | loss: 0.50867 | val_0_spearmanr: 0.2031  |  0:00:21s
epoch 7  | loss: 0.50148 | val_0_spearmanr: 0.20023 |  0:00:25s
epoch 8  | loss: 0.47973 | val_0_spearmanr: 0.194   |  0:00:28s
epoch 9  | loss: 0.47104 | val_0_spearmanr: 0.18497 |  0:00:31s
epoch 10 | loss: 0.47284 | val_0_spearmanr: 0.20827 |  0:00:34s
epoch 11 | loss: 0.45239 | val_0_spearmanr: 0.21828 |  0:00:37s
epoch 12 | loss: 0.45207 | val_0_spearmanr: 0.20148 |  0:00:40s
epoch 13 | loss: 0.44771 | val_0_spearmanr: 0.18226 |  0:00:43s
epoch 14 | loss: 0.43168 | val_0_spearmanr: 0.19265 |  0:00:47s
epoch 15 | l



TRAIN_NUM:  2 , VAL_NUM:  4
epoch 0  | loss: 0.4067  | val_0_spearmanr: 0.28778 |  0:00:03s
epoch 1  | loss: 0.39708 | val_0_spearmanr: 0.32051 |  0:00:06s
epoch 2  | loss: 0.39484 | val_0_spearmanr: 0.3124  |  0:00:10s
epoch 3  | loss: 0.37638 | val_0_spearmanr: 0.29391 |  0:00:13s
epoch 4  | loss: 0.38246 | val_0_spearmanr: 0.28642 |  0:00:17s
epoch 5  | loss: 0.36183 | val_0_spearmanr: 0.29944 |  0:00:20s
epoch 6  | loss: 0.35435 | val_0_spearmanr: 0.31987 |  0:00:24s
epoch 7  | loss: 0.34519 | val_0_spearmanr: 0.31657 |  0:00:27s
epoch 8  | loss: 0.33688 | val_0_spearmanr: 0.30955 |  0:00:31s
epoch 9  | loss: 0.32417 | val_0_spearmanr: 0.31279 |  0:00:34s
epoch 10 | loss: 0.31977 | val_0_spearmanr: 0.30146 |  0:00:38s
epoch 11 | loss: 0.30909 | val_0_spearmanr: 0.29455 |  0:00:41s

Early stopping occurred at epoch 11 with best_epoch = 1 and best_val_0_spearmanr = 0.32051




TRAIN_NUM:  3 , VAL_NUM:  1
epoch 0  | loss: 0.62939 | val_0_spearmanr: 0.2583  |  0:00:03s
epoch 1  | loss: 0.58922 | val_0_spearmanr: 0.26767 |  0:00:06s
epoch 2  | loss: 0.55667 | val_0_spearmanr: 0.23714 |  0:00:09s
epoch 3  | loss: 0.54134 | val_0_spearmanr: 0.24767 |  0:00:12s
epoch 4  | loss: 0.50555 | val_0_spearmanr: 0.25984 |  0:00:15s
epoch 5  | loss: 0.49357 | val_0_spearmanr: 0.25773 |  0:00:18s
epoch 6  | loss: 0.47899 | val_0_spearmanr: 0.22589 |  0:00:21s
epoch 7  | loss: 0.46129 | val_0_spearmanr: 0.22508 |  0:00:25s
epoch 8  | loss: 0.45001 | val_0_spearmanr: 0.28197 |  0:00:28s
epoch 9  | loss: 0.42876 | val_0_spearmanr: 0.23647 |  0:00:31s
epoch 10 | loss: 0.42153 | val_0_spearmanr: 0.22456 |  0:00:34s
epoch 11 | loss: 0.40598 | val_0_spearmanr: 0.24172 |  0:00:37s
epoch 12 | loss: 0.38344 | val_0_spearmanr: 0.24458 |  0:00:40s
epoch 13 | loss: 0.36952 | val_0_spearmanr: 0.23234 |  0:00:43s
epoch 14 | loss: 0.36229 | val_0_spearmanr: 0.22564 |  0:00:47s
epoch 15 | l



TRAIN_NUM:  3 , VAL_NUM:  2
epoch 0  | loss: 0.45524 | val_0_spearmanr: 0.30839 |  0:00:03s
epoch 1  | loss: 0.4452  | val_0_spearmanr: 0.2927  |  0:00:06s
epoch 2  | loss: 0.41253 | val_0_spearmanr: 0.29446 |  0:00:09s
epoch 3  | loss: 0.39723 | val_0_spearmanr: 0.3031  |  0:00:12s
epoch 4  | loss: 0.38161 | val_0_spearmanr: 0.31325 |  0:00:15s
epoch 5  | loss: 0.35317 | val_0_spearmanr: 0.32263 |  0:00:18s
epoch 6  | loss: 0.34671 | val_0_spearmanr: 0.29998 |  0:00:21s
epoch 7  | loss: 0.329   | val_0_spearmanr: 0.31972 |  0:00:25s
epoch 8  | loss: 0.32432 | val_0_spearmanr: 0.29115 |  0:00:28s
epoch 9  | loss: 0.30338 | val_0_spearmanr: 0.29719 |  0:00:31s
epoch 10 | loss: 0.29868 | val_0_spearmanr: 0.29357 |  0:00:34s
epoch 11 | loss: 0.29102 | val_0_spearmanr: 0.31207 |  0:00:37s
epoch 12 | loss: 0.28606 | val_0_spearmanr: 0.29029 |  0:00:40s
epoch 13 | loss: 0.26979 | val_0_spearmanr: 0.29659 |  0:00:43s
epoch 14 | loss: 0.25528 | val_0_spearmanr: 0.3101  |  0:00:47s
epoch 15 | l



TRAIN_NUM:  3 , VAL_NUM:  3
epoch 0  | loss: 0.38872 | val_0_spearmanr: 0.25551 |  0:00:03s
epoch 1  | loss: 0.35029 | val_0_spearmanr: 0.26033 |  0:00:06s
epoch 2  | loss: 0.32906 | val_0_spearmanr: 0.26204 |  0:00:09s
epoch 3  | loss: 0.30402 | val_0_spearmanr: 0.21845 |  0:00:12s
epoch 4  | loss: 0.29564 | val_0_spearmanr: 0.24502 |  0:00:15s
epoch 5  | loss: 0.27324 | val_0_spearmanr: 0.25097 |  0:00:18s
epoch 6  | loss: 0.27233 | val_0_spearmanr: 0.23907 |  0:00:21s
epoch 7  | loss: 0.26783 | val_0_spearmanr: 0.22416 |  0:00:25s
epoch 8  | loss: 0.24547 | val_0_spearmanr: 0.23176 |  0:00:28s
epoch 9  | loss: 0.24201 | val_0_spearmanr: 0.2337  |  0:00:31s
epoch 10 | loss: 0.23145 | val_0_spearmanr: 0.24637 |  0:00:34s
epoch 11 | loss: 0.23344 | val_0_spearmanr: 0.23592 |  0:00:37s
epoch 12 | loss: 0.21395 | val_0_spearmanr: 0.23689 |  0:00:40s

Early stopping occurred at epoch 12 with best_epoch = 2 and best_val_0_spearmanr = 0.26204




TRAIN_NUM:  3 , VAL_NUM:  4
epoch 0  | loss: 0.33448 | val_0_spearmanr: 0.32022 |  0:00:03s
epoch 1  | loss: 0.30024 | val_0_spearmanr: 0.32432 |  0:00:06s
epoch 2  | loss: 0.28981 | val_0_spearmanr: 0.33176 |  0:00:10s
epoch 3  | loss: 0.27906 | val_0_spearmanr: 0.33431 |  0:00:13s
epoch 4  | loss: 0.26365 | val_0_spearmanr: 0.32124 |  0:00:17s
epoch 5  | loss: 0.25472 | val_0_spearmanr: 0.32218 |  0:00:20s
epoch 6  | loss: 0.2437  | val_0_spearmanr: 0.30452 |  0:00:24s
epoch 7  | loss: 0.22744 | val_0_spearmanr: 0.30716 |  0:00:27s
epoch 8  | loss: 0.22379 | val_0_spearmanr: 0.32364 |  0:00:31s
epoch 9  | loss: 0.22288 | val_0_spearmanr: 0.32155 |  0:00:34s
epoch 10 | loss: 0.21393 | val_0_spearmanr: 0.31573 |  0:00:38s
epoch 11 | loss: 0.21149 | val_0_spearmanr: 0.31225 |  0:00:41s
epoch 12 | loss: 0.20239 | val_0_spearmanr: 0.3068  |  0:00:45s
epoch 13 | loss: 0.19592 | val_0_spearmanr: 0.33325 |  0:00:48s

Early stopping occurred at epoch 13 with best_epoch = 3 and best_val_0_spea



TRAIN_NUM:  4 , VAL_NUM:  1
epoch 0  | loss: 0.56032 | val_0_spearmanr: 0.27088 |  0:00:05s
epoch 1  | loss: 0.49542 | val_0_spearmanr: 0.25879 |  0:00:11s
epoch 2  | loss: 0.4606  | val_0_spearmanr: 0.2812  |  0:00:17s
epoch 3  | loss: 0.43375 | val_0_spearmanr: 0.30031 |  0:00:23s
epoch 4  | loss: 0.40977 | val_0_spearmanr: 0.30093 |  0:00:29s
epoch 5  | loss: 0.386   | val_0_spearmanr: 0.27286 |  0:00:35s
epoch 6  | loss: 0.36287 | val_0_spearmanr: 0.28188 |  0:00:40s
epoch 7  | loss: 0.34781 | val_0_spearmanr: 0.27771 |  0:00:46s
epoch 8  | loss: 0.33023 | val_0_spearmanr: 0.30344 |  0:00:52s
epoch 9  | loss: 0.30863 | val_0_spearmanr: 0.31578 |  0:00:58s
epoch 10 | loss: 0.29787 | val_0_spearmanr: 0.27139 |  0:01:04s
epoch 11 | loss: 0.29477 | val_0_spearmanr: 0.29713 |  0:01:10s
epoch 12 | loss: 0.28114 | val_0_spearmanr: 0.2886  |  0:01:15s
epoch 13 | loss: 0.27085 | val_0_spearmanr: 0.25641 |  0:01:21s
epoch 14 | loss: 0.2647  | val_0_spearmanr: 0.3031  |  0:01:27s
epoch 15 | l



TRAIN_NUM:  4 , VAL_NUM:  2
epoch 0  | loss: 0.31867 | val_0_spearmanr: 0.33723 |  0:00:05s
epoch 1  | loss: 0.30361 | val_0_spearmanr: 0.37365 |  0:00:11s
epoch 2  | loss: 0.2947  | val_0_spearmanr: 0.34437 |  0:00:17s
epoch 3  | loss: 0.27323 | val_0_spearmanr: 0.35942 |  0:00:23s
epoch 4  | loss: 0.26314 | val_0_spearmanr: 0.38765 |  0:00:29s
epoch 5  | loss: 0.2639  | val_0_spearmanr: 0.36902 |  0:00:35s
epoch 6  | loss: 0.25233 | val_0_spearmanr: 0.35091 |  0:00:40s
epoch 7  | loss: 0.24754 | val_0_spearmanr: 0.34546 |  0:00:46s
epoch 8  | loss: 0.24168 | val_0_spearmanr: 0.3768  |  0:00:52s
epoch 9  | loss: 0.23699 | val_0_spearmanr: 0.36417 |  0:00:58s
epoch 10 | loss: 0.23254 | val_0_spearmanr: 0.3359  |  0:01:04s
epoch 11 | loss: 0.22736 | val_0_spearmanr: 0.3468  |  0:01:10s
epoch 12 | loss: 0.22058 | val_0_spearmanr: 0.35267 |  0:01:15s
epoch 13 | loss: 0.21661 | val_0_spearmanr: 0.35482 |  0:01:21s
epoch 14 | loss: 0.21552 | val_0_spearmanr: 0.36229 |  0:01:27s

Early stopp



TRAIN_NUM:  4 , VAL_NUM:  3
epoch 0  | loss: 0.28342 | val_0_spearmanr: 0.33256 |  0:00:05s
epoch 1  | loss: 0.27198 | val_0_spearmanr: 0.27499 |  0:00:11s
epoch 2  | loss: 0.25675 | val_0_spearmanr: 0.33068 |  0:00:17s
epoch 3  | loss: 0.24512 | val_0_spearmanr: 0.26895 |  0:00:23s
epoch 4  | loss: 0.24167 | val_0_spearmanr: 0.34228 |  0:00:29s
epoch 5  | loss: 0.22869 | val_0_spearmanr: 0.30433 |  0:00:35s
epoch 6  | loss: 0.22538 | val_0_spearmanr: 0.29057 |  0:00:40s
epoch 7  | loss: 0.22749 | val_0_spearmanr: 0.30121 |  0:00:46s
epoch 8  | loss: 0.21292 | val_0_spearmanr: 0.30483 |  0:00:52s
epoch 9  | loss: 0.21249 | val_0_spearmanr: 0.3181  |  0:00:58s
epoch 10 | loss: 0.21191 | val_0_spearmanr: 0.31735 |  0:01:04s
epoch 11 | loss: 0.2119  | val_0_spearmanr: 0.31162 |  0:01:10s
epoch 12 | loss: 0.20908 | val_0_spearmanr: 0.336   |  0:01:15s
epoch 13 | loss: 0.20406 | val_0_spearmanr: 0.3166  |  0:01:21s
epoch 14 | loss: 0.20035 | val_0_spearmanr: 0.322   |  0:01:27s

Early stopp



TRAIN_NUM:  4 , VAL_NUM:  4
epoch 0  | loss: 0.25614 | val_0_spearmanr: 0.347   |  0:00:06s
epoch 1  | loss: 0.24283 | val_0_spearmanr: 0.33794 |  0:00:12s
epoch 2  | loss: 0.23003 | val_0_spearmanr: 0.34997 |  0:00:18s
epoch 3  | loss: 0.22575 | val_0_spearmanr: 0.34116 |  0:00:24s
epoch 4  | loss: 0.21828 | val_0_spearmanr: 0.35122 |  0:00:31s
epoch 5  | loss: 0.21588 | val_0_spearmanr: 0.34667 |  0:00:37s
epoch 6  | loss: 0.20382 | val_0_spearmanr: 0.35222 |  0:00:43s
epoch 7  | loss: 0.20589 | val_0_spearmanr: 0.35443 |  0:00:49s
epoch 8  | loss: 0.19999 | val_0_spearmanr: 0.35695 |  0:00:55s
epoch 9  | loss: 0.1915  | val_0_spearmanr: 0.32988 |  0:01:01s
epoch 10 | loss: 0.19319 | val_0_spearmanr: 0.36075 |  0:01:08s
epoch 11 | loss: 0.1939  | val_0_spearmanr: 0.35721 |  0:01:14s
epoch 12 | loss: 0.1874  | val_0_spearmanr: 0.34805 |  0:01:20s
epoch 13 | loss: 0.18604 | val_0_spearmanr: 0.3484  |  0:01:26s
epoch 14 | loss: 0.1837  | val_0_spearmanr: 0.34523 |  0:01:32s
epoch 15 | l



Successfully saved model at ./tabnet_model_test_1.zip


In [4]:
n_chunks = 5

chunk_num = 0

for train_chunk_num in np.random.permutation(range(1, n_chunks)):
    train_X = np.load(f"train_X_{train_chunk_num}.npy")
    train_Y = np.load(f"train_Y_{train_chunk_num}.npy")
    
    for val_chunk_num in np.random.permutation(range(1, n_chunks)):

        print("TRAIN_NUM: ", train_chunk_num, ", VAL_NUM: ", val_chunk_num)

        val_X = np.load(f"val_X_{val_chunk_num}.npy")
        val_Y = np.load(f"val_Y_{val_chunk_num}.npy")
        
        clf.fit(
            train_X.reshape(train_X.shape[0], -1), train_Y.reshape(-1, 1),
            eval_set=[(val_X.reshape(val_X.shape[0], -1), val_Y.reshape(-1, 1))],
            eval_metric=['spearmanr', spearmanr_metric],
            loss_fn=spearmanr_loss_fn,
            batch_size=1024, virtual_batch_size=128,
            max_epochs=50 , patience=10,
            drop_last=False, warm_start=True,
        )
    
# save tabnet model
saving_path_name = "./tabnet_model_test_2"
saved_filepath = clf.save_model(saving_path_name)

TRAIN_NUM:  4 , VAL_NUM:  2
epoch 0  | loss: 0.21373 | val_0_spearmanr: 0.34241 |  0:00:05s
epoch 1  | loss: 0.20173 | val_0_spearmanr: 0.3207  |  0:00:11s
epoch 2  | loss: 0.19663 | val_0_spearmanr: 0.35533 |  0:00:17s
epoch 3  | loss: 0.19178 | val_0_spearmanr: 0.35332 |  0:00:22s
epoch 4  | loss: 0.18888 | val_0_spearmanr: 0.34361 |  0:00:28s
epoch 5  | loss: 0.1859  | val_0_spearmanr: 0.35847 |  0:00:34s
epoch 6  | loss: 0.18697 | val_0_spearmanr: 0.33475 |  0:00:40s
epoch 7  | loss: 0.18206 | val_0_spearmanr: 0.33783 |  0:00:46s
epoch 8  | loss: 0.17869 | val_0_spearmanr: 0.34326 |  0:00:51s
epoch 9  | loss: 0.17957 | val_0_spearmanr: 0.3395  |  0:00:57s
epoch 10 | loss: 0.17571 | val_0_spearmanr: 0.34456 |  0:01:03s
epoch 11 | loss: 0.17451 | val_0_spearmanr: 0.32147 |  0:01:09s
epoch 12 | loss: 0.17436 | val_0_spearmanr: 0.3387  |  0:01:14s
epoch 13 | loss: 0.17218 | val_0_spearmanr: 0.3473  |  0:01:20s
epoch 14 | loss: 0.16816 | val_0_spearmanr: 0.32698 |  0:01:26s
epoch 15 | l



TRAIN_NUM:  4 , VAL_NUM:  3
epoch 0  | loss: 0.19585 | val_0_spearmanr: 0.31855 |  0:00:05s
epoch 1  | loss: 0.19041 | val_0_spearmanr: 0.33144 |  0:00:11s
epoch 2  | loss: 0.1907  | val_0_spearmanr: 0.32718 |  0:00:17s
epoch 3  | loss: 0.18383 | val_0_spearmanr: 0.31521 |  0:00:23s
epoch 4  | loss: 0.17644 | val_0_spearmanr: 0.32796 |  0:00:29s
epoch 5  | loss: 0.17497 | val_0_spearmanr: 0.31668 |  0:00:34s
epoch 6  | loss: 0.1748  | val_0_spearmanr: 0.32069 |  0:00:40s
epoch 7  | loss: 0.17272 | val_0_spearmanr: 0.32712 |  0:00:46s
epoch 8  | loss: 0.17303 | val_0_spearmanr: 0.30226 |  0:00:52s
epoch 9  | loss: 0.174   | val_0_spearmanr: 0.2872  |  0:00:58s
epoch 10 | loss: 0.17188 | val_0_spearmanr: 0.31397 |  0:01:04s
epoch 11 | loss: 0.16855 | val_0_spearmanr: 0.29519 |  0:01:09s

Early stopping occurred at epoch 11 with best_epoch = 1 and best_val_0_spearmanr = 0.33144




TRAIN_NUM:  4 , VAL_NUM:  1
epoch 0  | loss: 0.20181 | val_0_spearmanr: 0.29948 |  0:00:05s
epoch 1  | loss: 0.19297 | val_0_spearmanr: 0.3099  |  0:00:11s
epoch 2  | loss: 0.18715 | val_0_spearmanr: 0.28596 |  0:00:17s
epoch 3  | loss: 0.18288 | val_0_spearmanr: 0.28324 |  0:00:23s
epoch 4  | loss: 0.1813  | val_0_spearmanr: 0.31041 |  0:00:29s
epoch 5  | loss: 0.17315 | val_0_spearmanr: 0.30145 |  0:00:34s
epoch 6  | loss: 0.1731  | val_0_spearmanr: 0.30617 |  0:00:40s
epoch 7  | loss: 0.16992 | val_0_spearmanr: 0.29896 |  0:00:46s
epoch 8  | loss: 0.16921 | val_0_spearmanr: 0.29282 |  0:00:52s
epoch 9  | loss: 0.17047 | val_0_spearmanr: 0.30728 |  0:00:58s
epoch 10 | loss: 0.1674  | val_0_spearmanr: 0.28894 |  0:01:04s
epoch 11 | loss: 0.1669  | val_0_spearmanr: 0.29874 |  0:01:09s
epoch 12 | loss: 0.16483 | val_0_spearmanr: 0.28466 |  0:01:15s
epoch 13 | loss: 0.16413 | val_0_spearmanr: 0.27433 |  0:01:21s
epoch 14 | loss: 0.16091 | val_0_spearmanr: 0.28738 |  0:01:27s

Early stopp



TRAIN_NUM:  4 , VAL_NUM:  4
epoch 0  | loss: 0.19165 | val_0_spearmanr: 0.30432 |  0:00:06s
epoch 1  | loss: 0.18032 | val_0_spearmanr: 0.36958 |  0:00:12s
epoch 2  | loss: 0.17687 | val_0_spearmanr: 0.3648  |  0:00:18s
epoch 3  | loss: 0.17314 | val_0_spearmanr: 0.36085 |  0:00:24s
epoch 4  | loss: 0.17596 | val_0_spearmanr: 0.36289 |  0:00:30s
epoch 5  | loss: 0.17002 | val_0_spearmanr: 0.36124 |  0:00:37s
epoch 6  | loss: 0.16591 | val_0_spearmanr: 0.36719 |  0:00:43s
epoch 7  | loss: 0.16306 | val_0_spearmanr: 0.34661 |  0:00:49s
epoch 8  | loss: 0.16432 | val_0_spearmanr: 0.35419 |  0:00:55s
epoch 9  | loss: 0.16168 | val_0_spearmanr: 0.35438 |  0:01:01s
epoch 10 | loss: 0.16164 | val_0_spearmanr: 0.33519 |  0:01:08s
epoch 11 | loss: 0.16758 | val_0_spearmanr: 0.35096 |  0:01:14s

Early stopping occurred at epoch 11 with best_epoch = 1 and best_val_0_spearmanr = 0.36958




TRAIN_NUM:  1 , VAL_NUM:  1
epoch 0  | loss: 0.48107 | val_0_spearmanr: 0.29438 |  0:00:03s
epoch 1  | loss: 0.41171 | val_0_spearmanr: 0.28794 |  0:00:06s
epoch 2  | loss: 0.38238 | val_0_spearmanr: 0.26072 |  0:00:09s
epoch 3  | loss: 0.35548 | val_0_spearmanr: 0.28422 |  0:00:12s
epoch 4  | loss: 0.33418 | val_0_spearmanr: 0.31639 |  0:00:15s
epoch 5  | loss: 0.31621 | val_0_spearmanr: 0.29096 |  0:00:18s
epoch 6  | loss: 0.29588 | val_0_spearmanr: 0.27493 |  0:00:21s
epoch 7  | loss: 0.29399 | val_0_spearmanr: 0.29063 |  0:00:25s
epoch 8  | loss: 0.27489 | val_0_spearmanr: 0.27487 |  0:00:28s
epoch 9  | loss: 0.27833 | val_0_spearmanr: 0.29043 |  0:00:31s
epoch 10 | loss: 0.2617  | val_0_spearmanr: 0.26672 |  0:00:34s
epoch 11 | loss: 0.25614 | val_0_spearmanr: 0.25945 |  0:00:37s
epoch 12 | loss: 0.24374 | val_0_spearmanr: 0.27377 |  0:00:40s
epoch 13 | loss: 0.24412 | val_0_spearmanr: 0.29549 |  0:00:44s
epoch 14 | loss: 0.23729 | val_0_spearmanr: 0.28468 |  0:00:47s

Early stopp



TRAIN_NUM:  1 , VAL_NUM:  2
epoch 0  | loss: 0.33079 | val_0_spearmanr: 0.32101 |  0:00:03s
epoch 1  | loss: 0.29927 | val_0_spearmanr: 0.32413 |  0:00:06s
epoch 2  | loss: 0.3022  | val_0_spearmanr: 0.35339 |  0:00:09s
epoch 3  | loss: 0.27933 | val_0_spearmanr: 0.35017 |  0:00:12s
epoch 4  | loss: 0.26805 | val_0_spearmanr: 0.32791 |  0:00:15s
epoch 5  | loss: 0.26087 | val_0_spearmanr: 0.33087 |  0:00:19s
epoch 6  | loss: 0.25661 | val_0_spearmanr: 0.3275  |  0:00:22s
epoch 7  | loss: 0.24515 | val_0_spearmanr: 0.33592 |  0:00:25s
epoch 8  | loss: 0.23521 | val_0_spearmanr: 0.34172 |  0:00:28s
epoch 9  | loss: 0.23718 | val_0_spearmanr: 0.33938 |  0:00:31s
epoch 10 | loss: 0.23301 | val_0_spearmanr: 0.33197 |  0:00:34s
epoch 11 | loss: 0.22519 | val_0_spearmanr: 0.3394  |  0:00:37s
epoch 12 | loss: 0.22192 | val_0_spearmanr: 0.3524  |  0:00:41s

Early stopping occurred at epoch 12 with best_epoch = 2 and best_val_0_spearmanr = 0.35339




TRAIN_NUM:  1 , VAL_NUM:  4
epoch 0  | loss: 0.29043 | val_0_spearmanr: 0.3838  |  0:00:03s
epoch 1  | loss: 0.2762  | val_0_spearmanr: 0.36906 |  0:00:06s
epoch 2  | loss: 0.27046 | val_0_spearmanr: 0.37229 |  0:00:10s
epoch 3  | loss: 0.25993 | val_0_spearmanr: 0.37646 |  0:00:13s
epoch 4  | loss: 0.25035 | val_0_spearmanr: 0.3899  |  0:00:17s
epoch 5  | loss: 0.23865 | val_0_spearmanr: 0.39271 |  0:00:20s
epoch 6  | loss: 0.23052 | val_0_spearmanr: 0.38334 |  0:00:24s
epoch 7  | loss: 0.23421 | val_0_spearmanr: 0.37597 |  0:00:27s
epoch 8  | loss: 0.2242  | val_0_spearmanr: 0.38089 |  0:00:31s
epoch 9  | loss: 0.22096 | val_0_spearmanr: 0.38177 |  0:00:34s
epoch 10 | loss: 0.21908 | val_0_spearmanr: 0.38616 |  0:00:38s
epoch 11 | loss: 0.21512 | val_0_spearmanr: 0.3815  |  0:00:41s
epoch 12 | loss: 0.2134  | val_0_spearmanr: 0.391   |  0:00:45s
epoch 13 | loss: 0.20468 | val_0_spearmanr: 0.38796 |  0:00:48s
epoch 14 | loss: 0.20402 | val_0_spearmanr: 0.3851  |  0:00:52s
epoch 15 | l



TRAIN_NUM:  1 , VAL_NUM:  3
epoch 0  | loss: 0.2469  | val_0_spearmanr: 0.33132 |  0:00:03s
epoch 1  | loss: 0.24121 | val_0_spearmanr: 0.29895 |  0:00:06s
epoch 2  | loss: 0.23662 | val_0_spearmanr: 0.32747 |  0:00:09s
epoch 3  | loss: 0.22165 | val_0_spearmanr: 0.34957 |  0:00:12s
epoch 4  | loss: 0.22381 | val_0_spearmanr: 0.32593 |  0:00:15s
epoch 5  | loss: 0.21998 | val_0_spearmanr: 0.31656 |  0:00:18s
epoch 6  | loss: 0.2124  | val_0_spearmanr: 0.29987 |  0:00:22s
epoch 7  | loss: 0.20571 | val_0_spearmanr: 0.30817 |  0:00:25s
epoch 8  | loss: 0.2009  | val_0_spearmanr: 0.31243 |  0:00:28s
epoch 9  | loss: 0.19971 | val_0_spearmanr: 0.32311 |  0:00:31s
epoch 10 | loss: 0.19766 | val_0_spearmanr: 0.33661 |  0:00:34s
epoch 11 | loss: 0.19832 | val_0_spearmanr: 0.31523 |  0:00:37s
epoch 12 | loss: 0.19619 | val_0_spearmanr: 0.32896 |  0:00:40s
epoch 13 | loss: 0.19044 | val_0_spearmanr: 0.33957 |  0:00:44s

Early stopping occurred at epoch 13 with best_epoch = 3 and best_val_0_spea



TRAIN_NUM:  2 , VAL_NUM:  4
epoch 0  | loss: 0.43307 | val_0_spearmanr: 0.40747 |  0:00:03s
epoch 1  | loss: 0.36896 | val_0_spearmanr: 0.38502 |  0:00:06s
epoch 2  | loss: 0.33143 | val_0_spearmanr: 0.38588 |  0:00:10s
epoch 3  | loss: 0.31527 | val_0_spearmanr: 0.40175 |  0:00:13s
epoch 4  | loss: 0.28298 | val_0_spearmanr: 0.38531 |  0:00:17s
epoch 5  | loss: 0.26357 | val_0_spearmanr: 0.39666 |  0:00:20s
epoch 6  | loss: 0.2573  | val_0_spearmanr: 0.40273 |  0:00:24s
epoch 7  | loss: 0.25016 | val_0_spearmanr: 0.39114 |  0:00:27s
epoch 8  | loss: 0.24292 | val_0_spearmanr: 0.38923 |  0:00:31s
epoch 9  | loss: 0.23708 | val_0_spearmanr: 0.38656 |  0:00:34s
epoch 10 | loss: 0.23273 | val_0_spearmanr: 0.3876  |  0:00:38s

Early stopping occurred at epoch 10 with best_epoch = 0 and best_val_0_spearmanr = 0.40747




TRAIN_NUM:  2 , VAL_NUM:  3
epoch 0  | loss: 0.36908 | val_0_spearmanr: 0.33204 |  0:00:03s
epoch 1  | loss: 0.33051 | val_0_spearmanr: 0.35713 |  0:00:06s
epoch 2  | loss: 0.29862 | val_0_spearmanr: 0.33939 |  0:00:09s
epoch 3  | loss: 0.27875 | val_0_spearmanr: 0.33522 |  0:00:12s
epoch 4  | loss: 0.26925 | val_0_spearmanr: 0.33156 |  0:00:15s
epoch 5  | loss: 0.25556 | val_0_spearmanr: 0.32366 |  0:00:18s
epoch 6  | loss: 0.25188 | val_0_spearmanr: 0.32319 |  0:00:22s
epoch 7  | loss: 0.23626 | val_0_spearmanr: 0.31996 |  0:00:25s
epoch 8  | loss: 0.23087 | val_0_spearmanr: 0.32264 |  0:00:28s
epoch 9  | loss: 0.23241 | val_0_spearmanr: 0.31066 |  0:00:31s
epoch 10 | loss: 0.22192 | val_0_spearmanr: 0.28947 |  0:00:34s
epoch 11 | loss: 0.21845 | val_0_spearmanr: 0.31126 |  0:00:37s

Early stopping occurred at epoch 11 with best_epoch = 1 and best_val_0_spearmanr = 0.35713




TRAIN_NUM:  2 , VAL_NUM:  1
epoch 0  | loss: 0.31618 | val_0_spearmanr: 0.28134 |  0:00:03s
epoch 1  | loss: 0.28633 | val_0_spearmanr: 0.28493 |  0:00:06s
epoch 2  | loss: 0.27183 | val_0_spearmanr: 0.27042 |  0:00:09s
epoch 3  | loss: 0.26479 | val_0_spearmanr: 0.2455  |  0:00:12s
epoch 4  | loss: 0.25476 | val_0_spearmanr: 0.2563  |  0:00:15s
epoch 5  | loss: 0.24368 | val_0_spearmanr: 0.27462 |  0:00:18s
epoch 6  | loss: 0.24072 | val_0_spearmanr: 0.29206 |  0:00:22s
epoch 7  | loss: 0.22549 | val_0_spearmanr: 0.26665 |  0:00:25s
epoch 8  | loss: 0.22929 | val_0_spearmanr: 0.27359 |  0:00:28s
epoch 9  | loss: 0.21464 | val_0_spearmanr: 0.27991 |  0:00:31s
epoch 10 | loss: 0.21428 | val_0_spearmanr: 0.28402 |  0:00:34s
epoch 11 | loss: 0.21034 | val_0_spearmanr: 0.29028 |  0:00:37s
epoch 12 | loss: 0.20832 | val_0_spearmanr: 0.29893 |  0:00:40s
epoch 13 | loss: 0.2003  | val_0_spearmanr: 0.29711 |  0:00:44s
epoch 14 | loss: 0.19829 | val_0_spearmanr: 0.29588 |  0:00:47s
epoch 15 | l



TRAIN_NUM:  2 , VAL_NUM:  2
epoch 0  | loss: 0.20482 | val_0_spearmanr: 0.35737 |  0:00:03s
epoch 1  | loss: 0.20165 | val_0_spearmanr: 0.33742 |  0:00:06s
epoch 2  | loss: 0.19092 | val_0_spearmanr: 0.32062 |  0:00:09s
epoch 3  | loss: 0.18833 | val_0_spearmanr: 0.31641 |  0:00:12s
epoch 4  | loss: 0.1841  | val_0_spearmanr: 0.31746 |  0:00:15s
epoch 5  | loss: 0.17914 | val_0_spearmanr: 0.32028 |  0:00:18s
epoch 6  | loss: 0.17743 | val_0_spearmanr: 0.32265 |  0:00:22s
epoch 7  | loss: 0.17276 | val_0_spearmanr: 0.3164  |  0:00:25s
epoch 8  | loss: 0.17026 | val_0_spearmanr: 0.31247 |  0:00:28s
epoch 9  | loss: 0.17559 | val_0_spearmanr: 0.32671 |  0:00:31s
epoch 10 | loss: 0.16335 | val_0_spearmanr: 0.32092 |  0:00:34s

Early stopping occurred at epoch 10 with best_epoch = 0 and best_val_0_spearmanr = 0.35737




TRAIN_NUM:  3 , VAL_NUM:  1
epoch 0  | loss: 0.40686 | val_0_spearmanr: 0.28465 |  0:00:03s
epoch 1  | loss: 0.34305 | val_0_spearmanr: 0.31672 |  0:00:06s
epoch 2  | loss: 0.30524 | val_0_spearmanr: 0.29618 |  0:00:09s
epoch 3  | loss: 0.27097 | val_0_spearmanr: 0.25581 |  0:00:12s
epoch 4  | loss: 0.25859 | val_0_spearmanr: 0.27894 |  0:00:15s
epoch 5  | loss: 0.24503 | val_0_spearmanr: 0.2737  |  0:00:18s
epoch 6  | loss: 0.23017 | val_0_spearmanr: 0.26903 |  0:00:22s
epoch 7  | loss: 0.21645 | val_0_spearmanr: 0.28118 |  0:00:25s
epoch 8  | loss: 0.20991 | val_0_spearmanr: 0.27937 |  0:00:28s
epoch 9  | loss: 0.20722 | val_0_spearmanr: 0.28852 |  0:00:31s
epoch 10 | loss: 0.2025  | val_0_spearmanr: 0.29454 |  0:00:34s
epoch 11 | loss: 0.19104 | val_0_spearmanr: 0.30244 |  0:00:37s

Early stopping occurred at epoch 11 with best_epoch = 1 and best_val_0_spearmanr = 0.31672




TRAIN_NUM:  3 , VAL_NUM:  4
epoch 0  | loss: 0.31776 | val_0_spearmanr: 0.39297 |  0:00:03s
epoch 1  | loss: 0.27789 | val_0_spearmanr: 0.37855 |  0:00:06s
epoch 2  | loss: 0.24874 | val_0_spearmanr: 0.37754 |  0:00:10s
epoch 3  | loss: 0.23634 | val_0_spearmanr: 0.37548 |  0:00:13s
epoch 4  | loss: 0.22679 | val_0_spearmanr: 0.37672 |  0:00:17s
epoch 5  | loss: 0.21018 | val_0_spearmanr: 0.37181 |  0:00:20s
epoch 6  | loss: 0.20626 | val_0_spearmanr: 0.36006 |  0:00:24s
epoch 7  | loss: 0.1982  | val_0_spearmanr: 0.37265 |  0:00:27s
epoch 8  | loss: 0.18791 | val_0_spearmanr: 0.38294 |  0:00:31s
epoch 9  | loss: 0.18797 | val_0_spearmanr: 0.3708  |  0:00:34s
epoch 10 | loss: 0.17718 | val_0_spearmanr: 0.36771 |  0:00:38s

Early stopping occurred at epoch 10 with best_epoch = 0 and best_val_0_spearmanr = 0.39297




TRAIN_NUM:  3 , VAL_NUM:  3
epoch 0  | loss: 0.28215 | val_0_spearmanr: 0.31166 |  0:00:03s
epoch 1  | loss: 0.2513  | val_0_spearmanr: 0.30165 |  0:00:06s
epoch 2  | loss: 0.24091 | val_0_spearmanr: 0.31594 |  0:00:09s
epoch 3  | loss: 0.22682 | val_0_spearmanr: 0.31091 |  0:00:12s
epoch 4  | loss: 0.21974 | val_0_spearmanr: 0.31417 |  0:00:15s
epoch 5  | loss: 0.21176 | val_0_spearmanr: 0.34012 |  0:00:19s
epoch 6  | loss: 0.20181 | val_0_spearmanr: 0.34202 |  0:00:22s
epoch 7  | loss: 0.19413 | val_0_spearmanr: 0.33845 |  0:00:25s
epoch 8  | loss: 0.18891 | val_0_spearmanr: 0.31359 |  0:00:28s
epoch 9  | loss: 0.18601 | val_0_spearmanr: 0.3031  |  0:00:31s
epoch 10 | loss: 0.17669 | val_0_spearmanr: 0.31332 |  0:00:34s
epoch 11 | loss: 0.17113 | val_0_spearmanr: 0.31961 |  0:00:37s
epoch 12 | loss: 0.17086 | val_0_spearmanr: 0.32004 |  0:00:40s
epoch 13 | loss: 0.17024 | val_0_spearmanr: 0.32022 |  0:00:43s
epoch 14 | loss: 0.16532 | val_0_spearmanr: 0.31688 |  0:00:46s
epoch 15 | l



TRAIN_NUM:  3 , VAL_NUM:  2
epoch 0  | loss: 0.20737 | val_0_spearmanr: 0.34956 |  0:00:03s
epoch 1  | loss: 0.19841 | val_0_spearmanr: 0.34286 |  0:00:06s
epoch 2  | loss: 0.1887  | val_0_spearmanr: 0.34398 |  0:00:09s
epoch 3  | loss: 0.18309 | val_0_spearmanr: 0.35127 |  0:00:12s
epoch 4  | loss: 0.17696 | val_0_spearmanr: 0.3557  |  0:00:15s
epoch 5  | loss: 0.17871 | val_0_spearmanr: 0.35859 |  0:00:18s
epoch 6  | loss: 0.17101 | val_0_spearmanr: 0.35404 |  0:00:21s
epoch 7  | loss: 0.16494 | val_0_spearmanr: 0.34686 |  0:00:24s
epoch 8  | loss: 0.16082 | val_0_spearmanr: 0.34972 |  0:00:27s
epoch 9  | loss: 0.15919 | val_0_spearmanr: 0.35463 |  0:00:30s
epoch 10 | loss: 0.15474 | val_0_spearmanr: 0.35097 |  0:00:33s
epoch 11 | loss: 0.1506  | val_0_spearmanr: 0.35006 |  0:00:36s
epoch 12 | loss: 0.14982 | val_0_spearmanr: 0.34164 |  0:00:40s
epoch 13 | loss: 0.1467  | val_0_spearmanr: 0.3387  |  0:00:43s
epoch 14 | loss: 0.14579 | val_0_spearmanr: 0.35918 |  0:00:46s
epoch 15 | l



Successfully saved model at ./tabnet_model_test_2.zip
