You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Current metrics loop isn't valid for stateful metrics.
Current metrics calculation boils down to:
Do metrics's forward on the batch 2. Mean the batch results at the end of epoch/evaluation.
This breaks the logic of the metrics' that shouldn't be computed per batch. E.g. My custom metric is Median MAE across the whole validation set. In Darts loop it is essentially replaced by mean median MAE of the batches which is a different stat - I have noticed a considerable divergence in results.
To Reproduce
Here is a simple example of such stateful metric
class Collector(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("preds", default=[], dist_reduce_fx="cat")
def update(self, preds, target) -> None:
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")
self.preds.append(preds)
def concat(self):
return dim_zero_cat(self.preds),dim_zero_cat(self.targets)
#dummy tracker
class Non0Preds(Collector):
def compute(self):
preds,targets = self.concat()
return sum(preds!=0)
Use in the prediction, observe that the metric is around the batch size rather than the size of the validation or train set as it should be. This is obviously a dummy example that isreverse engineerable. Another example would be any ordering metric (medians, AUROC, etc.) where you can't reverse engineer the correct value.
Expected behavior
A better pattern is:
Call metrics .update at every step
Call .compute at the end of the epoch/evaluation
Call .reset on the metric to reset its state (or make sure that users call it in compute)
This allows the user to actually pass any metric and get the correct results AND is far more computationally effective since it avoids the internal state manipulations that Metric.forward does by default.
Note that implementing update and compute is the recommended by TorchMetrics itself and you are not advised to mess with forward
Describe the bug
Current metrics loop isn't valid for stateful metrics.
Current metrics calculation boils down to:
This breaks the logic of the metrics' that shouldn't be computed per batch. E.g. My custom metric is Median MAE across the whole validation set. In Darts loop it is essentially replaced by mean median MAE of the batches which is a different stat - I have noticed a considerable divergence in results.
To Reproduce
Here is a simple example of such stateful metric
Use in the prediction, observe that the metric is around the batch size rather than the size of the validation or train set as it should be. This is obviously a dummy example that isreverse engineerable. Another example would be any ordering metric (medians, AUROC, etc.) where you can't reverse engineer the correct value.
Expected behavior
A better pattern is:
.update
at every step.compute
at the end of the epoch/evaluationThis allows the user to actually pass any metric and get the correct results AND is far more computationally effective since it avoids the internal state manipulations that Metric.forward does by default.
Note that implementing
update
andcompute
is the recommended by TorchMetrics itself and you are not advised to mess withforward
See the logic here for an example of implementation that should work for validation and can be replicated for training:
https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metriccollection
System (please complete the following information):
The text was updated successfully, but these errors were encountered: