<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl-revisiting-models/blob/main/package/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_revisiting_models](https://github.com/yandex-research/rtdl-revisiting-models)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install delu
%pip install rtdl_revisiting_models



In [None]:
# ruff: noqa: E402
import math
import warnings
from typing import Dict, Literal

warnings.simplefilter("ignore")
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm
import pandas as pd

warnings.resetwarnings()

from rtdl_revisiting_models import MLP, ResNet, FTTransformer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds in all libraries.
delu.random.seed(0)

0

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Dataset

In [None]:
# >>> Dataset.
TaskType = Literal["regression", "binclass", "multiclass"]

task_type: TaskType = "regression"
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset["data"]
Y: np.ndarray = dataset["target"]

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else None
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == "regression":
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), "Classification labels must form the range [0, 1, ..., n_classes - 1]"

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    "train": {"x_cont": X_cont[train_idx], "y": Y[train_idx]},
    "val": {"x_cont": X_cont[val_idx], "y": Y[val_idx]},
    "test": {"x_cont": X_cont[test_idx], "y": Y[test_idx]},
}
if X_cat is not None:
    data_numpy["train"]["x_cat"] = X_cat[train_idx]
    data_numpy["val"]["x_cat"] = X_cat[val_idx]
    data_numpy["test"]["x_cat"] = X_cat[test_idx]

## Preprocessing

In [None]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy["train"]["x_cont"]
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution="normal",
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]["x_cont"] = preprocessing.transform(data_numpy[part]["x_cont"])

# >>> Label preprocessing.
if task_type == "regression":
    Y_mean = data_numpy["train"]["y"].mean().item()
    Y_std = data_numpy["train"]["y"].std().item()
    for part in data_numpy:
        data_numpy[part]["y"] = (data_numpy[part]["y"] - Y_mean) / Y_std

# >>> Convert data to tensors.
X_train = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/X_train_GR-II.csv")
X_val = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/X_val_GR-II.csv")
X_test = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/X_test_GR-II.csv")
y_train = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/y_train_GR-II.csv")
y_val = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/y_val_GR-II.csv")
y_test = pd.read_csv("/content/drive/MyDrive/GR-II/AGORA/50/y_test_GR-II.csv")

data_numpy = {
    'train': {'x_cont': X_train.values, 'y': y_train.values},
    'val': {'x_cont': X_val.values, 'y': y_val.values},
    'test': {'x_cont': X_test.values, 'y': y_test.values},
}

data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != "multiclass":
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]["y"] = data[part]["y"].float()

## Model

In [None]:
# The output size.
d_out = n_classes if task_type == "multiclass" else 1

# # NOTE: uncomment to train MLP
# model = MLP(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=384,
#     dropout=0.1,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

# # NOTE: uncomment to train ResNet
# model = ResNet(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=192,
#     d_hidden=None,
#     d_hidden_multiplier=2.0,
#     dropout1=0.3,
#     dropout2=0.0,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

model = FTTransformer(
    n_cont_features=424,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    **FTTransformer.get_default_kwargs(),
).to(device)
optimizer = model.make_default_optimizer()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

## Training

In [None]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = (
            [
                F.one_hot(column, cardinality)
                for column, cardinality in zip(batch["x_cat"].T, cat_cardinalities)
            ]
            if "x_cat" in batch
            else []
        )
        return model(torch.column_stack([batch["x_cont"]] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(batch["x_cont"], batch.get("x_cat")).squeeze(-1)

    else:
        raise RuntimeError(f"Unknown model type: {type(model)}")


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == "binclass"
    else F.cross_entropy
    if task_type == "multiclass"
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 4
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]["y"].cpu().numpy()

    if task_type == "binclass":
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == "multiclass":
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == "regression"
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -77.5944


In [None]:
import json

def calculate_and_save_metrics(part: str, save_path: str = "final_metrics.json"):
    model.eval()
    eval_batch_size = 256  # Much more memory-efficient

    y_pred_list = []
    y_true_list = []

    with torch.no_grad():
        for batch in delu.iter_batches(data[part], eval_batch_size):
            preds = apply_model(batch).detach().cpu().numpy()
            y_pred_list.append(preds)
            y_true_list.append(batch["y"].detach().cpu().numpy())

    y_pred = np.concatenate(y_pred_list)
    y_true = np.concatenate(y_true_list)

    metrics = {}

    if task_type == "regression":
        y_pred_rescaled = y_pred * Y_std
        y_true_rescaled = y_true * Y_std
        metrics["rmse"] = float(np.sqrt(sklearn.metrics.mean_squared_error(y_true_rescaled, y_pred_rescaled)))
        metrics["mae"] = float(sklearn.metrics.mean_absolute_error(y_true_rescaled, y_pred_rescaled))
        metrics["r2"] = float(sklearn.metrics.r2_score(y_true_rescaled, y_pred_rescaled))

    elif task_type == "binclass":
        probs = scipy.special.expit(y_pred)
        preds = np.round(probs)
        metrics["accuracy"] = float(sklearn.metrics.accuracy_score(y_true, preds))
        metrics["precision"] = float(sklearn.metrics.precision_score(y_true, preds))
        metrics["recall"] = float(sklearn.metrics.recall_score(y_true, preds))
        metrics["f1"] = float(sklearn.metrics.f1_score(y_true, preds))
        metrics["mcc"] = float(sklearn.metrics.matthews_corrcoef(y_true, preds))
        metrics["auc"] = float(sklearn.metrics.roc_auc_score(y_true, probs))

    elif task_type == "multiclass":
        probs = y_pred
        preds = probs.argmax(1)
        metrics["accuracy"] = float(sklearn.metrics.accuracy_score(y_true, preds))
        metrics["precision"] = float(sklearn.metrics.precision_score(y_true, preds, average="macro"))
        metrics["recall"] = float(sklearn.metrics.recall_score(y_true, preds, average="macro"))
        metrics["f1"] = float(sklearn.metrics.f1_score(y_true, preds, average="macro"))
        metrics["mcc"] = float(sklearn.metrics.matthews_corrcoef(y_true, preds))
        metrics["auc"] = float(sklearn.metrics.roc_auc_score(y_true, probs, multi_class="ovr"))

    with open(save_path, "w") as f:
        json.dump(metrics, f, indent=4)
    return metrics

In [None]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 1000
patience = 16

batch_size = 256
epoch_size = math.ceil(len(X_train) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "test": -math.inf,
    "epoch": -1,
}

print(f"Device: {device.type.upper()}")
print("-" * 88 + "\n")
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data["train"], batch_size, shuffle=True),
        desc=f"Epoch {epoch}",
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch["y"])
        loss.backward()
        optimizer.step()

    val_score = evaluate("val")
    test_score = evaluate("test")
    print(f"(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}")

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_score, "test": test_score, "epoch": epoch}

        torch.save(model.state_dict(), "best_model.pt")
        print("Saved model checkpoint to 'best_model.pt'")


print("\n\nResult:")
print(best)

Device: CUDA
----------------------------------------------------------------------------------------



  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 0: 100%|██████████| 171/171 [00:24<00:00,  6.85it/s]


(val) -72.2723 (test) -71.6909 [time] 0:00:44.274975
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 1: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -70.3951 (test) -69.8161 [time] 0:01:28.660878
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 2: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -68.3805 (test) -67.8043 [time] 0:02:13.235441
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 3: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -66.2592 (test) -65.6863 [time] 0:02:57.568919
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 4: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -64.0562 (test) -63.4871 [time] 0:03:42.000805
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 5: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -61.7827 (test) -61.2180 [time] 0:04:26.177425
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 6: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -59.4642 (test) -58.9046 [time] 0:05:10.584537
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 7: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -57.1092 (test) -56.5555 [time] 0:05:54.834167
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 8: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -54.7331 (test) -54.1863 [time] 0:06:39.013476
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 9: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -52.3555 (test) -51.8168 [time] 0:07:23.272996
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 10: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -49.9931 (test) -49.4636 [time] 0:08:07.677394
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 11: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -47.6616 (test) -47.1429 [time] 0:08:51.984878
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 12: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -45.3820 (test) -44.8760 [time] 0:09:36.504745
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 13: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -43.1742 (test) -42.6828 [time] 0:10:20.744426
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 14: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -41.0556 (test) -40.5812 [time] 0:11:04.928238
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 15: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -39.0524 (test) -38.5976 [time] 0:11:49.187067
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 16: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -37.1780 (test) -36.7454 [time] 0:12:33.446876
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 17: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -35.4533 (test) -35.0460 [time] 0:13:17.927576
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 18: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -33.9037 (test) -33.5244 [time] 0:14:02.353799
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 19: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -32.5290 (test) -32.1806 [time] 0:14:47.103877
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 20: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -31.3480 (test) -31.0324 [time] 0:15:31.306862
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 21: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -30.3576 (test) -30.0763 [time] 0:16:15.485176
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 22: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -29.5489 (test) -29.3026 [time] 0:16:59.750378
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 23: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -28.9140 (test) -28.7021 [time] 0:17:43.964760
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 24: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -28.4364 (test) -28.2571 [time] 0:18:28.131032
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 25: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -28.0882 (test) -27.9390 [time] 0:19:12.298963
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 26: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.8442 (test) -27.7220 [time] 0:19:57.030334
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 27: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.6821 (test) -27.5831 [time] 0:20:41.326320
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 28: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.5786 (test) -27.4992 [time] 0:21:25.528614
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 29: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.5163 (test) -27.4525 [time] 0:22:09.785705
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 30: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4812 (test) -27.4291 [time] 0:22:54.152755
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 31: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4608 (test) -27.4179 [time] 0:23:38.485433
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 32: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4496 (test) -27.4135 [time] 0:24:22.862514
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 33: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4441 (test) -27.4124 [time] 0:25:07.004205
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 34: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4413 (test) -27.4125 [time] 0:25:51.149501
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 35: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4400 (test) -27.4127 [time] 0:26:35.399644
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 36: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4387 (test) -27.4133 [time] 0:27:19.593718
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 37: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4381 (test) -27.4137 [time] 0:28:03.766850
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 38: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4378 (test) -27.4139 [time] 0:28:48.095722
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 39: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4378 (test) -27.4139 [time] 0:29:32.223948
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 40: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4380 (test) -27.4138 [time] 0:30:16.474860


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 41: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4383 (test) -27.4135 [time] 0:31:00.974492


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 42: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4377 (test) -27.4140 [time] 0:31:45.238397
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 43: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4386 (test) -27.4134 [time] 0:32:29.436177


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 44: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4382 (test) -27.4136 [time] 0:33:13.611227


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 45: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4378 (test) -27.4139 [time] 0:33:57.725066


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 46: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4381 (test) -27.4137 [time] 0:34:41.823987


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 47: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4383 (test) -27.4136 [time] 0:35:25.988319


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 48: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4388 (test) -27.4132 [time] 0:36:10.215815


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 49: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4374 (test) -27.4143 [time] 0:36:54.612278
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 50: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4374 (test) -27.4144 [time] 0:37:38.847771
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 51: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4376 (test) -27.4141 [time] 0:38:23.240970


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 52: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4390 (test) -27.4131 [time] 0:39:07.710713


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 53: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4378 (test) -27.4139 [time] 0:39:51.881012


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 54: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4370 (test) -27.4148 [time] 0:40:36.218296
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 55: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4374 (test) -27.4144 [time] 0:41:20.484016


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 56: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4374 (test) -27.4144 [time] 0:42:04.936509


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 57: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4376 (test) -27.4142 [time] 0:42:49.245651


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 58: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4386 (test) -27.4133 [time] 0:43:33.860576


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 59: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4390 (test) -27.4131 [time] 0:44:18.360189


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 60: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4388 (test) -27.4132 [time] 0:45:02.818741


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 61: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4374 (test) -27.4143 [time] 0:45:46.852793


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 62: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4375 (test) -27.4142 [time] 0:46:32.045912


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 63: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4378 (test) -27.4139 [time] 0:47:17.338524


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 64: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4382 (test) -27.4136 [time] 0:48:02.246868


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 65: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4370 (test) -27.4148 [time] 0:48:47.069298
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 66: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4372 (test) -27.4146 [time] 0:49:31.679606


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 67: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4385 (test) -27.4134 [time] 0:50:16.541002


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 68: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4376 (test) -27.4141 [time] 0:51:01.443832


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 69: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4384 (test) -27.4135 [time] 0:51:45.425131


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 70: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4368 (test) -27.4152 [time] 0:52:29.579772
🌸 New best epoch! 🌸
Saved model checkpoint to 'best_model.pt'


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 71: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4375 (test) -27.4142 [time] 0:53:13.892843


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 72: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4388 (test) -27.4132 [time] 0:53:58.184916


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 73: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4372 (test) -27.4145 [time] 0:54:42.317477


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 74: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4386 (test) -27.4133 [time] 0:55:26.755435


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 75: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4375 (test) -27.4142 [time] 0:56:11.065957


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 76: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4389 (test) -27.4132 [time] 0:56:55.320858


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 77: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4372 (test) -27.4146 [time] 0:57:39.604985


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 78: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4383 (test) -27.4135 [time] 0:58:24.018569


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 79: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4371 (test) -27.4147 [time] 0:59:08.544419


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 80: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4384 (test) -27.4134 [time] 0:59:52.913944


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 81: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4381 (test) -27.4137 [time] 1:00:37.172680


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 82: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4390 (test) -27.4131 [time] 1:01:21.716735


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 83: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4370 (test) -27.4149 [time] 1:02:05.917728


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 84: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4383 (test) -27.4135 [time] 1:02:50.177731


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 85: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4397 (test) -27.4128 [time] 1:03:34.649816


  loss = loss_fn(apply_model(batch), batch["y"])
  loss = loss_fn(apply_model(batch), batch["y"])
Epoch 86: 100%|██████████| 171/171 [00:24<00:00,  6.86it/s]


(val) -27.4389 (test) -27.4132 [time] 1:04:18.858309


Result:
{'val': -27.436811169404617, 'test': -27.415170845441335, 'epoch': 70}


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Restore best model
model.load_state_dict(torch.load("best_model.pt"))

# # Calculate and save final metrics
final_metrics = calculate_and_save_metrics("test", save_path="final_metrics.json")
print("Final Test Metrics:", final_metrics)

Final Test Metrics: {'rmse': 27.415169007307227, 'mae': 21.70148468017578, 'r2': -0.0002052783966064453}
