<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl/blob/main/rtdl/revisiting_models/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Notes

Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install delu
%pip install rtdl_revisiting_models

In [None]:
import math
from typing import Dict, Literal

import delu  # Deep Learning Utilities: https://yura52.github.io/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm

from rtdl_revisiting_models import MLP, ResNet, FTTransformer

In [2]:
device = torch.device('cpu')
# Set random seeds in all libraries.
delu.random.seed(0)

0

## Dataset

In [3]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

task_type: TaskType = 'regression'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else np.empty((len(X_cont), 0), dtype=np.int64)
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    'train': {
        'x_cont': X_cont[train_idx],
        'x_cat': X_cat[train_idx],
        'y': Y[train_idx],
    },
    'val': {
        'x_cont': X_cont[val_idx],
        'x_cat': X_cat[val_idx],
        'y': Y[val_idx],
    },
    'test': {
        'x_cont': X_cont[test_idx],
        'x_cat': X_cat[test_idx],
        'y': Y[test_idx],
    },
}

## Preprocessing

In [4]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
X_cont_std = data_numpy['train']['x_cont'].std(axis=0)
noise_hint = 1e-3
noise_std = np.minimum(noise_hint, noise_hint * X_cont_std)
noise = np.random.normal(0.0, noise_std, data_numpy['train']['x_cont'].shape)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=min(len(train_idx) // 30, 1000),
    output_distribution='normal',
    subsample=10**9,
).fit(data_numpy['train']['x_cont'] + noise)

for part in data_numpy:
    data_numpy[part]['x_cont'] = preprocessing.transform(data_numpy[part]['x_cont'])

# >>> Label preprocessing.
if task_type == 'regression':
    Y_mean = data_numpy['train']['y'].mean().item()
    Y_std = data_numpy['train']['y'].std().item()
    for part in data_numpy:
        data_numpy[part]['y'] = (data_numpy[part]['y'] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != 'multiclass':
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]['y'] = data[part]['y'].float()

## Model

In [5]:
# The output size.
d_out = n_classes if task_type == 'multiclass' else 1

# # NOTE: uncomment to train MLP
# model = MLP(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=384,
#     dropout=0.1,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

# # NOTE: uncomment to train ResNet
# model = ResNet(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=192,
#     d_hidden=None,
#     d_hidden_multiplier=2.0,
#     dropout1=0.3,
#     dropout2=0.0,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    **FTTransformer.get_default_kwargs(),
).to(device)
optimizer = model.make_default_optimizer()

## Training

In [6]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = [
            F.one_hot(column, cardinality)
            for column, cardinality in zip(batch['x_cat'].T, cat_cardinalities)
        ]
        return model(torch.column_stack([batch['x_cont']] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(
            batch['x_cont'],
            batch['x_cat'] if batch['x_cat'].numel() else None,
        ).squeeze(-1)

    else:
        raise RuntimeError(f'Unknown model type: {type(model)}')


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == 'binclass'
    else F.cross_entropy
    if task_type == 'multiclass'
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]['y'].cpu().numpy()

    if task_type == 'binclass':
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == 'multiclass':
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == 'regression'
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -1.2005


In [7]:
batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)

# (A) Fast training & bad performance for demonstration purposes.
n_epochs = 20
patience = 2
print(
    f'WARNING: the number of epochs and patience are configured'
    ' for a fast demonstration, but for bad performance.\n'
)

# (B) Longer training & better task performance.
# n_epochs = 1_000_000_000
# patience = 17

early_stopping = delu.EarlyStopping(patience, mode='max')
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}

for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data['train'], batch_size, shuffle=True),
        desc=f'Epoch {epoch}',
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch['y'])
        loss.backward()
        optimizer.step()

    val_score = evaluate('val')
    test_score = evaluate('test')
    print(f'(val) {val_score:.4f} (test) {test_score:.4f}')

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best['val']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
    print()

print('\n\nResult:')
print(best)

Epoch 0:   0%|          | 0/52 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 52/52 [00:10<00:00,  5.06it/s]


(val) -0.6088 (test) -0.6100
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 52/52 [00:11<00:00,  4.49it/s]


(val) -0.5894 (test) -0.5924
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 52/52 [00:12<00:00,  4.04it/s]


(val) -0.5625 (test) -0.5639
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 52/52 [00:12<00:00,  4.00it/s]


(val) -0.5600 (test) -0.5591
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 52/52 [00:11<00:00,  4.72it/s]


(val) -0.5460 (test) -0.5421
🌸 New best epoch! 🌸



Epoch 5: 100%|██████████| 52/52 [00:11<00:00,  4.48it/s]


(val) -0.5436 (test) -0.5421
🌸 New best epoch! 🌸



Epoch 6: 100%|██████████| 52/52 [00:11<00:00,  4.65it/s]


(val) -0.5370 (test) -0.5340
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 52/52 [00:10<00:00,  4.78it/s]


(val) -0.5365 (test) -0.5334
🌸 New best epoch! 🌸



Epoch 8: 100%|██████████| 52/52 [00:10<00:00,  5.20it/s]


(val) -0.5364 (test) -0.5321
🌸 New best epoch! 🌸



Epoch 9: 100%|██████████| 52/52 [00:09<00:00,  5.30it/s]


(val) -0.5381 (test) -0.5313



Epoch 10: 100%|██████████| 52/52 [00:14<00:00,  3.58it/s]


(val) -0.5362 (test) -0.5325
🌸 New best epoch! 🌸



Epoch 11: 100%|██████████| 52/52 [00:13<00:00,  3.98it/s]


(val) -0.5233 (test) -0.5184
🌸 New best epoch! 🌸



Epoch 12: 100%|██████████| 52/52 [00:10<00:00,  5.09it/s]


(val) -0.5160 (test) -0.5131
🌸 New best epoch! 🌸



Epoch 13: 100%|██████████| 52/52 [00:10<00:00,  5.04it/s]


(val) -0.5316 (test) -0.5352



Epoch 14: 100%|██████████| 52/52 [00:10<00:00,  5.02it/s]


(val) -0.5143 (test) -0.5141
🌸 New best epoch! 🌸



Epoch 15: 100%|██████████| 52/52 [00:10<00:00,  5.05it/s]


(val) -0.5131 (test) -0.5088
🌸 New best epoch! 🌸



Epoch 16: 100%|██████████| 52/52 [00:09<00:00,  5.28it/s]


(val) -0.5091 (test) -0.5102
🌸 New best epoch! 🌸



Epoch 17: 100%|██████████| 52/52 [00:10<00:00,  5.12it/s]


(val) -0.5056 (test) -0.5036
🌸 New best epoch! 🌸



Epoch 18: 100%|██████████| 52/52 [00:09<00:00,  5.29it/s]


(val) -0.5104 (test) -0.5072



Epoch 19: 100%|██████████| 52/52 [00:09<00:00,  5.28it/s]


(val) -0.5175 (test) -0.5154


Result:
{'val': -0.505572326153713, 'test': -0.5035526553426177, 'epoch': 17}
