Skip to content

Conversation

01-vyom
Copy link
Contributor

@01-vyom 01-vyom commented Jun 2, 2021

Fixes #1415

Description:
Similar to Metric, Loss metric now supports required_output_keys.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Jun 2, 2021
@01-vyom
Copy link
Contributor Author

01-vyom commented Jun 2, 2021

Following is the demo code to check the validity of the above change:

import torch
import torch.nn as nn
from torch.nn.functional import nll_loss

from ignite.metrics import Accuracy, Loss
from ignite.engine import create_supervised_evaluator

model = nn.Linear(10, 3)

metrics = {
    "Accuracy": Accuracy(),
    "Loss": Loss(nll_loss)
}

# global criterion kwargs
criterion_kwargs = {"reduction": 'sum'}
# criterion_kwargs = {}

evaluator = create_supervised_evaluator(
    model,
    metrics=metrics,
    output_transform=lambda x, y, y_pred: {
        "x": x, "y": y, "y_pred": y_pred, "criterion_kwargs": criterion_kwargs}
)
data = [
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, )))
]
res = evaluator.run(data)

As the required_output_keys contains criterion_kwargs, the user has to pass an empty dictionary for the original case with no criterion_kwargs for criterion.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 2, 2021

@01-vyom thanks for the draft PR !
Let's now add some tests here:

Instead of DummyMetric1, we can create a metric derived from Loss and override update method to check passed args similarly to here:

def update(self, output):
assert output == self.true_output

Please let me know if need more explanations.

We can also add your above demo example as an integration test to https://github.com/pytorch/ignite/blob/master/tests/ignite/metrics/test_loss.py

@01-vyom
Copy link
Contributor Author

01-vyom commented Jun 3, 2021

Ok, I will add a test_output_mapping as well as all other none_keys and wrong_keys test with a dummy Loss and an integration test similar to here:

def test_override_required_output_keys():

@01-vyom 01-vyom changed the title [WIP] ENH: Updated Loss metric to use required_output_keys ENH: Updated Loss metric to use required_output_keys Jun 3, 2021
@01-vyom 01-vyom requested a review from vfdev-5 June 3, 2021 19:39
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @01-vyom !

@vfdev-5 vfdev-5 enabled auto-merge (squash) June 3, 2021 20:11
@vfdev-5 vfdev-5 merged commit 786aea8 into pytorch:master Jun 3, 2021
@01-vyom 01-vyom deleted the loss-metric-1415 branch June 3, 2021 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loss metric to use required_output_keys
3 participants