You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When using additional metrics with PyTorch Lightning forecasting models, due to PLForecastingModule._calculate_metrics (called from train_step() or validation_step()) not calling self.log_dict(...) with the batch_size= parameter, the Trainer (_ResultCollection.log() is complaining about ambiguity (when show_warnings=True is enabled on the model):
/opt/homebrew/Caskroom/miniconda/base/envs/tf/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the batch_size from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use self.log(..., batch_size=batch_size).
To Reproduce
Add a custom metric via torch_metrics=... model parameter (e.g. BlockRNNModel), and run .fit().
Expected behavior
The warnings should be gone.
System (please complete the following information):
def_calculate_metrics(self, output, target, metrics):
ifnotlen(metrics):
returnifself.likelihood:
_metric=metrics(self.likelihood.sample(output), target)
else:
# If there's no likelihood, nr_params=1, and we need to squeeze out the# last dimension of model output, for properly computing the metric._metric=metrics(output.squeeze(dim=-1), target)
self.log_dict(
_metric,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=True,
sync_dist=True,
batch_size=target.shape[0], # ADD THIS LINE
)
The text was updated successfully, but these errors were encountered:
Describe the bug
When using additional metrics with PyTorch Lightning forecasting models, due to
PLForecastingModule._calculate_metrics
(called fromtrain_step()
orvalidation_step()
) not callingself.log_dict(...)
with thebatch_size=
parameter, the Trainer (_ResultCollection.log()
is complaining about ambiguity (whenshow_warnings=True
is enabled on the model):To Reproduce
Add a custom metric via
torch_metrics=...
model parameter (e.g. BlockRNNModel), and run .fit().Expected behavior
The warnings should be gone.
System (please complete the following information):
Additional context
suggested code modification (tested):
The text was updated successfully, but these errors were encountered: