<a target="_blank" href="https://colab.research.google.com/github/yandex-research/tabular-dl-num-embeddings/blob/main/package/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Notes

**Hyperparameters are not tuned and may be suboptimal.**

In [None]:
%pip install delu==0.0.23
%pip install rtdl

In [1]:
# ruff: noqa: E402
import math
import warnings
from typing import Dict, Literal, List, Optional

warnings.simplefilter("ignore")
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm

warnings.resetwarnings()

import rtdl_revisiting_models
import rtdl_num_embeddings

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds in all libraries.
delu.random.seed(0)

0

## Dataset

In [3]:
# >>> Dataset.
TaskType = Literal["regression", "binclass", "multiclass"]

task_type: TaskType = "regression"
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset["data"]
Y: np.ndarray = dataset["target"]

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else None
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == "regression":
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), "Classification labels must form the range [0, 1, ..., n_classes - 1]"

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    "train": {"x_cont": X_cont[train_idx], "y": Y[train_idx]},
    "val": {"x_cont": X_cont[val_idx], "y": Y[val_idx]},
    "test": {"x_cont": X_cont[test_idx], "y": Y[test_idx]},
}
if X_cat is not None:
    data_numpy["train"]["x_cat"] = X_cat[train_idx]
    data_numpy["val"]["x_cat"] = X_cat[val_idx]
    data_numpy["test"]["x_cat"] = X_cat[test_idx]

## Preprocessing

In [4]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy["train"]["x_cont"]
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution="normal",
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]["x_cont"] = preprocessing.transform(data_numpy[part]["x_cont"])

# >>> Label preprocessing.
if task_type == "regression":
    Y_mean = data_numpy["train"]["y"].mean().item()
    Y_std = data_numpy["train"]["y"].std().item()
    for part in data_numpy:
        data_numpy[part]["y"] = (data_numpy[part]["y"] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != "multiclass":
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]["y"] = data[part]["y"].float()

## Model

In [5]:
class Model(nn.Module):
    def __init__(
        self,
        n_cont_features: int,
        cat_cardinalities: List[int],
        bins: Optional[List[Tensor]],
        mlp_kwargs: dict,
    ) -> None:
        super().__init__()
        self.cat_cardinalities = cat_cardinalities
        # The total representation size for categorical features
        # == the sum of one-hot representation sizes
        # == the sum of the numbers of distinct values of all features.
        d_cat = sum(cat_cardinalities)

        # Choose any of the embeddings below.

        # Model == MLP-PLR.
        d_embedding = 24
        self.cont_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(
            n_cont_features, d_embedding, lite=False
        )
        d_num = n_cont_features * d_embedding

        # Model == MLP-Q or MLP-T depending on how bins were computed.
        # assert bins is not None
        # self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEncoding(bins)
        # d_num = sum(len(b) - 1 for b in bins)

        # Model == MLP-QL or MLP-TL depending on how bins were computed.
        # assert bins is not None
        # d_embedding = 8
        # self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(bins, d_embedding, activation=False)
        # d_num = n_cont_features * d_embedding

        # Model == MLP-LR.
        # d_embedding = 32
        # self.cont_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(n_cont_features, d_embedding)
        # d_num = n_cont_features * d_embedding

        self.backbone = rtdl_revisiting_models.MLP(d_in=d_num + d_cat, **mlp_kwargs)

    def forward(self, x_cont: Tensor, x_cat: Optional[Tensor]) -> Tensor:
        x = []

        # Step 1. Embed the continuous features.
        # Flattening is needed for MLP-like models.
        x.append(self.cont_embeddings(x_cont).flatten(1))

        # Step 2. Encode the categorical features using any strategy.
        if x_cat is not None:
            x.extend(
                F.one_hot(column, cardinality)
                for column, cardinality in zip(x_cat.T, self.cat_cardinalities)
            )

        # Step 3. Assemble the vector input for the backbone.
        x = torch.column_stack(x)

        # Step 4. Apply the backbone.
        return self.backbone(x)


# This is needed only for PiecewiseLinearEncoding and PiecewiseLinearEmbeddings.
bins = rtdl_num_embeddings.compute_bins(data["train"]["x_cont"])
model = Model(
    n_cont_features,
    cat_cardinalities,
    bins,
    {
        "n_blocks": 2,
        "d_block": 384,
        "dropout": 0.4,
        "d_out": n_classes if task_type == "multiclass" else 1,
    },
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

## Training

In [6]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    return model(batch["x_cont"], batch.get("x_cat")).squeeze(-1)


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == "binclass"
    else F.cross_entropy
    if task_type == "multiclass"
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]["y"].cpu().numpy()

    if task_type == "binclass":
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == "multiclass":
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == "regression"
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -1.1556


In [7]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 1_000_000_000
patience = 16

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "test": -math.inf,
    "epoch": -1,
}

print(f"Device: {device.type.upper()}")
print("-" * 88 + "\n")
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data["train"], batch_size, shuffle=True),
        desc=f"Epoch {epoch}",
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch["y"])
        loss.backward()
        optimizer.step()

    val_score = evaluate("val")
    test_score = evaluate("test")
    print(f"(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}")

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_score, "test": test_score, "epoch": epoch}
    print()

print("\n\nResult:")
print(best)

Device: CPU
----------------------------------------------------------------------------------------



Epoch 0:   0%|          | 0/52 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 52/52 [00:00<00:00, 70.57it/s]


(val) -0.7777 (test) -0.7983 [time] 0:00:00.783607
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 52/52 [00:00<00:00, 65.52it/s]


(val) -0.6336 (test) -0.6360 [time] 0:00:01.609783
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 52/52 [00:00<00:00, 62.85it/s]


(val) -0.6174 (test) -0.6164 [time] 0:00:02.468450
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 52/52 [00:00<00:00, 54.46it/s]


(val) -0.6003 (test) -0.5994 [time] 0:00:03.454586
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 52/52 [00:00<00:00, 62.83it/s]


(val) -0.5932 (test) -0.5924 [time] 0:00:04.315101
🌸 New best epoch! 🌸



Epoch 5: 100%|██████████| 52/52 [00:00<00:00, 66.09it/s]


(val) -0.5956 (test) -0.5889 [time] 0:00:05.135632



Epoch 6: 100%|██████████| 52/52 [00:00<00:00, 68.81it/s]


(val) -0.5779 (test) -0.5752 [time] 0:00:05.926288
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 52/52 [00:00<00:00, 56.95it/s]


(val) -0.5729 (test) -0.5680 [time] 0:00:06.877613
🌸 New best epoch! 🌸



Epoch 8: 100%|██████████| 52/52 [00:00<00:00, 64.97it/s]


(val) -0.5695 (test) -0.5653 [time] 0:00:07.711176
🌸 New best epoch! 🌸



Epoch 9: 100%|██████████| 52/52 [00:00<00:00, 79.61it/s]


(val) -0.5722 (test) -0.5706 [time] 0:00:08.393044



Epoch 10: 100%|██████████| 52/52 [00:00<00:00, 76.05it/s]


(val) -0.5685 (test) -0.5648 [time] 0:00:09.112112
🌸 New best epoch! 🌸



Epoch 11: 100%|██████████| 52/52 [00:00<00:00, 74.64it/s]


(val) -0.5599 (test) -0.5564 [time] 0:00:09.838392
🌸 New best epoch! 🌸



Epoch 12: 100%|██████████| 52/52 [00:00<00:00, 73.61it/s]


(val) -0.5613 (test) -0.5560 [time] 0:00:10.575555



Epoch 13: 100%|██████████| 52/52 [00:00<00:00, 69.25it/s]


(val) -0.5538 (test) -0.5504 [time] 0:00:11.358541
🌸 New best epoch! 🌸



Epoch 14: 100%|██████████| 52/52 [00:00<00:00, 74.63it/s]


(val) -0.5535 (test) -0.5472 [time] 0:00:12.088588
🌸 New best epoch! 🌸



Epoch 15: 100%|██████████| 52/52 [00:00<00:00, 66.88it/s]


(val) -0.5597 (test) -0.5537 [time] 0:00:12.895674



Epoch 16: 100%|██████████| 52/52 [00:00<00:00, 77.71it/s]


(val) -0.5464 (test) -0.5418 [time] 0:00:13.594061
🌸 New best epoch! 🌸



Epoch 17: 100%|██████████| 52/52 [00:00<00:00, 76.35it/s]


(val) -0.5478 (test) -0.5442 [time] 0:00:14.305069



Epoch 18: 100%|██████████| 52/52 [00:00<00:00, 78.01it/s]


(val) -0.5471 (test) -0.5442 [time] 0:00:15.001982



Epoch 19: 100%|██████████| 52/52 [00:00<00:00, 76.88it/s]


(val) -0.5404 (test) -0.5351 [time] 0:00:15.710724
🌸 New best epoch! 🌸



Epoch 20: 100%|██████████| 52/52 [00:00<00:00, 69.95it/s]


(val) -0.5365 (test) -0.5369 [time] 0:00:16.484272
🌸 New best epoch! 🌸



Epoch 21: 100%|██████████| 52/52 [00:00<00:00, 76.47it/s]


(val) -0.5440 (test) -0.5361 [time] 0:00:17.193300



Epoch 22: 100%|██████████| 52/52 [00:00<00:00, 76.89it/s]


(val) -0.5359 (test) -0.5301 [time] 0:00:17.904786
🌸 New best epoch! 🌸



Epoch 23: 100%|██████████| 52/52 [00:00<00:00, 70.98it/s]


(val) -0.5340 (test) -0.5292 [time] 0:00:18.667242
🌸 New best epoch! 🌸



Epoch 24: 100%|██████████| 52/52 [00:00<00:00, 78.22it/s]


(val) -0.5305 (test) -0.5259 [time] 0:00:19.362088
🌸 New best epoch! 🌸



Epoch 25: 100%|██████████| 52/52 [00:00<00:00, 78.47it/s]


(val) -0.5260 (test) -0.5224 [time] 0:00:20.055360
🌸 New best epoch! 🌸



Epoch 26: 100%|██████████| 52/52 [00:00<00:00, 75.08it/s]


(val) -0.5270 (test) -0.5231 [time] 0:00:20.776708



Epoch 27: 100%|██████████| 52/52 [00:00<00:00, 64.58it/s]


(val) -0.5326 (test) -0.5279 [time] 0:00:21.617215



Epoch 28: 100%|██████████| 52/52 [00:00<00:00, 76.31it/s]


(val) -0.5240 (test) -0.5206 [time] 0:00:22.329100
🌸 New best epoch! 🌸



Epoch 29: 100%|██████████| 52/52 [00:00<00:00, 77.28it/s]


(val) -0.5216 (test) -0.5201 [time] 0:00:23.032810
🌸 New best epoch! 🌸



Epoch 30: 100%|██████████| 52/52 [00:00<00:00, 74.03it/s]


(val) -0.5178 (test) -0.5174 [time] 0:00:23.767504
🌸 New best epoch! 🌸



Epoch 31: 100%|██████████| 52/52 [00:00<00:00, 73.65it/s]


(val) -0.5228 (test) -0.5208 [time] 0:00:24.505040



Epoch 32: 100%|██████████| 52/52 [00:00<00:00, 78.47it/s]


(val) -0.5168 (test) -0.5144 [time] 0:00:25.200334
🌸 New best epoch! 🌸



Epoch 33: 100%|██████████| 52/52 [00:00<00:00, 67.59it/s]


(val) -0.5209 (test) -0.5199 [time] 0:00:25.999622



Epoch 34: 100%|██████████| 52/52 [00:00<00:00, 70.62it/s]


(val) -0.5149 (test) -0.5171 [time] 0:00:26.771421
🌸 New best epoch! 🌸



Epoch 35: 100%|██████████| 52/52 [00:00<00:00, 75.60it/s]


(val) -0.5178 (test) -0.5148 [time] 0:00:27.494847



Epoch 36: 100%|██████████| 52/52 [00:00<00:00, 75.67it/s]


(val) -0.5153 (test) -0.5143 [time] 0:00:28.213514



Epoch 37: 100%|██████████| 52/52 [00:00<00:00, 74.51it/s]


(val) -0.5114 (test) -0.5111 [time] 0:00:28.943544
🌸 New best epoch! 🌸



Epoch 38: 100%|██████████| 52/52 [00:00<00:00, 76.48it/s]


(val) -0.5148 (test) -0.5129 [time] 0:00:29.651076



Epoch 39: 100%|██████████| 52/52 [00:00<00:00, 73.48it/s]


(val) -0.5193 (test) -0.5156 [time] 0:00:30.387862



Epoch 40: 100%|██████████| 52/52 [00:00<00:00, 73.05it/s]


(val) -0.5104 (test) -0.5101 [time] 0:00:31.143107
🌸 New best epoch! 🌸



Epoch 41: 100%|██████████| 52/52 [00:00<00:00, 64.92it/s]


(val) -0.5103 (test) -0.5089 [time] 0:00:31.985849
🌸 New best epoch! 🌸



Epoch 42: 100%|██████████| 52/52 [00:00<00:00, 62.80it/s]


(val) -0.5234 (test) -0.5225 [time] 0:00:32.846015



Epoch 43: 100%|██████████| 52/52 [00:00<00:00, 70.38it/s]


(val) -0.5084 (test) -0.5094 [time] 0:00:33.618419
🌸 New best epoch! 🌸



Epoch 44: 100%|██████████| 52/52 [00:00<00:00, 75.87it/s]


(val) -0.5065 (test) -0.5061 [time] 0:00:34.333827
🌸 New best epoch! 🌸



Epoch 45: 100%|██████████| 52/52 [00:00<00:00, 69.63it/s]


(val) -0.5105 (test) -0.5085 [time] 0:00:35.112714



Epoch 46: 100%|██████████| 52/52 [00:00<00:00, 75.47it/s]


(val) -0.5119 (test) -0.5102 [time] 0:00:35.834531



Epoch 47: 100%|██████████| 52/52 [00:00<00:00, 75.07it/s]


(val) -0.5055 (test) -0.5056 [time] 0:00:36.559820
🌸 New best epoch! 🌸



Epoch 48: 100%|██████████| 52/52 [00:00<00:00, 72.91it/s]


(val) -0.5036 (test) -0.5049 [time] 0:00:37.302773
🌸 New best epoch! 🌸



Epoch 49: 100%|██████████| 52/52 [00:00<00:00, 63.30it/s]


(val) -0.5031 (test) -0.5061 [time] 0:00:38.159594
🌸 New best epoch! 🌸



Epoch 50: 100%|██████████| 52/52 [00:00<00:00, 72.74it/s]


(val) -0.5058 (test) -0.5038 [time] 0:00:38.904748



Epoch 51: 100%|██████████| 52/52 [00:00<00:00, 72.67it/s]


(val) -0.5038 (test) -0.5023 [time] 0:00:39.649371



Epoch 52: 100%|██████████| 52/52 [00:00<00:00, 65.73it/s]


(val) -0.5039 (test) -0.5028 [time] 0:00:40.476049



Epoch 53: 100%|██████████| 52/52 [00:00<00:00, 70.88it/s]


(val) -0.5019 (test) -0.5005 [time] 0:00:41.242209
🌸 New best epoch! 🌸



Epoch 54: 100%|██████████| 52/52 [00:00<00:00, 76.35it/s]


(val) -0.5003 (test) -0.5001 [time] 0:00:41.956375
🌸 New best epoch! 🌸



Epoch 55: 100%|██████████| 52/52 [00:00<00:00, 70.85it/s]


(val) -0.5053 (test) -0.5043 [time] 0:00:42.729031



Epoch 56: 100%|██████████| 52/52 [00:00<00:00, 69.60it/s]


(val) -0.4999 (test) -0.4986 [time] 0:00:43.512425
🌸 New best epoch! 🌸



Epoch 57: 100%|██████████| 52/52 [00:00<00:00, 70.00it/s]


(val) -0.5010 (test) -0.5024 [time] 0:00:44.291349



Epoch 58: 100%|██████████| 52/52 [00:00<00:00, 74.90it/s]


(val) -0.5049 (test) -0.5050 [time] 0:00:45.016439



Epoch 59: 100%|██████████| 52/52 [00:00<00:00, 69.35it/s]


(val) -0.5018 (test) -0.5010 [time] 0:00:45.796623



Epoch 60: 100%|██████████| 52/52 [00:00<00:00, 67.93it/s]


(val) -0.5012 (test) -0.5027 [time] 0:00:46.594661



Epoch 61: 100%|██████████| 52/52 [00:00<00:00, 55.90it/s]


(val) -0.5038 (test) -0.5060 [time] 0:00:47.562994



Epoch 62: 100%|██████████| 52/52 [00:00<00:00, 56.65it/s]


(val) -0.5015 (test) -0.5005 [time] 0:00:48.516118



Epoch 63: 100%|██████████| 52/52 [00:00<00:00, 55.40it/s]


(val) -0.5013 (test) -0.5012 [time] 0:00:49.486839



Epoch 64: 100%|██████████| 52/52 [00:00<00:00, 64.77it/s]


(val) -0.4962 (test) -0.4960 [time] 0:00:50.321588
🌸 New best epoch! 🌸



Epoch 65: 100%|██████████| 52/52 [00:00<00:00, 62.27it/s]


(val) -0.4986 (test) -0.4985 [time] 0:00:51.187292



Epoch 66: 100%|██████████| 52/52 [00:00<00:00, 74.06it/s]


(val) -0.4994 (test) -0.4987 [time] 0:00:51.921865



Epoch 67: 100%|██████████| 52/52 [00:00<00:00, 76.41it/s]


(val) -0.4940 (test) -0.4946 [time] 0:00:52.630203
🌸 New best epoch! 🌸



Epoch 68: 100%|██████████| 52/52 [00:00<00:00, 75.43it/s]


(val) -0.5024 (test) -0.4993 [time] 0:00:53.352808



Epoch 69: 100%|██████████| 52/52 [00:00<00:00, 64.64it/s]


(val) -0.4969 (test) -0.4951 [time] 0:00:54.188798



Epoch 70: 100%|██████████| 52/52 [00:00<00:00, 75.49it/s]


(val) -0.4947 (test) -0.4966 [time] 0:00:54.904888



Epoch 71: 100%|██████████| 52/52 [00:00<00:00, 78.97it/s]


(val) -0.4960 (test) -0.4949 [time] 0:00:55.592324



Epoch 72: 100%|██████████| 52/52 [00:00<00:00, 69.34it/s]


(val) -0.4975 (test) -0.4977 [time] 0:00:56.370665



Epoch 73: 100%|██████████| 52/52 [00:00<00:00, 77.13it/s]


(val) -0.4933 (test) -0.4929 [time] 0:00:57.073650
🌸 New best epoch! 🌸



Epoch 74: 100%|██████████| 52/52 [00:00<00:00, 73.56it/s]


(val) -0.4921 (test) -0.4899 [time] 0:00:57.809498
🌸 New best epoch! 🌸



Epoch 75: 100%|██████████| 52/52 [00:00<00:00, 71.05it/s]


(val) -0.4957 (test) -0.4919 [time] 0:00:58.570345



Epoch 76: 100%|██████████| 52/52 [00:00<00:00, 72.92it/s]


(val) -0.4960 (test) -0.4946 [time] 0:00:59.314562



Epoch 77: 100%|██████████| 52/52 [00:00<00:00, 77.73it/s]


(val) -0.4987 (test) -0.4962 [time] 0:01:00.011554



Epoch 78: 100%|██████████| 52/52 [00:00<00:00, 77.24it/s]


(val) -0.4925 (test) -0.4903 [time] 0:01:00.713512



Epoch 79: 100%|██████████| 52/52 [00:00<00:00, 72.73it/s]


(val) -0.4999 (test) -0.4969 [time] 0:01:01.456478



Epoch 80: 100%|██████████| 52/52 [00:00<00:00, 68.39it/s]


(val) -0.4962 (test) -0.4943 [time] 0:01:02.245972



Epoch 81: 100%|██████████| 52/52 [00:00<00:00, 79.73it/s]


(val) -0.4919 (test) -0.4908 [time] 0:01:02.925611
🌸 New best epoch! 🌸



Epoch 82: 100%|██████████| 52/52 [00:00<00:00, 75.30it/s]


(val) -0.4923 (test) -0.4899 [time] 0:01:03.670569



Epoch 83: 100%|██████████| 52/52 [00:00<00:00, 67.60it/s]


(val) -0.4928 (test) -0.4945 [time] 0:01:04.468671



Epoch 84: 100%|██████████| 52/52 [00:00<00:00, 62.93it/s]


(val) -0.4973 (test) -0.4927 [time] 0:01:05.344641



Epoch 85: 100%|██████████| 52/52 [00:00<00:00, 64.55it/s]


(val) -0.4939 (test) -0.4906 [time] 0:01:06.185510



Epoch 86: 100%|██████████| 52/52 [00:00<00:00, 62.81it/s]


(val) -0.4906 (test) -0.4881 [time] 0:01:07.047094
🌸 New best epoch! 🌸



Epoch 87: 100%|██████████| 52/52 [00:00<00:00, 56.52it/s]


(val) -0.4953 (test) -0.4924 [time] 0:01:07.996084



