-
-
Notifications
You must be signed in to change notification settings - Fork 161
Closed
Description
Hi, I want to implement median absolute error as a loss function and as a metric to train FT-Transformer.
My code for median absolute error is as follows:
import torch
# Loss function
class MedianAbsoluteErrorLoss(torch.nn.Module):
def __init__(self):
super(MedianAbsoluteError, self).__init__()
def forward(self, y_pred, y_true):
# Calculate MedAE between true values and predicted values
errors = torch.abs(y_true - y_pred)
median_error = torch.median(errors)
return median_error
# Metric
def median_absolute_error(y_pred, y_true):
# Calcular la MAE entre los valores reales y los valores previstos
errors = torch.abs(y_true - y_pred)
median_error = torch.median(errors)
return median_error
And the model definition is as follows:
import numpy as np
from pytorch_tabular import TabularModel
from pytorch_tabular.models import FTTransformerConfig
params = {
'num_attn_blocks': 16,
'attn_dropout': 0.3407836558733435,
'add_norm_dropout': 0.11142324367590875,
'ff_dropout': 0.35890766953158,
'learning_rate': 0.002034424182751788
}
# Define model
model_config = FTTransformerConfig(
task="regression",
head = "LinearHead", #Linear Head
head_config = head_config, # Linear Head Config
target_range = [(1, 10)], # Range of Mohs Hardness scale
**params
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
tabular_model.fit(train=train_df,
loss=MedianAbsoluteErrorLoss,
metrics=[median_absolute_error])
But I'm getting this error:
Cell In[26], line 29
14 model_config = FTTransformerConfig(
15 task="regression",
16 head = "LinearHead", #Linear Head
(...)
19 **params
20 )
22 tabular_model = TabularModel(
23 data_config=data_config,
24 model_config=model_config,
25 optimizer_config=optimizer_config,
26 trainer_config=trainer_config,
27 )
---> 29 tabular_model.fit(train=train_df,
30 loss=MedianAbsoluteErrorLoss,
31 metrics=[median_absolute_error],)
File /opt/conda/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:667, in TabularModel.fit(self, train, validation, test, loss, metrics, metrics_prob_inputs, optimizer, optimizer_params, train_sampler, target_transform, max_epochs, min_epochs, seed, callbacks, datamodule)
663 assert (
664 self.config.task != "ssl"
665 ), "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
666 if metrics is not None:
--> 667 assert len(metrics) == len(
668 metrics_prob_inputs
669 ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
670 seed = seed if seed is not None else self.config.seed
671 seed_everything(seed)
TypeError: object of type 'NoneType' has no len()
I feel that the error is something simple to see but I can't find it, I would appreciate any help.
Regards.
Metadata
Metadata
Assignees
Labels
No labels