Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TorchMetrics loop implementation isn't compatible with stateful metrics #2390

Closed
tRosenflanz opened this issue May 20, 2024 · 1 comment · Fixed by #2391
Closed

[BUG] TorchMetrics loop implementation isn't compatible with stateful metrics #2390

tRosenflanz opened this issue May 20, 2024 · 1 comment · Fixed by #2391
Labels
improvement New feature or improvement

Comments

@tRosenflanz
Copy link
Contributor

tRosenflanz commented May 20, 2024

Describe the bug
Current metrics loop isn't valid for stateful metrics.
Current metrics calculation boils down to:

  1. Do metrics's forward on the batch 2. Mean the batch results at the end of epoch/evaluation.

This breaks the logic of the metrics' that shouldn't be computed per batch. E.g. My custom metric is Median MAE across the whole validation set. In Darts loop it is essentially replaced by mean median MAE of the batches which is a different stat - I have noticed a considerable divergence in results.

To Reproduce
Here is a simple example of such stateful metric

class Collector(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("preds", default=[], dist_reduce_fx="cat")


    def update(self, preds, target) -> None:
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")
        self.preds.append(preds)

    def concat(self):
        return dim_zero_cat(self.preds),dim_zero_cat(self.targets)


#dummy tracker
class Non0Preds(Collector):
    def compute(self):
        preds,targets = self.concat()
        return sum(preds!=0)

Use in the prediction, observe that the metric is around the batch size rather than the size of the validation or train set as it should be. This is obviously a dummy example that isreverse engineerable. Another example would be any ordering metric (medians, AUROC, etc.) where you can't reverse engineer the correct value.

Expected behavior
A better pattern is:

  1. Call metrics .update at every step
  2. Call .compute at the end of the epoch/evaluation
  3. Call .reset on the metric to reset its state (or make sure that users call it in compute)

This allows the user to actually pass any metric and get the correct results AND is far more computationally effective since it avoids the internal state manipulations that Metric.forward does by default.
Note that implementing update and compute is the recommended by TorchMetrics itself and you are not advised to mess with forward

See the logic here for an example of implementation that should work for validation and can be replicated for training:
https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metriccollection

System (please complete the following information):

  • Python version: [ 3.10]
  • darts version [0.29.0]
@tRosenflanz tRosenflanz added bug Something isn't working triage Issue waiting for triaging labels May 20, 2024
@dennisbader
Copy link
Collaborator

Hi @tRosenflanz, and thanks for raising this issue. I like your proposed solution. Would you like to contribute to this?

@madtoinou madtoinou added improvement New feature or improvement and removed bug Something isn't working triage Issue waiting for triaging labels May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement New feature or improvement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants