Skip to content

Help: Custom loss and custom metric implementation #314

@LuisFerTR

Description

@LuisFerTR

Hi, I want to implement median absolute error as a loss function and as a metric to train FT-Transformer.

My code for median absolute error is as follows:

import torch

# Loss function
class MedianAbsoluteErrorLoss(torch.nn.Module):
    def __init__(self):
        super(MedianAbsoluteError, self).__init__()

    def forward(self, y_pred, y_true):
        # Calculate MedAE between true values and predicted values
        errors = torch.abs(y_true - y_pred)
        median_error = torch.median(errors)

        return median_error

# Metric    
def median_absolute_error(y_pred, y_true):
    # Calcular la MAE entre los valores reales y los valores previstos
    errors = torch.abs(y_true - y_pred)
    median_error = torch.median(errors)

    return median_error

And the model definition is as follows:

import numpy as np
from pytorch_tabular import TabularModel
from pytorch_tabular.models import FTTransformerConfig
    
params = {
    'num_attn_blocks': 16,
    'attn_dropout': 0.3407836558733435,
    'add_norm_dropout': 0.11142324367590875,
    'ff_dropout': 0.35890766953158,
    'learning_rate': 0.002034424182751788
}
    
# Define model
model_config = FTTransformerConfig(
    task="regression",
    head = "LinearHead", #Linear Head
    head_config = head_config, # Linear Head Config
    target_range = [(1, 10)], # Range of Mohs Hardness scale
    **params
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)

tabular_model.fit(train=train_df,               
    loss=MedianAbsoluteErrorLoss,
    metrics=[median_absolute_error])

But I'm getting this error:

Cell In[26], line 29
     14 model_config = FTTransformerConfig(
     15     task="regression",
     16     head = "LinearHead", #Linear Head
   (...)
     19     **params
     20 )
     22 tabular_model = TabularModel(
     23     data_config=data_config,
     24     model_config=model_config,
     25     optimizer_config=optimizer_config,
     26     trainer_config=trainer_config,
     27 )
---> 29 tabular_model.fit(train=train_df,               
     30     loss=MedianAbsoluteErrorLoss,
     31     metrics=[median_absolute_error],)

File /opt/conda/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:667, in TabularModel.fit(self, train, validation, test, loss, metrics, metrics_prob_inputs, optimizer, optimizer_params, train_sampler, target_transform, max_epochs, min_epochs, seed, callbacks, datamodule)
    663 assert (
    664     self.config.task != "ssl"
    665 ), "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
    666 if metrics is not None:
--> 667     assert len(metrics) == len(
    668         metrics_prob_inputs
    669     ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
    670 seed = seed if seed is not None else self.config.seed
    671 seed_everything(seed)

TypeError: object of type 'NoneType' has no len()

I feel that the error is something simple to see but I can't find it, I would appreciate any help.

Regards.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions