Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Precision, Recall, add Accuracy (Binary and Categorical combined) #275

Merged
merged 24 commits into from
Nov 7, 2018

Conversation

anmolsjoshi
Copy link
Contributor

Fixes #262

Description:
This PR updates Precision and Recall, and adds Accuracy to handle binary and categorical cases for different types of input.

Check list:

  • New tests are added.
  • Updated doc string RST format
  • Edited metrics.rst to add information about Accuracy.

ignite/metrics/accuracy.py Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 25, 2018

@anmolsjoshi thanks a lot for taking this !
Just a small precision on how to handle old BinaryAccuracy and CategoricalAccuracy:

merge the classes into Accuracy and still keep BinaryAccuracy and CategoricalAccuracy and have their constructors just create the appropriate instance of Accuracy and throw a warning that these will be removed in 0.1.2

@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 could you explain merging the two classes into Accuracy? Sorry I'm unfamiliar with that. Do you mean creating binary update and categorical update functions, and calling them appropriately?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 25, 2018

@anmolsjoshi no problem, they could be something like this

@anmolsjoshi
Copy link
Contributor Author

Thanks for the example. Will get this worked tomorrow.

@anmolsjoshi
Copy link
Contributor Author

anmolsjoshi commented Sep 26, 2018

@vfdev-5 could you clarify this test ?

y has 4 elements with shape of (1, 2, 2), while y_pred has 8 elements with shape of (2, 2, 2). correct is equal to tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.uint8)

Given the docs, y_pred should be of shape (batch_size, ...) or (batch_size, num_classes, ...), which is not the case here. Given that this is a test for binary accuracy, shouldn't y_pred have a shape (1, 2, 2) or (1, 2, 2, 2). If it is the second case, this test should be handled in CategoricalAccuracy.

Or am I understanding this incorrectly?

Also, I tried using the existing CategoricalAccuracy.

acc = CategoricalAccuracy()
y_pred = torch.FloatTensor([[[0.3, 0.7],
                             [0.9, 0.1]],
                            [[0.2, 0.8],
                             [0.4, 0.6]]])
y = torch.ones(1, 2, 2).type(torch.LongTensor)
acc.update((y_pred.unsqueeze(0), y))
acc.compute() # ---> 0.5

I'm not sure the answer should be 0.5, taking a max of each index of y_pred, the result should be [[1, 0], [1, 1]] resulting in 0.75. I think the issue is with how correct is calculated for y_pred with dimensions higher than 2. I think the dimension along which maximum is calculated needs to be changed. I think it could be solved by changing the axis to -1.

Thoughts?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 26, 2018

@anmolsjoshi this is the case to compute accuracy between images (2x2) (segmentation task).

See #128

Given the docs, y_pred should be of shape (batch_size, ...) or (batch_size, num_classes, ...), which is not the case here. Given that this is a test for binary accuracy, shouldn't y_pred have a shape (1, 2, 2) or (1, 2, 2, 2).

Yes, maybe batch dimension is missing here

@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 I have updated all the tests appropriately and incorporated your comments. Please let me know your thoughts!

@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 any suggestions on the codecov failures?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 28, 2018

any suggestions on the codecov failures?

@anmolsjoshi take a look here. In red the lines that are not covered by tests (code is not executed).

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 28, 2018

@anmolsjoshi I was thinking on your implementation of unified binary and categorical accuracy and maybe we can unify it more (I'm not a fan of huge if cases like if y.ndimension() > 1: and if is_categorical:). Let me explain:
We would like to cover the following cases:

Binary classification/regression:

  • y_pred can be (batch_size, )
  • y_true can be (batch_size, )

Binary segmentation on ND array:

  • y_pred can be (batch_size, H, W, ...)
  • y_true can be (batch_size, H, W, ...)

Categorical classification/regression:

  • y_pred can be (batch_size, C)
  • y_true can be (batch_size, )

Categorical segmentation on ND array:

  • y_pred can be (batch_size, C, H, W, ...)
  • y_true can be (batch_size, H, W, ...)

As you may already noticed that the only thing that change in current implementations is indices:
Here is the update method of binary accuracy (indices variable is made explicit):

    def update(self, output):
        y_pred, y = output
        indices = torch.round(y_pred).type(y.type())
        correct = torch.eq(indices, y).view(-1)
        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

and here is the update method of categorical accuracy:

    def update(self, output):
        y_pred, y = output
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y).view(-1)
        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

So, there are questions for unified update function:

  • how to validate the input data ?
  • how to understand in which case we are: categorical or binary ?
  • how to compute indices in an unified way ?
  1. how to validate the input data ?

Following above supported case, we can check that

assert y.ndimension() >= 1 and y_pred.ndimension() >= 1
assert y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()
y_shape = y.shape
y_pred_shape = y_pred.shape
if y.ndimension() + 1 == y_pred.ndimension():
    y_pred_shape = (y_pred_shape[0], ) + y_pred_shape[2:]
assert y_shape == y_pred_shape
  1. how to understand in which case we are: categorical or binary

I think this is just:

  • if y.ndimension() == y_pred.ndimension() -> binary
  • if y.ndimension() + 1 == y_pred.ndimension() -> categorical
  1. how to compute indices in an unified way ?

Here I propose to map the binary case into 2 classes categorical:

y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)

Sure that in terms of performances we need to benchmark this.

And finally, in the ideal way we should also store somehow what type of data (categorical or binary) was used in the previous update and raise an error if user mixes the types.

What do you think about ?

@jasonkriss
Copy link
Contributor

@vfdev-5 I like that a lot.

@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 thank you so much for your input. i agree with your input and will get started on these changes. I'm definitely trying to get a handle of programming and feedback like this really helpful.

@rpatrik96
Copy link
Contributor

