-
-
Notifications
You must be signed in to change notification settings - Fork 605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Precision, Recall, add Accuracy (Binary and Categorical combined) #275
Conversation
…dated Precision and Recall similarly
…dated Precision and Recall similarly
@anmolsjoshi thanks a lot for taking this !
|
@vfdev-5 could you explain merging the two classes into Accuracy? Sorry I'm unfamiliar with that. Do you mean creating binary update and categorical update functions, and calling them appropriately? |
@anmolsjoshi no problem, they could be something like this |
Thanks for the example. Will get this worked tomorrow. |
@vfdev-5 could you clarify this test ? y has 4 elements with shape of (1, 2, 2), while y_pred has 8 elements with shape of (2, 2, 2). correct is equal to Given the docs, y_pred should be of shape (batch_size, ...) or (batch_size, num_classes, ...), which is not the case here. Given that this is a test for binary accuracy, shouldn't y_pred have a shape (1, 2, 2) or (1, 2, 2, 2). If it is the second case, this test should be handled in CategoricalAccuracy. Or am I understanding this incorrectly? Also, I tried using the existing CategoricalAccuracy. acc = CategoricalAccuracy()
y_pred = torch.FloatTensor([[[0.3, 0.7],
[0.9, 0.1]],
[[0.2, 0.8],
[0.4, 0.6]]])
y = torch.ones(1, 2, 2).type(torch.LongTensor)
acc.update((y_pred.unsqueeze(0), y))
acc.compute() # ---> 0.5 I'm not sure the answer should be 0.5, taking a max of each index of y_pred, the result should be [[1, 0], [1, 1]] resulting in 0.75. I think the issue is with how Thoughts? |
@anmolsjoshi this is the case to compute accuracy between images (2x2) (segmentation task). See #128
Yes, maybe batch dimension is missing here |
@vfdev-5 I have updated all the tests appropriately and incorporated your comments. Please let me know your thoughts! |
… Accuracy metric, updated tests
@vfdev-5 any suggestions on the codecov failures? |
@anmolsjoshi take a look here. In red the lines that are not covered by tests (code is not executed). |
@anmolsjoshi I was thinking on your implementation of unified binary and categorical accuracy and maybe we can unify it more (I'm not a fan of huge if cases like Binary classification/regression:
Binary segmentation on ND array:
Categorical classification/regression:
Categorical segmentation on ND array:
As you may already noticed that the only thing that change in current implementations is def update(self, output):
y_pred, y = output
indices = torch.round(y_pred).type(y.type())
correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_examples += correct.shape[0] and here is the update method of categorical accuracy: def update(self, output):
y_pred, y = output
indices = torch.max(y_pred, 1)[1]
correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_examples += correct.shape[0] So, there are questions for unified update function:
Following above supported case, we can check that assert y.ndimension() >= 1 and y_pred.ndimension() >= 1
assert y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()
y_shape = y.shape
y_pred_shape = y_pred.shape
if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0], ) + y_pred_shape[2:]
assert y_shape == y_pred_shape
I think this is just:
Here I propose to map the binary case into 2 classes categorical: y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1) Sure that in terms of performances we need to benchmark this. And finally, in the ideal way we should also store somehow what type of data (categorical or binary) was used in the previous What do you think about ? |
@vfdev-5 I like that a lot. |
@vfdev-5 thank you so much for your input. i agree with your input and will get started on these changes. I'm definitely trying to get a handle of programming and feedback like this really helpful. |
@anmolsjoshi @vfdev-5 |
@rpatrik96, thanks for the proposition! What kind of metrics do you think would be useful with multi-label targets? |
@vfdev-5 |
@rpatrik96 how would you threshold |
@vfdev-5 I think the following way would be logical: if |
@rpatrik96 maybe we could provide a custom function to override |
@vfdev-5 |
@rpatrik96 its a basic decision rule : proba to predictions. |
@vfdev-5 Oh, fine, I thought there is something more complicated for it :). So how should we proceed? |
@vfdev-5 thanks for the review! I'll write a few more tests as you have suggested. |
Thanks @anmolsjoshi , just few remarks if you could add them please: There is a problem with a sklearn test:
You can check it here |
…d mnist examples and progess bar test with correct Accuracy
@vfdev-5 updated the tests to prevent that sklearn warning, updated mnist and test_pbar also. Getting depreciation warnings, hope that its ok since it is the point of this PR. |
@anmolsjoshi thanks !
Yes, this is OK. We are almost done, concerning sklearn tests, I thought about to add them to all tests : test_binary_compute, test_binary_compute_batch_images, test_categorical_compute, test_categorical_compute_batch_images. Let me add them by myself. |
@vfdev-5 feel free to use this latest commit, I can remove it. but it replaces all the accuracy tests to compare with sklearn and using random values for y_pred. I can extend this to precision and recall if needed. |
Awesome @anmolsjoshi ! Haven't yet time to do what you just committed. If you can extend to precision and recall too, would be really great |
@vfdev-5 I have changed most tests of Precision and Recall to add a comparison with sklearn, I kept the naive examples because they are good sanity checks. let me know if any changes are needed. I think we should add some information in the doc string about Precision and Recall not being batch wise metrics. Basically, it shouldn't be used for the Progress Bar or as a Running Average and should only be computed at the end of epoch. Happy to discuss! |
@anmolsjoshi thanks a lot ! LGTM
I think this better to be added to @jasonkriss or @alykhantejani please review and merge if OK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very thorough. Great work on this @anmolsjoshi!
I think we can maybe extract some common logic from theses metrics at some point but we don't need to worry about that now.
LGTM (pending the warning comment). @alykhantejani I'll let you merge when ready.
@vfdev-5 @jasonkriss is there a way to ignore the W504 flake8 error? If I kept the warning as one line, it would be too long and throw another error lol. EDIT: fixed it, although its worth considering. |
@anmolsjoshi I think it is OK like you did with a line break inside a string. |
Yea looks good to me. Thanks @anmolsjoshi! |
This PR looks great, thanks @anmolsjoshi! Thanks Victor for the thorough review |
Fixes #262
Description:
This PR updates Precision and Recall, and adds Accuracy to handle binary and categorical cases for different types of input.
Check list: