<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl/blob/main/rtdl/revisiting_models/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Notes

Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install delu==0.0.22
%pip install rtdl

In [1]:
import math
import warnings
from typing import Dict, Literal

warnings.simplefilter("ignore")
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm

warnings.resetwarnings()

from rtdl.revisiting_models import MLP, ResNet, FTTransformer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set random seeds in all libraries.
delu.random.seed(0)

0

## Dataset

In [3]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

task_type: TaskType = 'regression'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else np.empty((len(X_cont), 0), dtype=np.int64)
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    'train': {
        'x_cont': X_cont[train_idx],
        'x_cat': X_cat[train_idx],
        'y': Y[train_idx],
    },
    'val': {
        'x_cont': X_cont[val_idx],
        'x_cat': X_cat[val_idx],
        'y': Y[val_idx],
    },
    'test': {
        'x_cont': X_cont[test_idx],
        'x_cat': X_cat[test_idx],
        'y': Y[test_idx],
    },
}

## Preprocessing

In [4]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy['train']['x_cont']
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]['x_cont'] = preprocessing.transform(data_numpy[part]['x_cont'])

# >>> Label preprocessing.
if task_type == 'regression':
    Y_mean = data_numpy['train']['y'].mean().item()
    Y_std = data_numpy['train']['y'].std().item()
    for part in data_numpy:
        data_numpy[part]['y'] = (data_numpy[part]['y'] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != 'multiclass':
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]['y'] = data[part]['y'].float()

## Model

In [5]:
# The output size.
d_out = n_classes if task_type == 'multiclass' else 1

# # NOTE: uncomment to train MLP
# model = MLP(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=384,
#     dropout=0.1,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

# # NOTE: uncomment to train ResNet
# model = ResNet(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=192,
#     d_hidden=None,
#     d_hidden_multiplier=2.0,
#     dropout1=0.3,
#     dropout2=0.0,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    **FTTransformer.get_default_kwargs(),
).to(device)
optimizer = model.make_default_optimizer()

## Training

In [6]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = [
            F.one_hot(column, cardinality)
            for column, cardinality in zip(batch['x_cat'].T, cat_cardinalities)
        ]
        return model(torch.column_stack([batch['x_cont']] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(
            batch['x_cont'],
            batch['x_cat'] if batch['x_cat'].numel() else None,
        ).squeeze(-1)

    else:
        raise RuntimeError(f'Unknown model type: {type(model)}')


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == 'binclass'
    else F.cross_entropy
    if task_type == 'multiclass'
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]['y'].cpu().numpy()

    if task_type == 'binclass':
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == 'multiclass':
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == 'regression'
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -1.1442


In [7]:
# (A) Fast training & bad performance for demonstration purposes.
n_epochs = 20
patience = 2
print(
    f'WARNING: the number of epochs and patience are configured'
    ' for a fast demonstration, but for bad performance.'
)

# (B) Longer training & better task performance.
# n_epochs = 1_000_000_000
# patience = 17

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode='max')
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}

print(f'Device: {device.type.upper()}')
print('-' * 88 + '\n')
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data['train'], batch_size, shuffle=True),
        desc=f'Epoch {epoch}',
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch['y'])
        loss.backward()
        optimizer.step()

    val_score = evaluate('val')
    test_score = evaluate('test')
    print(f'(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}')

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best['val']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
    print()

print('\n\nResult:')
print(best)

Device: CPU
----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 52/52 [00:10<00:00,  4.94it/s]


(val) -0.6441 (test) -0.6452 [time] 0:00:11.644658
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 52/52 [00:10<00:00,  5.17it/s]


(val) -0.6141 (test) -0.6151 [time] 0:00:22.773217
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 52/52 [00:10<00:00,  5.09it/s]


(val) -0.5833 (test) -0.5832 [time] 0:00:34.048887
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 52/52 [00:10<00:00,  5.12it/s]


(val) -0.5726 (test) -0.5730 [time] 0:00:45.262249
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 52/52 [00:10<00:00,  5.00it/s]


(val) -0.5767 (test) -0.5709 [time] 0:00:56.717894



Epoch 5: 100%|██████████| 52/52 [00:09<00:00,  5.25it/s]


(val) -0.5605 (test) -0.5572 [time] 0:01:07.669551
🌸 New best epoch! 🌸



Epoch 6: 100%|██████████| 52/52 [00:10<00:00,  5.15it/s]


(val) -0.5537 (test) -0.5539 [time] 0:01:18.819721
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 52/52 [00:10<00:00,  5.09it/s]


(val) -0.5475 (test) -0.5437 [time] 0:01:30.083844
🌸 New best epoch! 🌸



Epoch 8: 100%|██████████| 52/52 [00:13<00:00,  3.96it/s]


(val) -0.5590 (test) -0.5554 [time] 0:01:44.311932



Epoch 9: 100%|██████████| 52/52 [00:10<00:00,  5.07it/s]


(val) -0.5495 (test) -0.5501 [time] 0:01:55.693106


Result:
{'val': -0.5475281168196952, 'test': -0.5437376174999943, 'epoch': 7}