Epoch 88: 100%|██████████| 52/52 [00:00<00:00, 73.10it/s]


(val) -0.4905 (test) -0.4903 [time] 0:01:08.737195
🌸 New best epoch! 🌸



Epoch 89: 100%|██████████| 52/52 [00:00<00:00, 71.00it/s]


(val) -0.4920 (test) -0.4892 [time] 0:01:09.500282



Epoch 90: 100%|██████████| 52/52 [00:00<00:00, 65.75it/s]


(val) -0.4925 (test) -0.4873 [time] 0:01:10.337718



Epoch 91: 100%|██████████| 52/52 [00:00<00:00, 65.01it/s]


(val) -0.4931 (test) -0.4888 [time] 0:01:11.177083



Epoch 92: 100%|██████████| 52/52 [00:00<00:00, 68.85it/s]


(val) -0.4938 (test) -0.4923 [time] 0:01:11.962380



Epoch 93: 100%|██████████| 52/52 [00:00<00:00, 68.69it/s]


(val) -0.4999 (test) -0.4952 [time] 0:01:12.750272



Epoch 94: 100%|██████████| 52/52 [00:00<00:00, 66.56it/s]


(val) -0.4960 (test) -0.4926 [time] 0:01:13.570959



Epoch 95: 100%|██████████| 52/52 [00:00<00:00, 72.99it/s]


(val) -0.4913 (test) -0.4905 [time] 0:01:14.311482



Epoch 96: 100%|██████████| 52/52 [00:00<00:00, 64.68it/s]


(val) -0.4915 (test) -0.4891 [time] 0:01:15.159559



Epoch 97: 100%|██████████| 52/52 [00:00<00:00, 69.84it/s]


(val) -0.4944 (test) -0.4924 [time] 0:01:15.935413



Epoch 98: 100%|██████████| 52/52 [00:00<00:00, 67.92it/s]


(val) -0.4909 (test) -0.4887 [time] 0:01:16.734217



Epoch 99: 100%|██████████| 52/52 [00:00<00:00, 73.01it/s]


(val) -0.4912 (test) -0.4904 [time] 0:01:17.476971



Epoch 100: 100%|██████████| 52/52 [00:00<00:00, 70.87it/s]


(val) -0.4862 (test) -0.4861 [time] 0:01:18.243725
🌸 New best epoch! 🌸



Epoch 101: 100%|██████████| 52/52 [00:00<00:00, 72.16it/s]


(val) -0.4875 (test) -0.4870 [time] 0:01:19.001550



Epoch 102: 100%|██████████| 52/52 [00:00<00:00, 67.69it/s]


(val) -0.4875 (test) -0.4852 [time] 0:01:19.799826



Epoch 103: 100%|██████████| 52/52 [00:00<00:00, 74.23it/s]


(val) -0.4891 (test) -0.4862 [time] 0:01:20.534735



Epoch 104: 100%|██████████| 52/52 [00:00<00:00, 66.36it/s]


(val) -0.4872 (test) -0.4885 [time] 0:01:21.352413



Epoch 105: 100%|██████████| 52/52 [00:00<00:00, 71.79it/s]


(val) -0.4899 (test) -0.4914 [time] 0:01:22.110138



Epoch 106: 100%|██████████| 52/52 [00:00<00:00, 72.52it/s]


(val) -0.4927 (test) -0.4906 [time] 0:01:22.867869



Epoch 107: 100%|██████████| 52/52 [00:00<00:00, 69.12it/s]


(val) -0.4903 (test) -0.4904 [time] 0:01:23.652128



Epoch 108: 100%|██████████| 52/52 [00:00<00:00, 70.45it/s]


(val) -0.4943 (test) -0.4897 [time] 0:01:24.449741



Epoch 109: 100%|██████████| 52/52 [00:00<00:00, 71.11it/s]


(val) -0.4896 (test) -0.4879 [time] 0:01:25.211694



Epoch 110: 100%|██████████| 52/52 [00:00<00:00, 69.39it/s]


(val) -0.4903 (test) -0.4860 [time] 0:01:26.017504



Epoch 111: 100%|██████████| 52/52 [00:00<00:00, 64.18it/s]


(val) -0.4873 (test) -0.4898 [time] 0:01:26.857759



Epoch 112: 100%|██████████| 52/52 [00:00<00:00, 61.56it/s]


(val) -0.4889 (test) -0.4879 [time] 0:01:27.734361



Epoch 113: 100%|██████████| 52/52 [00:00<00:00, 70.21it/s]


(val) -0.4884 (test) -0.4901 [time] 0:01:28.509992



Epoch 114: 100%|██████████| 52/52 [00:00<00:00, 71.43it/s]


(val) -0.4877 (test) -0.4902 [time] 0:01:29.267472



Epoch 115: 100%|██████████| 52/52 [00:00<00:00, 74.95it/s]


(val) -0.4989 (test) -0.4990 [time] 0:01:29.992830



Epoch 116: 100%|██████████| 52/52 [00:00<00:00, 70.65it/s]


(val) -0.4845 (test) -0.4860 [time] 0:01:30.759019
🌸 New best epoch! 🌸



Epoch 117: 100%|██████████| 52/52 [00:00<00:00, 56.12it/s]


(val) -0.4912 (test) -0.4903 [time] 0:01:31.717450



Epoch 118: 100%|██████████| 52/52 [00:00<00:00, 72.29it/s]


(val) -0.4893 (test) -0.4889 [time] 0:01:32.468096



Epoch 119: 100%|██████████| 52/52 [00:00<00:00, 75.06it/s]


(val) -0.4873 (test) -0.4873 [time] 0:01:33.191455



Epoch 120: 100%|██████████| 52/52 [00:00<00:00, 76.09it/s]


(val) -0.4828 (test) -0.4846 [time] 0:01:33.908960
🌸 New best epoch! 🌸



Epoch 121: 100%|██████████| 52/52 [00:00<00:00, 67.28it/s]


(val) -0.4874 (test) -0.4890 [time] 0:01:34.718659



Epoch 122: 100%|██████████| 52/52 [00:00<00:00, 68.44it/s]


(val) -0.4899 (test) -0.4887 [time] 0:01:35.508652



Epoch 123: 100%|██████████| 52/52 [00:00<00:00, 75.65it/s]


(val) -0.4873 (test) -0.4902 [time] 0:01:36.225237



Epoch 124: 100%|██████████| 52/52 [00:00<00:00, 71.32it/s]


(val) -0.4846 (test) -0.4872 [time] 0:01:36.986730



Epoch 125: 100%|██████████| 52/52 [00:00<00:00, 72.19it/s]


(val) -0.4885 (test) -0.4911 [time] 0:01:37.738599



Epoch 126: 100%|██████████| 52/52 [00:00<00:00, 69.73it/s]


(val) -0.4915 (test) -0.4894 [time] 0:01:38.515422



Epoch 127: 100%|██████████| 52/52 [00:00<00:00, 72.31it/s]


(val) -0.4863 (test) -0.4870 [time] 0:01:39.263381



Epoch 128: 100%|██████████| 52/52 [00:00<00:00, 69.78it/s]


(val) -0.4857 (test) -0.4865 [time] 0:01:40.045609



Epoch 129: 100%|██████████| 52/52 [00:00<00:00, 66.91it/s]


(val) -0.4859 (test) -0.4896 [time] 0:01:40.853560



Epoch 130: 100%|██████████| 52/52 [00:00<00:00, 65.15it/s]


(val) -0.4899 (test) -0.4894 [time] 0:01:41.686223



Epoch 131: 100%|██████████| 52/52 [00:00<00:00, 72.18it/s]


(val) -0.4839 (test) -0.4849 [time] 0:01:42.442002



Epoch 132: 100%|██████████| 52/52 [00:00<00:00, 70.34it/s]


(val) -0.4868 (test) -0.4884 [time] 0:01:43.213030



Epoch 133: 100%|██████████| 52/52 [00:00<00:00, 73.37it/s]


(val) -0.4878 (test) -0.4884 [time] 0:01:43.951236



Epoch 134: 100%|██████████| 52/52 [00:00<00:00, 74.12it/s]


(val) -0.4831 (test) -0.4866 [time] 0:01:44.694586



Epoch 135: 100%|██████████| 52/52 [00:00<00:00, 71.96it/s]


(val) -0.4846 (test) -0.4861 [time] 0:01:45.450784



Epoch 136: 100%|██████████| 52/52 [00:00<00:00, 66.47it/s]

(val) -0.4859 (test) -0.4873 [time] 0:01:46.271441


Result:
{'val': -0.48278530806682063, 'test': -0.4845548791750518, 'epoch': 120}



