<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl-num-embeddings/blob/main/package/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_num_embeddings](https://github.com/yandex-research/rtdl-num-embeddings)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [1]:
# If the notebook fails because of these additional packages,
# try install specific versions:
# %pip install delu==0.0.25
# %pip install rtdl_revisiting_models==0.0.2

%pip install delu
%pip install rtdl_revisiting_models

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# ruff: noqa: E402
import math
import warnings
from typing import Literal, Optional

warnings.simplefilter('ignore')
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm

warnings.resetwarnings()

import rtdl_revisiting_models
import rtdl_num_embeddings

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set random seeds in all libraries.
delu.random.seed(0)

0

## Dataset

In [4]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

task_type: TaskType = 'regression'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else None
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    'train': {'x_cont': X_cont[train_idx], 'y': Y[train_idx]},
    'val': {'x_cont': X_cont[val_idx], 'y': Y[val_idx]},
    'test': {'x_cont': X_cont[test_idx], 'y': Y[test_idx]},
}
if X_cat is not None:
    data_numpy['train']['x_cat'] = X_cat[train_idx]
    data_numpy['val']['x_cat'] = X_cat[val_idx]
    data_numpy['test']['x_cat'] = X_cat[test_idx]

## Preprocessing

In [5]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy['train']['x_cont']
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]['x_cont'] = preprocessing.transform(data_numpy[part]['x_cont'])

# >>> Label preprocessing.
if task_type == 'regression':
    Y_mean = data_numpy['train']['y'].mean().item()
    Y_std = data_numpy['train']['y'].std().item()
    for part in data_numpy:
        data_numpy[part]['y'] = (data_numpy[part]['y'] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != 'multiclass':
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]['y'] = data[part]['y'].float()

## Model

In [12]:
class Model(nn.Module):
    def __init__(
        self,
        n_cont_features: int,
        cat_cardinalities: list[int],
        bins: Optional[list[Tensor]],
        mlp_kwargs: dict,
    ) -> None:
        super().__init__()
        self.cat_cardinalities = cat_cardinalities
        # The total representation size for categorical features
        # == the sum of one-hot representation sizes
        # == the sum of the numbers of distinct values of all features.
        d_cat = sum(cat_cardinalities)

        # Choose any of the embeddings below.

        # d_embedding = 24
        # self.cont_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(
        #     n_cont_features, d_embedding, lite=False
        # )
        # d_num = n_cont_features * d_embedding

        # assert bins is not None
        # self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEncoding(bins)
        # d_num = sum(len(b) - 1 for b in bins)

        assert bins is not None
        d_embedding = 8
        self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
            bins, d_embedding, activation=False, version='B'
        )
        d_num = n_cont_features * d_embedding

        # d_embedding = 32
        # self.cont_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(
        #     n_cont_features, d_embedding
        # )
        # d_num = n_cont_features * d_embedding

        self.backbone = rtdl_revisiting_models.MLP(d_in=d_num + d_cat, **mlp_kwargs)

    def forward(self, x_cont: Tensor, x_cat: Optional[Tensor]) -> Tensor:
        x = []

        # Step 1. Embed the continuous features.
        # Flattening is needed for MLP-like models.
        x.append(self.cont_embeddings(x_cont).flatten(1))

        # Step 2. Encode the categorical features using any strategy.
        if x_cat is not None:
            x.extend(
                F.one_hot(column, cardinality)
                for column, cardinality in zip(x_cat.T, self.cat_cardinalities)
            )

        # Step 3. Assemble the vector input for the backbone.
        x = torch.column_stack(x)

        # Step 4. Apply the backbone.
        return self.backbone(x)


# This is needed only for PiecewiseLinearEncoding and PiecewiseLinearEmbeddings.
bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])
model = Model(
    n_cont_features,
    cat_cardinalities,
    bins,
    {
        'n_blocks': 2,
        'd_block': 384,
        'dropout': 0.4,
        'd_out': n_classes if task_type == 'multiclass' else 1,
    },
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

## Training

In [13]:
def apply_model(batch: dict[str, Tensor]) -> Tensor:
    return model(batch['x_cont'], batch.get('x_cat')).squeeze(-1)


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == 'binclass'
    else F.cross_entropy
    if task_type == 'multiclass'
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]['y'].cpu().numpy()

    if task_type == 'binclass':
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == 'multiclass':
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == 'regression'
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -1.1612


In [14]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 1_000_000_000
patience = 16

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode='max')
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}

print(f'Device: {device.type.upper()}')
print('-' * 88 + '\n')
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data['train'], batch_size, shuffle=True),
        desc=f'Epoch {epoch}',
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch['y'])
        loss.backward()
        optimizer.step()

    val_score = evaluate('val')
    test_score = evaluate('test')
    print(f'(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}')

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best['val']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
    print()

print('\n\nResult:')
print(best)

Device: CPU
----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 52/52 [00:00<00:00, 246.63it/s]


(val) -0.6263 (test) -0.6354 [time] 0:00:00.225773
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 52/52 [00:00<00:00, 256.21it/s]


(val) -0.5990 (test) -0.6039 [time] 0:00:00.442324
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 52/52 [00:00<00:00, 268.94it/s]


(val) -0.5846 (test) -0.5873 [time] 0:00:00.649662
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 52/52 [00:00<00:00, 268.85it/s]


(val) -0.5727 (test) -0.5744 [time] 0:00:00.857772
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 52/52 [00:00<00:00, 247.14it/s]


(val) -0.5599 (test) -0.5578 [time] 0:00:01.081754
🌸 New best epoch! 🌸



Epoch 5: 100%|██████████| 52/52 [00:00<00:00, 266.45it/s]


(val) -0.5517 (test) -0.5493 [time] 0:00:01.291182
🌸 New best epoch! 🌸



Epoch 6: 100%|██████████| 52/52 [00:00<00:00, 262.55it/s]


(val) -0.5448 (test) -0.5408 [time] 0:00:01.503628
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 52/52 [00:00<00:00, 265.18it/s]


(val) -0.5377 (test) -0.5350 [time] 0:00:01.713532
🌸 New best epoch! 🌸



Epoch 8: 100%|██████████| 52/52 [00:00<00:00, 260.52it/s]


(val) -0.5294 (test) -0.5256 [time] 0:00:01.926896
🌸 New best epoch! 🌸



Epoch 9: 100%|██████████| 52/52 [00:00<00:00, 261.05it/s]


(val) -0.5247 (test) -0.5210 [time] 0:00:02.139919
🌸 New best epoch! 🌸



Epoch 10: 100%|██████████| 52/52 [00:00<00:00, 261.69it/s]


(val) -0.5224 (test) -0.5178 [time] 0:00:02.352378
🌸 New best epoch! 🌸



Epoch 11: 100%|██████████| 52/52 [00:00<00:00, 259.48it/s]


(val) -0.5219 (test) -0.5149 [time] 0:00:02.566670
🌸 New best epoch! 🌸



Epoch 12: 100%|██████████| 52/52 [00:00<00:00, 247.34it/s]


(val) -0.5173 (test) -0.5139 [time] 0:00:02.790892
🌸 New best epoch! 🌸



Epoch 13: 100%|██████████| 52/52 [00:00<00:00, 249.44it/s]


(val) -0.5125 (test) -0.5112 [time] 0:00:03.013256
🌸 New best epoch! 🌸



Epoch 14: 100%|██████████| 52/52 [00:00<00:00, 250.70it/s]


(val) -0.5128 (test) -0.5078 [time] 0:00:03.235203



Epoch 15: 100%|██████████| 52/52 [00:00<00:00, 256.02it/s]


(val) -0.5049 (test) -0.5040 [time] 0:00:03.452615
🌸 New best epoch! 🌸



