# Usage example: Custom model

## 1. Custom Embedding

In [1]:
import torch
from torch import Tensor
import torch.nn as nn

import os,sys; sys.path.append(os.path.abspath(".."))
from deep_table.nn.encoders.embedding.base import BaseEmbedding


class YourEmbedding(BaseEmbedding):
    def __init__(
        self,
        num_continuous_features: int,
        num_categorical_features: int,
        num_categories: int,
        dim_embed: int,
        use_cls: bool = False,
        initialization: str = "uniform",
        activation = None,
    ) -> None:
        super().__init__(
            num_continuous_features=num_continuous_features,
            num_categorical_features=num_categorical_features,
            num_categories=num_categories,
            dim_embed=dim_embed,
            is_in_backbone_continuous=True,
            is_in_backbone_categorical=True,
            dim_in_backbone=(
                num_continuous_features + num_categorical_features,
                dim_embed,
            ),
            dim_skip_backbone=0,             
            use_cls=False,
        )
    
        self.con_embed = nn.ModuleList(
            [nn.Linear(1, dim_embed) for i in range(self.num_continuous_features)]
        )
        self.cat_embed = nn.Embedding(self.num_categories, dim_embed)

    def continuous_embedding(self, x: Tensor) -> Tensor:
        embedding = []
        for i in range(len(self.con_embed)):
            embedding.append(self.con_embed[i](x[:, i].view(-1, 1)).unsqueeze(1))
        embedding = torch.cat(embedding, dim=1)
        return embedding

    def categorical_embedding(self, x: Tensor) -> Tensor:
        return self.cat_embed(x)

## 2. Custom Backbone

In [2]:
from deep_table.nn.encoders.backbone.base import BaseBackbone


class YourBackbone(BaseBackbone):
    def __init__(
        self,
        num_features: int,
        dim_embed: int,
        use_cls: bool = True,
        **kwargs
    ) -> None:
        super().__init__()
        self._dim_out = 12
        dim_input = num_features * dim_embed
        self.layer = nn.Linear(dim_input, self._dim_out)

    def dim_out(self, is_pretrain: bool = False) -> int:
        return self._dim_out

    def forward(self, x: Tensor) -> Tensor:
        x = x.flatten(1)
        x = self.layer(x)
        return x

## 3 Custom Model

In [3]:
from deep_table.nn.models.base import BaseModel


class YourModel(BaseModel):
    def __init__(
        self,
        encoder,
        dim_out: int,
        **kwargs,
    ) -> None:
        self.save_hyperparameters(ignore="encoder")
        super().__init__(encoder, **kwargs)

    def _build_network(self) -> None:
        dim_representation = self.encoder.dim_out(is_pretrain=False)
        self.mlp = nn.Linear(dim_representation, self.hparams.dim_out)

    def forward(self, x):
        x = self.encoder(x)
        x = self.mlp(x)
        return x

## 4. Training

In [7]:
from deep_table.data.datasets import Adult
from deep_table.data.data_module import TabularDatamodule


adult_dataset = Adult(root="../data")
adult_dataframes = adult_dataset.processed_dataframes()

datamodule = TabularDatamodule(
    train=adult_dataframes["train"],
    val=adult_dataframes["val"],
    test=adult_dataframes["test"],
    task=adult_dataset.task,
    dim_out=adult_dataset.dim_out,
    categorical_columns=adult_dataset.categorical_columns,
    continuous_columns=adult_dataset.continuous_columns,
    target=adult_dataset.target_columns,
    num_categories=adult_dataset.num_categories(),
)

In [8]:
from omegaconf import OmegaConf


# Encoder settings
encoder_config = OmegaConf.create({
    "embedding": {
        "args": {
            "dim_embed": 16
        }
    },
    "backbone": {
        "args": {}
    }
})

# model settings (learning rate, scheduler...)
model_config = OmegaConf.create({
    "args": {}
})

# training settings (epoch, gpu...)
trainer_config = OmegaConf.create({
    "gpus": 0,
    "max_epochs": 1,
    "seed": 42
})

In [9]:
from deep_table.estimators.base import Estimator
from deep_table.utils import get_scores

estimator = Estimator(
    encoder_config, model_config, trainer_config,
    custom_embedding=YourEmbedding, custom_backbone=YourBackbone, custom_model=YourModel
)
estimator.fit(datamodule)

predict = estimator.predict(datamodule.dataloader(split="test"))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Global seed set to 42

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Encoder           | 4.7 K 
1 | mlp     | Linear            | 13    
2 | loss    | BCEWithLogitsLoss | 0     
----------------------------------------------
4.7 K     Trainable params
0         Non-trainable params
4.7 K     Total params
0.019     Total estimated model params size (MB)


                                                                                                                                                                                                                  

Global seed set to 42


Epoch 0:  75%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                 | 191/255 [00:26<00:08,  7.30it/s, loss=0.327, train_loss=0.390]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                          | 0/64 [00:00<?, ?it/s][A
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:48<00:00,  5.28it/s, loss=0.327, train_loss=0.390, val_loss=0.333][A
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:48<00:00,  5.28it/s, loss=0.327, train_loss=0.390, val_loss=0.333][A
Predicting: 191it [00:22, ?it/s]


In [10]:
get_scores(predict, target=datamodule.dataloader(split="test"), task="binary")

{'accuracy': 0.8495792641729624,
 'AUC': 0.9011576892508752,
 'F1 score': 0.904039810352259,
 'cross_entropy': 0.3254130371925385}