diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 1875b905..1f1ded60 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -664,7 +664,7 @@ def fit( ), "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning" if metrics is not None: assert len(metrics) == len( - metrics_prob_inputs + metrics_prob_inputs or [] ), "The length of `metrics` and `metrics_prob_inputs` should be equal" seed = seed or self.config.seed if seed: