In [2]:
from omegaconf import OmegaConf


# Encoder settings
encoder_config = OmegaConf.create({
    "embedding": {
        "name": "FeatureEmbedding",
    },
    "backbone": {
        "name": "FTTransformerBackbone",
    }
})

# model settings (learning rate, scheduler...)
model_config = OmegaConf.create({
    "name": "MLPHeadModel"
})

# training settings (epoch, gpu...): not necessary
trainer_config = OmegaConf.create({
    "max_epochs": 1,
})

In [5]:
import os,sys; sys.path.append(os.path.abspath(".."))
from deep_table.data.data_module import TabularDatamodule
from deep_table.data.datasets import Adult


adult_dataset = Adult(root="../data")
adult_dataframes = adult_dataset.processed_dataframes()

datamodule = TabularDatamodule(
    train=adult_dataframes["train"],
    val=adult_dataframes["val"],
    test=adult_dataframes["test"],
    task=adult_dataset.task,
    dim_out=adult_dataset.dim_out,
    categorical_columns=adult_dataset.categorical_columns,
    continuous_columns=adult_dataset.continuous_columns,
    target=adult_dataset.target_columns,
    num_categories=adult_dataset.num_categories(),
)


In [6]:
from deep_table.estimators.base import Estimator
from deep_table.utils import get_scores


estimator = Estimator(
    encoder_config,
    model_config, 
    trainer_config
)
estimator.fit(datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
`activation` is not set. `nn.Identity` would be used instead.
`activation` is not set. `nn.Identity` would be used instead.
Global seed set to 0

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Encoder           | 30.8 K
1 | mlp     | Sequential        | 4.6 K 
2 | loss    | BCEWithLogitsLoss | 0     
----------------------------------------------
35.4 K    Trainable params
0         Non-trainable params
35.4 K    Total params
0.142     Total estimated model params size (MB)


                                                                                                                                                                                                                  

Global seed set to 0


Epoch 0:  75%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                 | 191/255 [00:29<00:09,  6.60it/s, loss=0.355, train_loss=0.385]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                          | 0/64 [00:00<?, ?it/s][A
Epoch 0:  76%|███████████████████████████████████████████████████████████████████████████████████████████████████▉                                | 193/255 [00:36<00:11,  5.27it/s, loss=0.355, train_loss=0.385][A
Epoch 0:  81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 207/255 [00:36<00:08,  5.64it/s, loss=0.355, train_loss=0.385][A
Epoch 0:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████

In [7]:
predict = estimator.predict(datamodule.dataloader(split="test"))
get_scores(predict, target=datamodule.dataloader(split="test"), task="binary")

Predicting: 191it [00:23, ?it/s]


{'accuracy': 0.8488422087095387,
 'AUC': 0.9042503388917222,
 'F1 score': 0.902792589959316,
 'cross_entropy': 0.3235419449084581}

In [8]:
pretrain_model_config = OmegaConf.create({
    "name": "SAINTPretrainModel"
})

pretrain_model = Estimator(
    encoder_config,
    pretrain_model_config,
    trainer_config
)
pretrain_model.fit(datamodule)

estimator = Estimator(
    encoder_config, model_config, 
                      
    trainer_config)
estimator.fit(datamodule, from_pretrained=pretrain_model)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
`activation` is not set. `nn.Identity` would be used instead.
`activation` is not set. `nn.Identity` would be used instead.
Global seed set to 0

  | Name               | Type             | Params
--------------------------------------------------------
0 | encoder            | Encoder          | 30.8 K
1 | cutmix             | Cutmix           | 0     
2 | mixup              | Mixup            | 0     
3 | g1                 | SimpleMLPLayer   | 62.2 K
4 | g2                 | SimpleMLPLayer   | 62.2 K
5 | feature_wise_mlp   | ModuleList       | 1.2 M 
6 | contranstive_loss  | InfoNCELoss      | 0     
7 | mse_loss           | MSELoss          | 0     
8 | cross_entropy_loss | CrossEntropyLoss | 0     
--------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.347     Total estimated model params

                                                                                                                                                                                                                  

Global seed set to 0


Epoch 0:  75%|████████████████████████████████████████████████████████▉                   | 191/255 [00:40<00:13,  4.71it/s, loss=528, train_contrastive_loss=374.0, train_denoising_loss=2.530, train_loss=400.0]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                          | 0/64 [00:00<?, ?it/s][A
Epoch 0:  76%|█████████████████████████████████████████████████████████▌                  | 193/255 [00:48<00:15,  3.98it/s, loss=528, train_contrastive_loss=374.0, train_denoising_loss=2.530, train_loss=400.0][A
Epoch 0:  78%|███████████████████████████████████████████████████████████                 | 198/255 [00:48<00:13,  4.08it/s, loss=528, train_contrastive_loss=374.0, train_denoising_loss=2.530, train_loss=400.0][A
Epoch 0:  80%|████████████████████████████████████████████████████████████▌               | 203/255 [00:48<00:12,

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Global seed set to 0

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Encoder           | 30.8 K
1 | mlp     | Sequential        | 4.6 K 
2 | loss    | BCEWithLogitsLoss | 0     
----------------------------------------------
35.4 K    Trainable params
0         Non-trainable params
35.4 K    Total params
0.142     Total estimated model params size (MB)



                                                                                                                                                                                                                  

Global seed set to 0


Epoch 0:  75%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                 | 191/255 [00:28<00:09,  6.72it/s, loss=0.348, train_loss=0.406]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                          | 0/64 [00:00<?, ?it/s][A
Epoch 0:  76%|███████████████████████████████████████████████████████████████████████████████████████████████████▉                                | 193/255 [00:36<00:11,  5.37it/s, loss=0.348, train_loss=0.406][A
Epoch 0:  81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 207/255 [00:36<00:08,  5.74it/s, loss=0.348, train_loss=0.406][A
Epoch 0:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████

In [9]:
predict = estimator.predict(datamodule.dataloader(split="test"))
get_scores(predict, target=datamodule.dataloader(split="test"), task="binary")

Predicting: 191it [00:23, ?it/s]


{'accuracy': 0.8525274860266568,
 'AUC': 0.9066998626869078,
 'F1 score': 0.907507993374167,
 'cross_entropy': 0.31816212925849996}