Epoch 16: 100%|██████████| 52/52 [00:00<00:00, 251.34it/s]


(val) -0.5052 (test) -0.5039 [time] 0:00:03.672649



Epoch 17: 100%|██████████| 52/52 [00:00<00:00, 257.77it/s]


(val) -0.5001 (test) -0.4992 [time] 0:00:03.888797
🌸 New best epoch! 🌸



Epoch 18: 100%|██████████| 52/52 [00:00<00:00, 253.97it/s]


(val) -0.5027 (test) -0.5002 [time] 0:00:04.107126



Epoch 19: 100%|██████████| 52/52 [00:00<00:00, 249.83it/s]


(val) -0.5009 (test) -0.4977 [time] 0:00:04.330228



Epoch 20: 100%|██████████| 52/52 [00:00<00:00, 251.16it/s]


(val) -0.5004 (test) -0.4993 [time] 0:00:04.552041



Epoch 21: 100%|██████████| 52/52 [00:00<00:00, 247.56it/s]


(val) -0.4957 (test) -0.4973 [time] 0:00:04.776276
🌸 New best epoch! 🌸



Epoch 22: 100%|██████████| 52/52 [00:00<00:00, 234.48it/s]


(val) -0.5122 (test) -0.5117 [time] 0:00:05.012518



Epoch 23: 100%|██████████| 52/52 [00:00<00:00, 249.30it/s]


(val) -0.4940 (test) -0.4953 [time] 0:00:05.235928
🌸 New best epoch! 🌸



Epoch 24: 100%|██████████| 52/52 [00:00<00:00, 254.71it/s]


(val) -0.4939 (test) -0.4953 [time] 0:00:05.454505
🌸 New best epoch! 🌸



Epoch 25: 100%|██████████| 52/52 [00:00<00:00, 246.90it/s]


(val) -0.4930 (test) -0.4932 [time] 0:00:05.679297
🌸 New best epoch! 🌸



Epoch 26: 100%|██████████| 52/52 [00:00<00:00, 251.18it/s]


(val) -0.4925 (test) -0.4946 [time] 0:00:05.900805
🌸 New best epoch! 🌸



Epoch 27: 100%|██████████| 52/52 [00:00<00:00, 250.78it/s]


(val) -0.4909 (test) -0.4942 [time] 0:00:06.121866
🌸 New best epoch! 🌸



Epoch 28: 100%|██████████| 52/52 [00:00<00:00, 251.44it/s]


(val) -0.4900 (test) -0.4937 [time] 0:00:06.342433
🌸 New best epoch! 🌸



Epoch 29: 100%|██████████| 52/52 [00:00<00:00, 246.28it/s]


(val) -0.4896 (test) -0.4940 [time] 0:00:06.567599
🌸 New best epoch! 🌸



Epoch 30: 100%|██████████| 52/52 [00:00<00:00, 251.14it/s]


(val) -0.4857 (test) -0.4913 [time] 0:00:06.787959
🌸 New best epoch! 🌸



Epoch 31: 100%|██████████| 52/52 [00:00<00:00, 256.05it/s]


(val) -0.4845 (test) -0.4885 [time] 0:00:07.005285
🌸 New best epoch! 🌸



Epoch 32: 100%|██████████| 52/52 [00:00<00:00, 232.52it/s]


(val) -0.4917 (test) -0.4999 [time] 0:00:07.243684



Epoch 33: 100%|██████████| 52/52 [00:00<00:00, 254.40it/s]


(val) -0.4870 (test) -0.4911 [time] 0:00:07.461907



Epoch 34: 100%|██████████| 52/52 [00:00<00:00, 250.99it/s]


(val) -0.4864 (test) -0.4890 [time] 0:00:07.682530



Epoch 35: 100%|██████████| 52/52 [00:00<00:00, 250.68it/s]


(val) -0.4897 (test) -0.4964 [time] 0:00:07.904021



Epoch 36: 100%|██████████| 52/52 [00:00<00:00, 250.84it/s]


(val) -0.4887 (test) -0.4902 [time] 0:00:08.125986



Epoch 37: 100%|██████████| 52/52 [00:00<00:00, 246.42it/s]


(val) -0.4926 (test) -0.4956 [time] 0:00:08.350656



Epoch 38: 100%|██████████| 52/52 [00:00<00:00, 249.11it/s]


(val) -0.4865 (test) -0.4933 [time] 0:00:08.572795



Epoch 39: 100%|██████████| 52/52 [00:00<00:00, 253.03it/s]


(val) -0.4801 (test) -0.4847 [time] 0:00:08.792359
🌸 New best epoch! 🌸



Epoch 40: 100%|██████████| 52/52 [00:00<00:00, 245.83it/s]


(val) -0.4831 (test) -0.4874 [time] 0:00:09.018162



Epoch 41: 100%|██████████| 52/52 [00:00<00:00, 248.36it/s]


(val) -0.4827 (test) -0.4869 [time] 0:00:09.241389



Epoch 42: 100%|██████████| 52/52 [00:00<00:00, 263.28it/s]


(val) -0.4838 (test) -0.4868 [time] 0:00:09.453039



Epoch 43: 100%|██████████| 52/52 [00:00<00:00, 233.57it/s]


(val) -0.4989 (test) -0.5055 [time] 0:00:09.689764



Epoch 44: 100%|██████████| 52/52 [00:00<00:00, 254.18it/s]


(val) -0.4843 (test) -0.4854 [time] 0:00:09.908063



Epoch 45: 100%|██████████| 52/52 [00:00<00:00, 260.12it/s]


(val) -0.4790 (test) -0.4835 [time] 0:00:10.122170
🌸 New best epoch! 🌸



Epoch 46: 100%|██████████| 52/52 [00:00<00:00, 252.88it/s]


(val) -0.4857 (test) -0.4868 [time] 0:00:10.342053



Epoch 47: 100%|██████████| 52/52 [00:00<00:00, 250.65it/s]


(val) -0.4815 (test) -0.4866 [time] 0:00:10.564456



Epoch 48: 100%|██████████| 52/52 [00:00<00:00, 247.36it/s]


(val) -0.4867 (test) -0.4881 [time] 0:00:10.788933



Epoch 49: 100%|██████████| 52/52 [00:00<00:00, 249.47it/s]


(val) -0.4795 (test) -0.4859 [time] 0:00:11.010892



Epoch 50: 100%|██████████| 52/52 [00:00<00:00, 263.57it/s]


(val) -0.4823 (test) -0.4845 [time] 0:00:11.222769



Epoch 51: 100%|██████████| 52/52 [00:00<00:00, 248.24it/s]


(val) -0.4789 (test) -0.4854 [time] 0:00:11.446526
🌸 New best epoch! 🌸



Epoch 52: 100%|██████████| 52/52 [00:00<00:00, 248.31it/s]


(val) -0.4776 (test) -0.4861 [time] 0:00:11.669961
🌸 New best epoch! 🌸



Epoch 53: 100%|██████████| 52/52 [00:00<00:00, 238.19it/s]


(val) -0.4811 (test) -0.4885 [time] 0:00:11.902521



Epoch 54: 100%|██████████| 52/52 [00:00<00:00, 244.49it/s]


(val) -0.4791 (test) -0.4825 [time] 0:00:12.129379



Epoch 55: 100%|██████████| 52/52 [00:00<00:00, 248.88it/s]


(val) -0.4783 (test) -0.4843 [time] 0:00:12.353191



Epoch 56: 100%|██████████| 52/52 [00:00<00:00, 255.12it/s]


(val) -0.4720 (test) -0.4813 [time] 0:00:12.570314
🌸 New best epoch! 🌸



Epoch 57: 100%|██████████| 52/52 [00:00<00:00, 254.43it/s]


(val) -0.4810 (test) -0.4832 [time] 0:00:12.789299



Epoch 58: 100%|██████████| 52/52 [00:00<00:00, 256.92it/s]


(val) -0.4727 (test) -0.4815 [time] 0:00:13.005111



Epoch 59: 100%|██████████| 52/52 [00:00<00:00, 256.85it/s]


(val) -0.4739 (test) -0.4799 [time] 0:00:13.220974



Epoch 60: 100%|██████████| 52/52 [00:00<00:00, 256.60it/s]


(val) -0.4717 (test) -0.4808 [time] 0:00:13.437758
🌸 New best epoch! 🌸



Epoch 61: 100%|██████████| 52/52 [00:00<00:00, 246.64it/s]


(val) -0.4731 (test) -0.4786 [time] 0:00:13.662981



Epoch 62: 100%|██████████| 52/52 [00:00<00:00, 254.89it/s]


(val) -0.4734 (test) -0.4830 [time] 0:00:13.880748



Epoch 63: 100%|██████████| 52/52 [00:00<00:00, 232.55it/s]


(val) -0.4808 (test) -0.4867 [time] 0:00:14.118283



Epoch 64: 100%|██████████| 52/52 [00:00<00:00, 242.15it/s]


(val) -0.4777 (test) -0.4808 [time] 0:00:14.347416



Epoch 65: 100%|██████████| 52/52 [00:00<00:00, 246.65it/s]


(val) -0.4743 (test) -0.4810 [time] 0:00:14.571619



Epoch 66: 100%|██████████| 52/52 [00:00<00:00, 244.91it/s]


(val) -0.4714 (test) -0.4785 [time] 0:00:14.797858
🌸 New best epoch! 🌸



Epoch 67: 100%|██████████| 52/52 [00:00<00:00, 250.30it/s]


(val) -0.4711 (test) -0.4794 [time] 0:00:15.019236
🌸 New best epoch! 🌸



Epoch 68: 100%|██████████| 52/52 [00:00<00:00, 256.60it/s]


(val) -0.4746 (test) -0.4812 [time] 0:00:15.236164



Epoch 69: 100%|██████████| 52/52 [00:00<00:00, 247.10it/s]


(val) -0.4706 (test) -0.4811 [time] 0:00:15.460918
🌸 New best epoch! 🌸



Epoch 70: 100%|██████████| 52/52 [00:00<00:00, 249.32it/s]


(val) -0.4720 (test) -0.4779 [time] 0:00:15.683674



Epoch 71: 100%|██████████| 52/52 [00:00<00:00, 245.12it/s]


(val) -0.4754 (test) -0.4815 [time] 0:00:15.910964



Epoch 72: 100%|██████████| 52/52 [00:00<00:00, 256.16it/s]


(val) -0.4719 (test) -0.4808 [time] 0:00:16.128079



Epoch 73: 100%|██████████| 52/52 [00:00<00:00, 242.04it/s]


(val) -0.4708 (test) -0.4783 [time] 0:00:16.356467



Epoch 74: 100%|██████████| 52/52 [00:00<00:00, 258.67it/s]


(val) -0.4729 (test) -0.4762 [time] 0:00:16.571125



Epoch 75: 100%|██████████| 52/52 [00:00<00:00, 257.02it/s]


(val) -0.4717 (test) -0.4808 [time] 0:00:16.787014



Epoch 76: 100%|██████████| 52/52 [00:00<00:00, 248.82it/s]


(val) -0.4713 (test) -0.4802 [time] 0:00:17.010697



Epoch 77: 100%|██████████| 52/52 [00:00<00:00, 246.91it/s]


(val) -0.4684 (test) -0.4752 [time] 0:00:17.235772
🌸 New best epoch! 🌸



Epoch 78: 100%|██████████| 52/52 [00:00<00:00, 247.80it/s]


(val) -0.4698 (test) -0.4756 [time] 0:00:17.460502



Epoch 79: 100%|██████████| 52/52 [00:00<00:00, 254.02it/s]


(val) -0.4645 (test) -0.4734 [time] 0:00:17.678918
🌸 New best epoch! 🌸



Epoch 80: 100%|██████████| 52/52 [00:00<00:00, 252.59it/s]


(val) -0.4665 (test) -0.4740 [time] 0:00:17.898728



Epoch 81: 100%|██████████| 52/52 [00:00<00:00, 259.38it/s]


(val) -0.4696 (test) -0.4769 [time] 0:00:18.112861



Epoch 82: 100%|██████████| 52/52 [00:00<00:00, 257.77it/s]


(val) -0.4681 (test) -0.4747 [time] 0:00:18.328794



Epoch 83: 100%|██████████| 52/52 [00:00<00:00, 241.52it/s]


(val) -0.4704 (test) -0.4754 [time] 0:00:18.557887



Epoch 84: 100%|██████████| 52/52 [00:00<00:00, 265.57it/s]


(val) -0.4740 (test) -0.4769 [time] 0:00:18.767399



Epoch 85: 100%|██████████| 52/52 [00:00<00:00, 262.45it/s]


(val) -0.4715 (test) -0.4786 [time] 0:00:18.978528



Epoch 86: 100%|██████████| 52/52 [00:00<00:00, 251.20it/s]


(val) -0.4706 (test) -0.4754 [time] 0:00:19.199502



Epoch 87: 100%|██████████| 52/52 [00:00<00:00, 263.31it/s]


(val) -0.4707 (test) -0.4769 [time] 0:00:19.410841



Epoch 88: 100%|██████████| 52/52 [00:00<00:00, 254.89it/s]


(val) -0.4775 (test) -0.4874 [time] 0:00:19.628789



Epoch 89: 100%|██████████| 52/52 [00:00<00:00, 265.51it/s]


(val) -0.4757 (test) -0.4823 [time] 0:00:19.838074



Epoch 90: 100%|██████████| 52/52 [00:00<00:00, 251.95it/s]


(val) -0.4709 (test) -0.4762 [time] 0:00:20.058314



Epoch 91: 100%|██████████| 52/52 [00:00<00:00, 243.46it/s]


(val) -0.4685 (test) -0.4757 [time] 0:00:20.286802



Epoch 92: 100%|██████████| 52/52 [00:00<00:00, 231.67it/s]


(val) -0.4654 (test) -0.4739 [time] 0:00:20.526148



Epoch 93: 100%|██████████| 52/52 [00:00<00:00, 249.48it/s]


(val) -0.4630 (test) -0.4725 [time] 0:00:20.749025
🌸 New best epoch! 🌸



Epoch 94: 100%|██████████| 52/52 [00:00<00:00, 248.24it/s]


(val) -0.4651 (test) -0.4731 [time] 0:00:20.971487



Epoch 95: 100%|██████████| 52/52 [00:00<00:00, 247.37it/s]


(val) -0.4647 (test) -0.4722 [time] 0:00:21.196638



Epoch 96: 100%|██████████| 52/52 [00:00<00:00, 245.91it/s]


(val) -0.4803 (test) -0.4880 [time] 0:00:21.421940



Epoch 97: 100%|██████████| 52/52 [00:00<00:00, 257.05it/s]


(val) -0.4700 (test) -0.4783 [time] 0:00:21.638249



Epoch 98: 100%|██████████| 52/52 [00:00<00:00, 253.47it/s]


(val) -0.4653 (test) -0.4733 [time] 0:00:21.856979



Epoch 99: 100%|██████████| 52/52 [00:00<00:00, 261.04it/s]


(val) -0.4666 (test) -0.4724 [time] 0:00:22.070344



Epoch 100: 100%|██████████| 52/52 [00:00<00:00, 252.54it/s]


(val) -0.4688 (test) -0.4761 [time] 0:00:22.289933



Epoch 101: 100%|██████████| 52/52 [00:00<00:00, 245.99it/s]


(val) -0.4664 (test) -0.4719 [time] 0:00:22.515645



Epoch 102: 100%|██████████| 52/52 [00:00<00:00, 230.94it/s]


(val) -0.4651 (test) -0.4733 [time] 0:00:22.754611



Epoch 103: 100%|██████████| 52/52 [00:00<00:00, 249.65it/s]


(val) -0.4682 (test) -0.4777 [time] 0:00:22.976523



Epoch 104: 100%|██████████| 52/52 [00:00<00:00, 235.49it/s]


(val) -0.4744 (test) -0.4761 [time] 0:00:23.211378



Epoch 105: 100%|██████████| 52/52 [00:00<00:00, 246.09it/s]


(val) -0.4648 (test) -0.4720 [time] 0:00:23.436431



Epoch 106: 100%|██████████| 52/52 [00:00<00:00, 247.07it/s]


(val) -0.4666 (test) -0.4741 [time] 0:00:23.661494



Epoch 107: 100%|██████████| 52/52 [00:00<00:00, 247.36it/s]


(val) -0.4618 (test) -0.4714 [time] 0:00:23.886248
🌸 New best epoch! 🌸



Epoch 108: 100%|██████████| 52/52 [00:00<00:00, 248.42it/s]


(val) -0.4740 (test) -0.4796 [time] 0:00:24.109617



Epoch 109: 100%|██████████| 52/52 [00:00<00:00, 243.78it/s]


(val) -0.4650 (test) -0.4697 [time] 0:00:24.337298



Epoch 110: 100%|██████████| 52/52 [00:00<00:00, 242.56it/s]


(val) -0.4623 (test) -0.4687 [time] 0:00:24.565750



Epoch 111: 100%|██████████| 52/52 [00:00<00:00, 227.37it/s]


(val) -0.4664 (test) -0.4705 [time] 0:00:24.809605



Epoch 112: 100%|██████████| 52/52 [00:00<00:00, 245.89it/s]


(val) -0.4616 (test) -0.4702 [time] 0:00:25.035651
🌸 New best epoch! 🌸



Epoch 113: 100%|██████████| 52/52 [00:00<00:00, 242.56it/s]


(val) -0.4698 (test) -0.4733 [time] 0:00:25.264318



Epoch 114: 100%|██████████| 52/52 [00:00<00:00, 255.63it/s]


(val) -0.4698 (test) -0.4766 [time] 0:00:25.482060



Epoch 115: 100%|██████████| 52/52 [00:00<00:00, 250.80it/s]


(val) -0.4626 (test) -0.4705 [time] 0:00:25.703718



Epoch 116: 100%|██████████| 52/52 [00:00<00:00, 242.71it/s]


(val) -0.4715 (test) -0.4739 [time] 0:00:25.932528



Epoch 117: 100%|██████████| 52/52 [00:00<00:00, 248.54it/s]


(val) -0.4647 (test) -0.4729 [time] 0:00:26.156345



Epoch 118: 100%|██████████| 52/52 [00:00<00:00, 241.60it/s]


(val) -0.4639 (test) -0.4676 [time] 0:00:26.385842



Epoch 119: 100%|██████████| 52/52 [00:00<00:00, 241.23it/s]


(val) -0.4634 (test) -0.4700 [time] 0:00:26.615703



Epoch 120: 100%|██████████| 52/52 [00:00<00:00, 236.84it/s]


(val) -0.4645 (test) -0.4687 [time] 0:00:26.849054



Epoch 121: 100%|██████████| 52/52 [00:00<00:00, 260.16it/s]


(val) -0.4658 (test) -0.4723 [time] 0:00:27.062289



Epoch 122: 100%|██████████| 52/52 [00:00<00:00, 260.83it/s]


(val) -0.4654 (test) -0.4702 [time] 0:00:27.275472



Epoch 123: 100%|██████████| 52/52 [00:00<00:00, 247.96it/s]


(val) -0.4780 (test) -0.4876 [time] 0:00:27.499093



Epoch 124: 100%|██████████| 52/52 [00:00<00:00, 248.23it/s]


(val) -0.4643 (test) -0.4730 [time] 0:00:27.722564



Epoch 125: 100%|██████████| 52/52 [00:00<00:00, 249.86it/s]


(val) -0.4642 (test) -0.4705 [time] 0:00:27.944486



Epoch 126: 100%|██████████| 52/52 [00:00<00:00, 245.48it/s]


(val) -0.4687 (test) -0.4748 [time] 0:00:28.170121



Epoch 127: 100%|██████████| 52/52 [00:00<00:00, 261.16it/s]


(val) -0.4622 (test) -0.4701 [time] 0:00:28.382843



Epoch 128: 100%|██████████| 52/52 [00:00<00:00, 254.43it/s]


(val) -0.4611 (test) -0.4666 [time] 0:00:28.601271
🌸 New best epoch! 🌸



Epoch 129: 100%|██████████| 52/52 [00:00<00:00, 244.72it/s]


(val) -0.4619 (test) -0.4709 [time] 0:00:28.827736



Epoch 130: 100%|██████████| 52/52 [00:00<00:00, 256.93it/s]


(val) -0.4629 (test) -0.4743 [time] 0:00:29.043704



Epoch 131: 100%|██████████| 52/52 [00:00<00:00, 261.67it/s]


(val) -0.4664 (test) -0.4706 [time] 0:00:29.256837



Epoch 132: 100%|██████████| 52/52 [00:00<00:00, 264.75it/s]


(val) -0.4698 (test) -0.4754 [time] 0:00:29.466750



Epoch 133: 100%|██████████| 52/52 [00:00<00:00, 254.25it/s]


(val) -0.4613 (test) -0.4685 [time] 0:00:29.684580



Epoch 134: 100%|██████████| 52/52 [00:00<00:00, 256.48it/s]


(val) -0.4679 (test) -0.4706 [time] 0:00:29.899860



Epoch 135: 100%|██████████| 52/52 [00:00<00:00, 246.48it/s]


(val) -0.4627 (test) -0.4698 [time] 0:00:30.124583



Epoch 136: 100%|██████████| 52/52 [00:00<00:00, 247.34it/s]


(val) -0.4702 (test) -0.4739 [time] 0:00:30.349664



Epoch 137: 100%|██████████| 52/52 [00:00<00:00, 247.82it/s]


(val) -0.4675 (test) -0.4693 [time] 0:00:30.572873



Epoch 138: 100%|██████████| 52/52 [00:00<00:00, 230.78it/s]


(val) -0.4664 (test) -0.4716 [time] 0:00:30.811788



Epoch 139: 100%|██████████| 52/52 [00:00<00:00, 259.66it/s]


(val) -0.4650 (test) -0.4704 [time] 0:00:31.025837



Epoch 140: 100%|██████████| 52/52 [00:00<00:00, 261.77it/s]


(val) -0.4683 (test) -0.4730 [time] 0:00:31.238291



Epoch 141: 100%|██████████| 52/52 [00:00<00:00, 255.34it/s]


(val) -0.4718 (test) -0.4752 [time] 0:00:31.455647



Epoch 142: 100%|██████████| 52/52 [00:00<00:00, 254.83it/s]


(val) -0.4735 (test) -0.4786 [time] 0:00:31.673475



Epoch 143: 100%|██████████| 52/52 [00:00<00:00, 247.36it/s]


(val) -0.4653 (test) -0.4702 [time] 0:00:31.897596



Epoch 144: 100%|██████████| 52/52 [00:00<00:00, 243.74it/s]

(val) -0.4689 (test) -0.4718 [time] 0:00:32.124995


Result:
{'val': -0.4610990747680548, 'test': -0.4665925272393832, 'epoch': 128}