@anmolsjoshi @vfdev-5
Hi, are you concerning also tasks with multiple target labels? I currently have a problem where it could come handy and I am eager to help.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 29, 2018

@rpatrik96, thanks for the proposition! What kind of metrics do you think would be useful with multi-label targets?

@rpatrik96
Copy link
Contributor

rpatrik96 commented Sep 29, 2018

@vfdev-5
In my opinion, it would be great if the Precision, Recall could be used for these tasks too BinaryAccuracy is OK for them. I wonder how to generalize the current metrics.
The easiest would be to specify a flag for multitarget, but maybe it can be done automagically, I mean if we have multiple labels, the target is a one-hot vector, we should compare only for that - or a bit more generally, we can construct the confusion matrix and deduce the metrics from that. What do you think?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 29, 2018

@rpatrik96 how would you threshold y_pred to choose which tag are on and which off ?

@rpatrik96
Copy link
Contributor

rpatrik96 commented Sep 29, 2018

@vfdev-5 I think the following way would be logical: if y_true is a one-hot vector with N 1s in it, then I would sort the y_pred vector and use the top N entries like in the case of TopKCategoricalAccuracy, but this time the mapping is not many-to-one, but many-to-many.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 29, 2018

@rpatrik96 maybe we could provide a custom function to override torch.round when transforming probas y_pred to predicted tags and then continue as in scikit-learn ?

@rpatrik96
Copy link
Contributor

@vfdev-5
I like your proposal, just one question : why would we need to override torch.round? That part is not crystal clear for me.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 30, 2018

@rpatrik96 its a basic decision rule : proba to predictions. round([0.1, 0.6, 0.7, 0.3]) -> [0, 1, 1, 0]

@rpatrik96
Copy link
Contributor

@vfdev-5 Oh, fine, I thought there is something more complicated for it :). So how should we proceed?

@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 thanks for the review! I'll write a few more tests as you have suggested.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 30, 2018

Thanks @anmolsjoshi , just few remarks if you could add them please:
Checking tests from travis, could you please replace, CategoricalAccuracy in this test
and in all mnist examples, otherwise warnings are raised.

There is a problem with a sklearn test:

tests/ignite/metrics/test_precision.py::test_sklearn_compare
  /home/travis/miniconda/envs/test-environment/lib/python2.7/site-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
    'precision', 'predicted', average, warn_for)

You can check it here

…d mnist examples and progess bar test with correct Accuracy
@anmolsjoshi
Copy link
Contributor Author

@vfdev-5 updated the tests to prevent that sklearn warning, updated mnist and test_pbar also.

Getting depreciation warnings, hope that its ok since it is the point of this PR.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 31, 2018

@anmolsjoshi thanks !

Getting depreciation warnings, hope that its ok since it is the point of this PR.

Yes, this is OK.

We are almost done, concerning sklearn tests, I thought about to add them to all tests : test_binary_compute, test_binary_compute_batch_images, test_categorical_compute, test_categorical_compute_batch_images. Let me add them by myself.

@anmolsjoshi
Copy link
Contributor Author

anmolsjoshi commented Oct 31, 2018

@vfdev-5 feel free to use this latest commit, I can remove it. but it replaces all the accuracy tests to compare with sklearn and using random values for y_pred. I can extend this to precision and recall if needed.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 31, 2018

Awesome @anmolsjoshi ! Haven't yet time to do what you just committed. If you can extend to precision and recall too, would be really great

@anmolsjoshi
Copy link
Contributor Author

anmolsjoshi commented Oct 31, 2018

@vfdev-5 sounds good. will get that done. I'm going to remove the sklearn specific test from bdf9a85 commit because now all tests use sklearn.

@anmolsjoshi
Copy link
Contributor Author

anmolsjoshi commented Oct 31, 2018

@vfdev-5 I have changed most tests of Precision and Recall to add a comparison with sklearn, I kept the naive examples because they are good sanity checks. let me know if any changes are needed.

I think we should add some information in the doc string about Precision and Recall not being batch wise metrics.

Basically, it shouldn't be used for the Progress Bar or as a Running Average and should only be computed at the end of epoch. Happy to discuss!

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 31, 2018

@anmolsjoshi thanks a lot ! LGTM

I think we should add some information in the doc string about Precision and Recall not being batch wise metrics.

I think this better to be added to RunningAverage and even there what happens is that metric is computed batchwise and then smoothed.

@jasonkriss or @alykhantejani please review and merge if OK

Copy link
Contributor

@jasonkriss jasonkriss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very thorough. Great work on this @anmolsjoshi!

I think we can maybe extract some common logic from theses metrics at some point but we don't need to worry about that now.

LGTM (pending the warning comment). @alykhantejani I'll let you merge when ready.

ignite/metrics/binary_accuracy.py Outdated Show resolved Hide resolved
@anmolsjoshi
Copy link
Contributor Author

anmolsjoshi commented Nov 1, 2018

@vfdev-5 @jasonkriss is there a way to ignore the W504 flake8 error? If I kept the warning as one line, it would be too long and throw another error lol.

EDIT: fixed it, although its worth considering.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 1, 2018

@anmolsjoshi I think it is OK like you did with a line break inside a string.

@jasonkriss
Copy link
Contributor

Yea looks good to me. Thanks @anmolsjoshi!

@alykhantejani
Copy link
Contributor

This PR looks great, thanks @anmolsjoshi! Thanks Victor for the thorough review

@alykhantejani alykhantejani merged commit cca0b99 into pytorch:master Nov 7, 2018
@anmolsjoshi anmolsjoshi deleted the updatemetrics branch January 8, 2019 19:39
@vfdev-5 vfdev-5 mentioned this pull request Jan 15, 2019
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants