diff --git a/README.md b/README.md
index 29edb58e..85073ca0 100644
--- a/README.md
+++ b/README.md
@@ -100,7 +100,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
- ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
+ ], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
diff --git a/docs/gs_usage.md b/docs/gs_usage.md
index 7285c519..af3c249f 100644
--- a/docs/gs_usage.md
+++ b/docs/gs_usage.md
@@ -14,7 +14,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
- ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
+ ], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
diff --git a/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb b/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb
index af4ddbf2..ab803092 100644
--- a/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb
+++ b/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb
@@ -532,7 +532,7 @@
"data_config = DataConfig(\n",
" target=[\n",
" target_col\n",
- " ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
+ " ], # target should always be a list\n",
" continuous_cols=num_col_names,\n",
" categorical_cols=cat_col_names,\n",
")\n",
diff --git a/docs/tutorials/15-Multi Target Classification.ipynb b/docs/tutorials/15-Multi Target Classification.ipynb
new file mode 100644
index 00000000..6d882cc2
--- /dev/null
+++ b/docs/tutorials/15-Multi Target Classification.ipynb
@@ -0,0 +1,2008 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.datasets import make_classification\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import accuracy_score, f1_score\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "from pytorch_tabular.utils import make_mixed_dataset, print_metrics\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "outputs": [],
+ "source": [
+ "# Load dataset\n",
+ "data, cat_col_names, num_col_names = make_mixed_dataset(task=\"classification\", n_samples=10000, n_features=8, n_categories=4, weights=[0.8], random_state=42)\n",
+ "\n",
+ "# Create a new, second target\n",
+ "data['second_target'] = 0\n",
+ "for c in cat_col_names:\n",
+ " data.second_target += data[c]\n",
+ "\n",
+ "# Correlate it to 1st target\n",
+ "data.second_target += (data.target == 'class_0')\n",
+ "\n",
+ "# Create random discrete noise to make task non-trivial\n",
+ "random_noise = np.random.normal(0, 0.8, data.shape[0]).astype(int)\n",
+ "data.second_target = (data.second_target + random_noise).mod(3) \n",
+ "\n",
+ "# Now that dataset is complete, we can perform train/test split\n",
+ "train, test = train_test_split(data, random_state=42)\n",
+ "train, val = train_test_split(train, random_state=42)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "target\n",
+ "class_0 0.7968\n",
+ "class_1 0.2032\n",
+ "Name: proportion, dtype: float64"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.target.value_counts(normalize=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "second_target\n",
+ "2.0 0.3364\n",
+ "1.0 0.3354\n",
+ "0.0 0.3282\n",
+ "Name: proportion, dtype: float64"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.second_target.value_counts(normalize=True)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "source": [
+ "# Importing the Library"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "outputs": [],
+ "source": [
+ "from pytorch_tabular import TabularModel\n",
+ "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n",
+ "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n",
+ "from pytorch_tabular.models.common.heads import LinearHeadConfig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "results = []"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "source": [
+ "## Define the Configs\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "outputs": [],
+ "source": [
+ "trainer_config = TrainerConfig(\n",
+ " auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n",
+ " batch_size=1024,\n",
+ " max_epochs=100,\n",
+ " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n",
+ " early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n",
+ " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n",
+ " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
+ " load_best=True, # After training, load the best checkpoint\n",
+ "# accelerator=\"cpu\"\n",
+ ")\n",
+ "optimizer_config = OptimizerConfig()\n",
+ "\n",
+ "head_config = LinearHeadConfig(\n",
+ " layers=\"\", # No additional layer in head, just a mapping layer to output_dim\n",
+ " dropout=0.1,\n",
+ " initialization=\"kaiming\"\n",
+ ").__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
+ "\n",
+ "model_config = CategoryEmbeddingModelConfig(\n",
+ " task=\"classification\",\n",
+ " layers=\"1024-512-512\", # Number of nodes in each layer\n",
+ " activation=\"LeakyReLU\", # Activation between each layers\n",
+ " head = \"LinearHead\", #Linear Head\n",
+ " head_config = head_config, # Linear Head Config\n",
+ " learning_rate = 1e-3,\n",
+ " metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n",
+ " metrics_prob_input=[True, False, True] # f1_score needs probability scores, while accuracy doesn't\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Data Config For Each Target"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_config_first_target = DataConfig(\n",
+ " target=['target'], #target should always be a list\n",
+ " continuous_cols=num_col_names,\n",
+ " categorical_cols=cat_col_names\n",
+ ")\n",
+ "\n",
+ "data_config_second_target = DataConfig(\n",
+ " target=['second_target'], #target should always be a list\n",
+ " continuous_cols=num_col_names,\n",
+ " categorical_cols=cat_col_names\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "Collapsed": "false"
+ },
+ "source": [
+ "## Training the Single-Target Model "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "Collapsed": "false",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
2024-07-10 11:38:02,110 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m110\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Seed set to 42\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:02,126 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m126\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:02,130 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for \n",
+ "classification task \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m130\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n",
+ "classification task \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:02,150 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m150\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:02,190 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m190\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:02,366 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m366\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d2cae94714354ea09299a9e44f238d89",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LR finder stopped early after 92 steps due to diverging loss.\n",
+ "Learning rate set to 0.025118864315095822\n",
+ "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_744b05b3-1568-4cf0-97ac-eaefc0c795f4.ckpt\n",
+ "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_744b05b3-1568-4cf0-97ac-eaefc0c795f4.ckpt\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:10,876 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.025118864315095822. For plot\n",
+ "and detailed analysis, use `find_learning_rate` method. \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m876\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.025118864315095822\u001b[0m. For plot\n",
+ "and detailed analysis, use `find_learning_rate` method. \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:10,881 - {pytorch_tabular.tabular_model:669} - INFO - Training Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m881\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃ ┃ Name ┃ Type ┃ Params ┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│ 0 │ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│ 1 │ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│ 2 │ head │ LinearHead │ 1.0 K │\n",
+ "│ 3 │ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 1.0 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Trainable params: 805 K \n",
+ "Non-trainable params: 0 \n",
+ "Total params: 805 K \n",
+ "Total estimated model params size (MB): 3 \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mTrainable params\u001b[0m: 805 K \n",
+ "\u001b[1mNon-trainable params\u001b[0m: 0 \n",
+ "\u001b[1mTotal params\u001b[0m: 805 K \n",
+ "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d67a105c98184698acff2a3caab52b74",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:29,013 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m013\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:29,014 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m014\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "543ac972faed487c9fa81364d979f3f7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test_accuracy │ 0.9484000205993652 │\n",
+ "│ test_auroc │ 0.9717501997947693 │\n",
+ "│ test_f1_score │ 0.9484000205993652 │\n",
+ "│ test_loss │ 0.17070142924785614 │\n",
+ "│ test_loss_0 │ 0.17070142924785614 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9484000205993652 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9717501997947693 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9484000205993652 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.17070142924785614 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.17070142924785614 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "tabular_model = TabularModel(\n",
+ " data_config=data_config_first_target,\n",
+ " model_config=model_config,\n",
+ " optimizer_config=optimizer_config,\n",
+ " trainer_config=trainer_config,\n",
+ ")\n",
+ "\n",
+ "tabular_model.fit(train=train, validation=val)\n",
+ "\n",
+ "result = tabular_model.evaluate(test)\n",
+ "\n",
+ "result = {k: float(v) for k,v in result[0].items()}\n",
+ "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n",
+ " index=['1st Target (single mode)'])\n",
+ "results.append(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Now train separately on the 2nd target"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:30,958 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m958\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Seed set to 42\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:30,972 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m972\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:30,976 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for \n",
+ "classification task \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m976\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n",
+ "classification task \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:30,996 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m996\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:31,035 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m035\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:31,054 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m054\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ace5c6e2844d4ff58f2550a8e9da9ee4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LR finder stopped early after 91 steps due to diverging loss.\n",
+ "Learning rate set to 0.0002511886431509582\n",
+ "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_cd3e2f01-7f2a-4103-ac17-9b4b4f8aeed1.ckpt\n",
+ "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_cd3e2f01-7f2a-4103-ac17-9b4b4f8aeed1.ckpt\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:38,580 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.0002511886431509582. For \n",
+ "plot and detailed analysis, use `find_learning_rate` method. \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m580\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.0002511886431509582\u001b[0m. For \n",
+ "plot and detailed analysis, use `find_learning_rate` method. \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:38,584 - {pytorch_tabular.tabular_model:669} - INFO - Training Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m584\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃ ┃ Name ┃ Type ┃ Params ┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│ 0 │ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│ 1 │ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│ 2 │ head │ LinearHead │ 1.5 K │\n",
+ "│ 3 │ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 1.5 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Trainable params: 806 K \n",
+ "Non-trainable params: 0 \n",
+ "Total params: 806 K \n",
+ "Total estimated model params size (MB): 3 \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mTrainable params\u001b[0m: 806 K \n",
+ "\u001b[1mNon-trainable params\u001b[0m: 0 \n",
+ "\u001b[1mTotal params\u001b[0m: 806 K \n",
+ "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "483e9eeb0b754be9b966618d11ab2baa",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:54,807 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m807\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:54,808 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m808\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0fad2da095be4448ae4507d193f28c4a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test_accuracy │ 0.6808000206947327 │\n",
+ "│ test_auroc │ 0.8134375810623169 │\n",
+ "│ test_f1_score │ 0.6808000206947327 │\n",
+ "│ test_loss │ 0.8243984580039978 │\n",
+ "│ test_loss_0 │ 0.8243984580039978 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6808000206947327 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8134375810623169 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6808000206947327 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8243984580039978 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8243984580039978 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "tabular_model = TabularModel(\n",
+ " data_config=data_config_second_target,\n",
+ " model_config=model_config,\n",
+ " optimizer_config=optimizer_config,\n",
+ " trainer_config=trainer_config,\n",
+ ")\n",
+ "\n",
+ "tabular_model.fit(train=train, validation=val)\n",
+ "\n",
+ "result = tabular_model.evaluate(test)\n",
+ "\n",
+ "result = {k: float(v) for k,v in result[0].items()}\n",
+ "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n",
+ " index=['2nd Target (single mode)'])\n",
+ "\n",
+ "results.append(result)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multi-Target Training\n",
+ "\n",
+ "Instead of training one model for the first target, and another model for the second target, we can train a single model that will make a prediction for both targets.\n",
+ "\n",
+ "This is usually beneficial in reducing training time, but may also lead to better results since the model may have a better representation (embedding) by learning from multiple targets.\n",
+ "\n",
+ "To perform multi-target training, we only need to model the 'target' field in the data_config to include a list of multiple targets.\n",
+ "\n",
+ "Results are reported on the sum of all metrics (f1, AU-ROC, etc.), as well as a list of results for each target with the suffix '_n' (starting at n=1), for example, the f1 score for the 2nd target is test_f1_score_0."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,010 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m010\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Seed set to 42\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,023 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m023\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,028 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for \n",
+ "classification task \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m028\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n",
+ "classification task \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,049 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m049\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,089 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m089\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:38:57,115 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c952afd65ca84628a3d1142ddc2e936e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LR finder stopped early after 94 steps due to diverging loss.\n",
+ "Learning rate set to 4.786300923226385e-05\n",
+ "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_90994562-d7c8-47ac-974c-b9c85a34479e.ckpt\n",
+ "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_90994562-d7c8-47ac-974c-b9c85a34479e.ckpt\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:39:06,162 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 4.786300923226385e-05. For \n",
+ "plot and detailed analysis, use `find_learning_rate` method. \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m162\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m4.786300923226385e-05\u001b[0m. For \n",
+ "plot and detailed analysis, use `find_learning_rate` method. \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:39:06,166 - {pytorch_tabular.tabular_model:669} - INFO - Training Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m166\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃ ┃ Name ┃ Type ┃ Params ┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│ 0 │ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│ 1 │ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│ 2 │ head │ LinearHead │ 2.6 K │\n",
+ "│ 3 │ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 2.6 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴───────────────────────────┴────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Trainable params: 807 K \n",
+ "Non-trainable params: 0 \n",
+ "Total params: 807 K \n",
+ "Total estimated model params size (MB): 3 \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mTrainable params\u001b[0m: 807 K \n",
+ "\u001b[1mNon-trainable params\u001b[0m: 0 \n",
+ "\u001b[1mTotal params\u001b[0m: 807 K \n",
+ "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "49d66315050e49ed94823d49ea1ff8c0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=100` reached.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:07,357 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m357\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:07,363 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m363\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b06cba39db194aaea56bc99fd05c334b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test_accuracy │ 1.5648000240325928 │\n",
+ "│ test_accuracy_0 │ 0.946399986743927 │\n",
+ "│ test_accuracy_1 │ 0.618399977684021 │\n",
+ "│ test_auroc │ 1.7456306219100952 │\n",
+ "│ test_auroc_0 │ 0.9698674082756042 │\n",
+ "│ test_auroc_1 │ 0.7757631540298462 │\n",
+ "│ test_f1_score │ 1.5648000240325928 │\n",
+ "│ test_f1_score_0 │ 0.946399986743927 │\n",
+ "│ test_f1_score_1 │ 0.618399977684021 │\n",
+ "│ test_loss │ 1.0632164478302002 │\n",
+ "│ test_loss_0 │ 0.16491228342056274 │\n",
+ "│ test_loss_1 │ 0.8983041048049927 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.5648000240325928 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.946399986743927 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.618399977684021 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7456306219100952 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9698674082756042 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7757631540298462 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.5648000240325928 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.946399986743927 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.618399977684021 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0632164478302002 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.16491228342056274 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8983041048049927 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "data_config_multi = DataConfig(\n",
+ " target=['target','second_target'], #target should always be a list\n",
+ " continuous_cols=num_col_names,\n",
+ " categorical_cols=cat_col_names\n",
+ ")\n",
+ "\n",
+ "tabular_model = TabularModel(\n",
+ " data_config=data_config_multi,\n",
+ " model_config=model_config,\n",
+ " optimizer_config=optimizer_config,\n",
+ " trainer_config=trainer_config,\n",
+ ")\n",
+ "\n",
+ "tabular_model.fit(train=train, validation=val)\n",
+ "\n",
+ "result = tabular_model.evaluate(test)\n",
+ "\n",
+ "result = {k: float(v) for k,v in result[0].items()}\n",
+ "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n",
+ " index=['1st Target (multi-target mode)'])\n",
+ "results.append(result1)\n",
+ "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n",
+ " index=['2nd Target (multi-target mode)'])\n",
+ "results.append(result2)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " | | \n",
+ " f1 | \n",
+ " auroc | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1st Target (single mode) | \n",
+ " 0.948400 | \n",
+ " 0.971750 | \n",
+ "
\n",
+ " \n",
+ " | 2nd Target (single mode) | \n",
+ " 0.680800 | \n",
+ " 0.813438 | \n",
+ "
\n",
+ " \n",
+ " | 1st Target (multi-target mode) | \n",
+ " 0.946400 | \n",
+ " 0.969867 | \n",
+ "
\n",
+ " \n",
+ " | 2nd Target (multi-target mode) | \n",
+ " 0.618400 | \n",
+ " 0.775763 | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "res_df = pd.concat(results)\n",
+ "res_df.style.highlight_max(color=\"lightgreen\",axis=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this run, we see that multi-target model performed on-par with the single-target on the 1st target, but slightly worse on the 2nd target. This may vary for this artificial dataset depending on random number generation, and in general multi-target may not outperform single-target variants. Additional tuning may be needed for multi-target, e.g. larger embedding size. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Deep Learning Model\n",
+ "\n",
+ "We can also test whether a deeper model benefits from the shared embedding. In this case, we test Gandalf with multi-target classification, similar to our experiment above.\n",
+ "\n",
+ "Note: Without an accelerator (GPU), this training will take considerable time on a CPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pytorch_tabular.models import GANDALFConfig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,147 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m147\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Seed set to 42\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,160 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m160\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,165 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for \n",
+ "classification task \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m165\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n",
+ "classification task \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,188 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: GANDALFModel \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m188\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: GANDALFModel \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,212 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m212\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:09,240 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m240\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3d222f1be7284f9aae6216f29e17ea4b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_steps=100` reached.\n",
+ "Learning rate set to 0.10964781961431852\n",
+ "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_92ae2770-372c-495e-9410-7ff7723d419e.ckpt\n",
+ "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_92ae2770-372c-495e-9410-7ff7723d419e.ckpt\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:22,112 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.10964781961431852. For plot \n",
+ "and detailed analysis, use `find_learning_rate` method. \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m112\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.10964781961431852\u001b[0m. For plot \n",
+ "and detailed analysis, use `find_learning_rate` method. \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:22,115 - {pytorch_tabular.tabular_model:669} - INFO - Training Started \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃ ┃ Name ┃ Type ┃ Params ┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│ 0 │ _backbone │ GANDALFBackbone │ 9.6 K │\n",
+ "│ 1 │ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│ 2 │ _head │ Sequential │ 90 │\n",
+ "│ 3 │ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴──────────────────┴────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+ "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
+ "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+ "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ GANDALFBackbone │ 9.6 K │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head │ Sequential │ 90 │\n",
+ "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n",
+ "└───┴──────────────────┴──────────────────┴────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Trainable params: 9.8 K \n",
+ "Non-trainable params: 0 \n",
+ "Total params: 9.8 K \n",
+ "Total estimated model params size (MB): 0 \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mTrainable params\u001b[0m: 9.8 K \n",
+ "\u001b[1mNon-trainable params\u001b[0m: 0 \n",
+ "\u001b[1mTotal params\u001b[0m: 9.8 K \n",
+ "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0 \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "833c09b04f5a4841a69215c4161a94ec",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:28,324 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m324\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "2024-07-10 11:40:28,326 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model \n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m326\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8db6da2e21c24377a31514cc5a245919",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test_accuracy │ 1.210800051689148 │\n",
+ "│ test_accuracy_0 │ 0.9120000004768372 │\n",
+ "│ test_accuracy_1 │ 0.298799991607666 │\n",
+ "│ test_auroc │ 1.4514379501342773 │\n",
+ "│ test_auroc_0 │ 0.9510558843612671 │\n",
+ "│ test_auroc_1 │ 0.500382125377655 │\n",
+ "│ test_f1_score │ 1.210800051689148 │\n",
+ "│ test_f1_score_0 │ 0.9120000004768372 │\n",
+ "│ test_f1_score_1 │ 0.298799991607666 │\n",
+ "│ test_loss │ 1.3304287195205688 │\n",
+ "│ test_loss_0 │ 0.22366906702518463 │\n",
+ "│ test_loss_1 │ 1.106759786605835 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.210800051689148 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9120000004768372 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.298799991607666 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.4514379501342773 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9510558843612671 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_auroc_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.500382125377655 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.210800051689148 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9120000004768372 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.298799991607666 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.3304287195205688 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.22366906702518463 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.106759786605835 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " f1 | \n",
+ " auroc | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1st Target (single mode) | \n",
+ " 0.9484 | \n",
+ " 0.971750 | \n",
+ "
\n",
+ " \n",
+ " | 2nd Target (single mode) | \n",
+ " 0.6808 | \n",
+ " 0.813438 | \n",
+ "
\n",
+ " \n",
+ " | 1st Target (multi-target mode) | \n",
+ " 0.9464 | \n",
+ " 0.969867 | \n",
+ "
\n",
+ " \n",
+ " | 2nd Target (multi-target mode) | \n",
+ " 0.6184 | \n",
+ " 0.775763 | \n",
+ "
\n",
+ " \n",
+ " | 1st Target (Gandalf, multi-target mode) | \n",
+ " 0.9120 | \n",
+ " 0.951056 | \n",
+ "
\n",
+ " \n",
+ " | 2nd Target (Gandalf, multi-target mode) | \n",
+ " 0.2988 | \n",
+ " 0.500382 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " f1 auroc\n",
+ "1st Target (single mode) 0.9484 0.971750\n",
+ "2nd Target (single mode) 0.6808 0.813438\n",
+ "1st Target (multi-target mode) 0.9464 0.969867\n",
+ "2nd Target (multi-target mode) 0.6184 0.775763\n",
+ "1st Target (Gandalf, multi-target mode) 0.9120 0.951056\n",
+ "2nd Target (Gandalf, multi-target mode) 0.2988 0.500382"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data_config_2nd = DataConfig(\n",
+ " target=['target', 'second_target'],\n",
+ " continuous_cols=num_col_names,\n",
+ " categorical_cols=cat_col_names,\n",
+ ")\n",
+ "\n",
+ "trainer_gl_config = TrainerConfig(\n",
+ " auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n",
+ " batch_size=1024,\n",
+ " max_epochs=50,\n",
+ " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n",
+ " early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n",
+ " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n",
+ " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
+ " load_best=True, # After training, load the best checkpoint\n",
+ " # accelerator=\"cpu\"\n",
+ ")\n",
+ "\n",
+ "model_config_gandalf = GANDALFConfig(task='classification',\n",
+ " metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n",
+ " metrics_prob_input=[True, False, True], # f1_score needs probability scores, while accuracy doesn't,\n",
+ " gflu_stages=6, gflu_dropout=0.2\n",
+ ")\n",
+ "\n",
+ "\n",
+ "tabular_model = TabularModel(\n",
+ " data_config=data_config_2nd,\n",
+ " model_config=model_config_gandalf,\n",
+ " optimizer_config=optimizer_config,\n",
+ " trainer_config=trainer_gl_config,\n",
+ ")\n",
+ "\n",
+ "tabular_model.fit(train=train, validation=val)\n",
+ "\n",
+ "result = tabular_model.evaluate(test)\n",
+ "\n",
+ "result = {k: float(v) for k,v in result[0].items()}\n",
+ "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n",
+ " index=['1st Target (Gandalf, multi-target mode)'])\n",
+ "results.append(result1)\n",
+ "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n",
+ " index=['2nd Target (Gandalf, multi-target mode)'])\n",
+ "results.append(result2)\n",
+ "\n",
+ "pd.concat(results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Results Summary\n",
+ "\n",
+ "Here we see that the deeper Gandalf model with multi-target performed worse than either variant of the Category Embedding Model. As before, additional hyperparameter tuning may be required for a fair comparison."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.18"
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/__only_for_dev__/adhoc_scaffold.py b/examples/__only_for_dev__/adhoc_scaffold.py
index d028827f..efedb4af 100644
--- a/examples/__only_for_dev__/adhoc_scaffold.py
+++ b/examples/__only_for_dev__/adhoc_scaffold.py
@@ -53,8 +53,7 @@ def print_metrics(y_true, y_pred, tag):
from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402
data_config = DataConfig(
- # target should always be a list. Multi-targets are only supported for regression.
- # Multi-Task Classification is not implemented
+ # target should always be a list.
target=["target"],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
diff --git a/examples/__only_for_dev__/to_test_dae.py b/examples/__only_for_dev__/to_test_dae.py
index c00a5c1f..5d91d125 100644
--- a/examples/__only_for_dev__/to_test_dae.py
+++ b/examples/__only_for_dev__/to_test_dae.py
@@ -145,8 +145,7 @@ def print_metrics(y_true, y_pred, tag):
lr = 1e-3
data_config = DataConfig(
- # target should always be a list. Multi-targets are only supported for regression.
- # Multi-Task Classification is not implemented
+ # target should always be a list.
target=[target_name],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py
index 7df8180f..b0a4af35 100644
--- a/src/pytorch_tabular/config/config.py
+++ b/src/pytorch_tabular/config/config.py
@@ -197,6 +197,8 @@ class InferredConfig:
output_dim (Optional[int]): The number of output targets
+ output_cardinality (Optional[int]): The number of unique values in classification output
+
categorical_cardinality (Optional[List[int]]): The number of unique values in categorical features
embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
@@ -216,6 +218,10 @@ class InferredConfig:
default=None,
metadata={"help": "The number of output targets"},
)
+ output_cardinality: Optional[List[int]] = field(
+ default=None,
+ metadata={"help": "The number of unique values in classification output"},
+ )
categorical_cardinality: Optional[List[int]] = field(
default=None,
metadata={"help": "The number of unique values in categorical features"},
diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py
index 1328d34c..12b5518c 100644
--- a/src/pytorch_tabular/models/base_model.py
+++ b/src/pytorch_tabular/models/base_model.py
@@ -122,23 +122,43 @@ def __init__(
config.metrics_params.append(vars(metric))
if config.task == "classification":
config.metrics_prob_input = self.custom_metrics_prob_inputs
+ for i, mp in enumerate(config.metrics_params):
+ mp.sub_params_list = []
+ for j, num_classes in enumerate(inferred_config.output_cardinality):
+ config.metrics_params[i].sub_params_list.append(
+ OmegaConf.create(
+ {
+ "task": mp.get("task", "multiclass"),
+ "num_classes": mp.get("num_classes", num_classes),
+ }
+ )
+ )
+
# Updating default metrics in config
elif config.task == "classification":
# Adding metric_params to config for classification task
for i, mp in enumerate(config.metrics_params):
- # For classification task, output_dim == number of classses
- config.metrics_params[i]["task"] = mp.get("task", "multiclass")
- config.metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
- if config.metrics[i] in (
- "accuracy",
- "precision",
- "recall",
- "precision_recall",
- "specificity",
- "f1_score",
- "fbeta_score",
- ):
- config.metrics_params[i]["top_k"] = mp.get("top_k", 1)
+ mp.sub_params_list = []
+ for j, num_classes in enumerate(inferred_config.output_cardinality):
+ # config.metrics_params[i][j]["task"] = mp.get("task", "multiclass")
+ # config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes)
+
+ config.metrics_params[i].sub_params_list.append(
+ OmegaConf.create(
+ {"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)}
+ )
+ )
+
+ if config.metrics[i] in (
+ "accuracy",
+ "precision",
+ "recall",
+ "precision_recall",
+ "specificity",
+ "f1_score",
+ "fbeta_score",
+ ):
+ config.metrics_params[i].sub_params_list[j]["top_k"] = mp.get("top_k", 1)
if self.custom_optimizer is not None:
config.optimizer = str(self.custom_optimizer.__class__.__name__)
@@ -267,7 +287,22 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
)
else:
# TODO loss fails with batch size of 1?
- computed_loss = self.loss(y_hat.squeeze(), y.squeeze()) + reg_loss
+ computed_loss = reg_loss
+ start_index = 0
+ for i in range(len(self.hparams.output_cardinality)):
+ end_index = start_index + self.hparams.output_cardinality[i]
+ _loss = self.loss(y_hat[:, start_index:end_index], y[:, i])
+ computed_loss += _loss
+ if self.hparams.output_dim > 1:
+ self.log(
+ f"{tag}_loss_{i}",
+ _loss,
+ on_epoch=True,
+ on_step=False,
+ logger=True,
+ prog_bar=False,
+ )
+ start_index = end_index
self.log(
f"{tag}_loss",
computed_loss,
@@ -325,11 +360,29 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
_metrics.append(_metric)
avg_metric = torch.stack(_metrics, dim=0).sum()
else:
- y_hat = nn.Softmax(dim=-1)(y_hat.squeeze())
- if prob_inp:
- avg_metric = metric(y_hat, y.squeeze(), **metric_params)
- else:
- avg_metric = metric(torch.argmax(y_hat, dim=-1), y.squeeze(), **metric_params)
+ _metrics = []
+ start_index = 0
+ for i, cardinality in enumerate(self.hparams.output_cardinality):
+ end_index = start_index + cardinality
+ y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze())
+ if prob_inp:
+ _metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i])
+ else:
+ _metric = metric(
+ torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]
+ )
+ if len(self.hparams.output_cardinality) > 1:
+ self.log(
+ f"{tag}_{metric_str}_{i}",
+ _metric,
+ on_epoch=True,
+ on_step=False,
+ logger=True,
+ prog_bar=False,
+ )
+ _metrics.append(_metric)
+ start_index = end_index
+ avg_metric = torch.stack(_metrics, dim=0).sum()
metrics.append(avg_metric)
self.log(
f"{tag}_{metric_str}",
diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py
index 917bc931..50849ae4 100644
--- a/src/pytorch_tabular/tabular_datamodule.py
+++ b/src/pytorch_tabular/tabular_datamodule.py
@@ -282,13 +282,21 @@ def _update_config(self, config) -> InferredConfig:
if config.task == "regression":
# self._output_dim_reg = len(config.target) if config.target else None if self.train is not None:
output_dim = len(config.target) if config.target else None
+ output_cardinality = None
elif config.task == "classification":
# self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
if self.train is not None:
- output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None
+ output_cardinality = (
+ self.train[config.target].fillna("NA").nunique().tolist() if config.target else None
+ )
+ output_dim = sum(output_cardinality)
else:
- output_dim = len(np.unique(self.train_dataset.y)) if config.target else None
+ output_cardinality = (
+ self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None
+ )
+ output_dim = sum(output_cardinality)
elif config.task == "ssl":
+ output_cardinality = None
output_dim = None
else:
raise ValueError(f"{config.task} is an unsupported task.")
@@ -308,6 +316,7 @@ def _update_config(self, config) -> InferredConfig:
categorical_dim=categorical_dim,
continuous_dim=continuous_dim,
output_dim=output_dim,
+ output_cardinality=output_cardinality,
categorical_cardinality=categorical_cardinality,
embedding_dims=embedding_dims,
)
@@ -381,11 +390,14 @@ def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame:
if self.config.task != "classification":
return data
if stage == "fit" or self.label_encoder is None:
- self.label_encoder = LabelEncoder()
- data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]])
+ self.label_encoder = [None] * len(self.config.target)
+ for i in range(len(self.config.target)):
+ self.label_encoder[i] = LabelEncoder()
+ data[self.config.target[i]] = self.label_encoder[i].fit_transform(data[self.config.target[i]])
else:
- if self.config.target[0] in data.columns:
- data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]])
+ for i in range(len(self.config.target)):
+ if self.config.target[i] in data.columns:
+ data[self.config.target[i]] = self.label_encoder[i].transform(data[self.config.target[i]])
return data
def _target_transform(self, data: DataFrame, stage: str) -> DataFrame:
@@ -818,7 +830,8 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
# TODO Is the target encoding necessary?
if len(set(self.target) - set(df.columns)) > 0:
if self.config.task == "classification":
- df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1)
+ for i in range(len(self.target)):
+ df.loc[:, self.target[i]] = np.array([self.label_encoder[i].classes_[0]] * len(df)).reshape(-1, 1)
else:
df.loc[:, self.target] = np.zeros((len(df), len(self.target)))
df, _ = self.preprocess_data(df, stage="inference")
diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py
index 3f902a7d..11234934 100644
--- a/src/pytorch_tabular/tabular_model.py
+++ b/src/pytorch_tabular/tabular_model.py
@@ -211,9 +211,6 @@ def num_params(self):
def _run_validation(self):
"""Validates the Config params and throws errors if something is wrong."""
- if self.config.task == "classification":
- if len(self.config.target) > 1:
- raise NotImplementedError("Multi-Target Classification is not implemented.")
if self.config.task == "regression":
if self.config.target_range is not None:
if (
@@ -1291,12 +1288,16 @@ def _format_predicitons(
pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)
elif self.config.task == "classification":
- point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy()
- for i, class_ in enumerate(self.datamodule.label_encoder.classes_):
- pred_df[f"{class_}_probability"] = point_predictions[:, i]
- pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
- np.argmax(point_predictions, axis=1)
- )
+ start_index = 0
+ for i, target_col in enumerate(self.config.target):
+ end_index = start_index + self.datamodule._inferred_config.output_cardinality[i]
+ prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy()
+ start_index = end_index
+ for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_):
+ pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j]
+ pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform(
+ np.argmax(prob_prediction, axis=1)
+ )
warnings.warn(
"Classification prediction column will be renamed to"
" `{target_col}_prediction` in the next release to maintain"
@@ -2046,15 +2047,12 @@ def _combine_predictions(
elif callable(aggregate):
bagged_pred = aggregate(pred_prob_l)
if self.config.task == "classification":
- classes = self.datamodule.label_encoder.classes_
+ # FIXME need to iterate .label_encoder[x]
+ classes = self.datamodule.label_encoder[0].classes_
if aggregate == "hard_voting":
pred_df = pd.DataFrame(
np.concatenate(pred_prob_l, axis=1),
- columns=[
- f"{c}_probability_fold_{i}"
- for i in range(len(pred_prob_l))
- for c in self.datamodule.label_encoder.classes_
- ],
+ columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes],
index=pred_idx,
)
pred_df["prediction"] = classes[final_pred]
@@ -2062,7 +2060,8 @@ def _combine_predictions(
final_pred = classes[np.argmax(bagged_pred, axis=1)]
pred_df = pd.DataFrame(
bagged_pred,
- columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_],
+ # FIXME
+ columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_],
index=pred_idx,
)
pred_df["prediction"] = final_pred
diff --git a/tests/test_autoint.py b/tests/test_autoint.py
index 025b0ed6..6f4b90b7 100644
--- a/tests/test_autoint.py
+++ b/tests/test_autoint.py
@@ -78,6 +78,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -91,6 +92,7 @@ def test_regression(
@pytest.mark.parametrize("batch_norm_continuous_input", [True, False])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -100,7 +102,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_categorical_embedding.py b/tests/test_categorical_embedding.py
index cce097ba..5f0265a2 100644
--- a/tests/test_categorical_embedding.py
+++ b/tests/test_categorical_embedding.py
@@ -124,6 +124,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -136,6 +137,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -146,7 +148,7 @@ def test_classification(
return
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_common.py b/tests/test_common.py
index 5f7c4922..bd5d428e 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -753,7 +753,7 @@ def test_cross_validate_regression(
[
"accuracy",
None,
- lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values),
+ lambda y_true, y_pred: accuracy_score(y_true, y_pred["target_prediction"].values),
],
)
@pytest.mark.parametrize("return_oof", [True])
diff --git a/tests/test_danet.py b/tests/test_danet.py
index f65ecd59..ac4e3861 100644
--- a/tests/test_danet.py
+++ b/tests/test_danet.py
@@ -80,6 +80,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -91,6 +92,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -98,7 +100,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_ft_transformer.py b/tests/test_ft_transformer.py
index 6fc626e3..dac827bd 100644
--- a/tests/test_ft_transformer.py
+++ b/tests/test_ft_transformer.py
@@ -86,6 +86,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -97,6 +98,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -104,7 +106,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_gandalf.py b/tests/test_gandalf.py
index 7a702d63..374767fe 100644
--- a/tests/test_gandalf.py
+++ b/tests/test_gandalf.py
@@ -80,6 +80,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -91,6 +92,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -98,7 +100,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_gate.py b/tests/test_gate.py
index 1dc17b8f..c4c2693f 100644
--- a/tests/test_gate.py
+++ b/tests/test_gate.py
@@ -85,6 +85,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -96,6 +97,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -103,7 +105,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_mdn.py b/tests/test_mdn.py
index ca78a7bc..c0786714 100644
--- a/tests/test_mdn.py
+++ b/tests/test_mdn.py
@@ -76,6 +76,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -89,6 +90,7 @@ def test_regression(
@pytest.mark.parametrize("num_gaussian", [1, 2])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -97,7 +99,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_node.py b/tests/test_node.py
index cfac47be..ebc4b5fb 100644
--- a/tests/test_node.py
+++ b/tests/test_node.py
@@ -83,6 +83,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -94,6 +95,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -101,7 +103,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index 1e4b65dd..76d08a31 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -149,6 +149,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -161,6 +162,7 @@ def test_regression(
@pytest.mark.parametrize("freeze_backbone", [False])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -172,7 +174,7 @@ def test_classification(
ssl_train, ssl_val = train_test_split(ssl, random_state=42)
finetune_train, finetune_val = train_test_split(finetune, random_state=42)
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_tabnet.py b/tests/test_tabnet.py
index cb135117..1524aa6d 100644
--- a/tests/test_tabnet.py
+++ b/tests/test_tabnet.py
@@ -78,6 +78,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[[f"feature_{i}" for i in range(54)]],
@@ -87,6 +88,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -94,7 +96,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
diff --git a/tests/test_tabtransformer.py b/tests/test_tabtransformer.py
index b0b64a93..aa3c644f 100644
--- a/tests/test_tabtransformer.py
+++ b/tests/test_tabtransformer.py
@@ -84,6 +84,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]
+@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
@@ -95,6 +96,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
+ multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
@@ -102,7 +104,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
- target=target,
+ target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,