# Usage example: Custom dataset (iris)

## Step 0. Prepare datasets and create DataFrame

In [1]:
import pandas as pd

iris = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')

## Step 1. Preprocessing pd.DataFrame

In [2]:
import os,sys; sys.path.append(os.path.abspath(".."))
from deep_table.data.data_module import TabularDatamodule
from deep_table.preprocess import CategoryPreprocessor


# Convert target labels (str) to int
category_preprocesser = CategoryPreprocessor(categorical_columns=["species"], use_unk=False)
iris = category_preprocesser.fit_transform(iris)

## Step 2. Make TabularDatamodule instance

In [3]:
datamodule = TabularDatamodule(
    train=iris.iloc[:20],
    val=iris.iloc[20:40],
    test=iris.iloc[40:],
    task="multiclass",
    dim_out=3,
    categorical_columns=[],
    continuous_columns=[
        "sepal_length",
        "sepal_width",
        "petal_length",
        "petal_width"
    ],
    target=["species"],
    num_categories=0,
)

## Step 3. Training

In [4]:
from omegaconf import OmegaConf


# Encoder settings
encoder_config = OmegaConf.create({
    "embedding": {
        "name": "FeatureEmbedding",
        "args": {
            "dim_embed": 16
        }
    },
    "backbone": {
        "name": "MLPBackbone",
        "args": {}
    }
})

# model settings (learning rate, scheduler...)
model_config = OmegaConf.create({
    "name": "MLPHeadModel"
})

# training settings (epoch, gpu...)
trainer_config = OmegaConf.create({
    "gpus": 0,
    "max_epochs": 1,
    "seed": 42
})

In [5]:
from deep_table.estimators.base import Estimator
from deep_table.utils import get_scores


estimator = Estimator(encoder_config, model_config, trainer_config)
estimator.fit(datamodule)

predict = estimator.predict(datamodule.dataloader(split="test"))
get_scores(predict, target=datamodule.dataloader(split="test"), task="multiclass")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
`activation` is not set. `nn.Identity` would be used instead.
`activation` is not set. `nn.Identity` would be used instead.
Global seed set to 42

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Encoder          | 319 K 
1 | mlp     | Sequential       | 17.4 K
2 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
336 K     Trainable params
0         Non-trainable params
336 K     Total params
1.348     Total estimated model params size (MB)


                                                                                                                                                                                                                  

Global seed set to 42


Epoch 0:  50%|████████████████████████████████████████████████████████████████████▌                                                                    | 1/2 [00:08<00:04,  4.39s/it, loss=1.06, train_loss=1.060]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                           | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:19<00:00,  6.39s/it, loss=1.06, train_loss=1.060, val_loss=0.981][A
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:19<00:00,  6.40s/it, loss=1.06, train_loss=1.060, val_loss=0.981][A
Predicting: 100%|████████████████████████████████████████████████████████████████████████████████████████████████

{'cross_entropy': 1.1514963963885783, 'accuracy': 0.09090909090909091}