-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPS, IMP] New batched_nms implementation #3426
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
ONNX failures are related, and I think the easiest workaround is to simplify the checks be independent on the device
torchvision/ops/boxes.py
Outdated
iou_threshold: float, | ||
) -> Tensor: | ||
# Based on Detectron2 implementation | ||
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be torch.zeros_like(scores, dtype = torch.bool)
?
Looks like a nice boost. :) Does it make sense to check that this fix does not have any negative side-effect on the validation metrics of all pre-trained models? I think we should confirm by re-estimating all the validation stats and by retraining some of the models just before merge. |
The two implementations are equivalent in terms of results (they both rely on |
Looks like |
…ng to return vanilla now
Codecov Report
@@ Coverage Diff @@
## master #3426 +/- ##
==========================================
+ Coverage 78.70% 78.75% +0.04%
==========================================
Files 105 105
Lines 9735 9748 +13
Branches 1563 1565 +2
==========================================
+ Hits 7662 7677 +15
+ Misses 1582 1581 -1
+ Partials 491 490 -1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) | ||
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 | ||
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it might be preferable to construct boxes which are always well-formed.
So something like
boxes = torch.rand(num_boxes, 4)
boxes[:, 2:] += boxes[:, :2]
is generally better when constructing bounding boxes. It might be good to move this to a helper function btw
Summary: * new batched_nms implem * flake8 * hopefully fix torchscipt tests * Use where instead of nonzero * Use same threshold (4k) for CPU and GPU * Remove use of argsort * use views again * remove print * trying stuff, I don't know what's going on * previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now * add tracing decorators * cleanup * wip * ignore new path with ONNX * use vanilla if tracing...???? * Remove script_if_tracing decorator as it was conflicting with _is_tracing * flake8 * Improve coverage Reviewed By: NicolasHug, cpuhrsch Differential Revision: D26945728 fbshipit-source-id: 118a41e03da2939a726e5bd18f5f77b7c0ce6339 Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Closes #1311
This PR introduces a new implementation of
batched_nms
, which is faster than the current one in some cases (refer to benchmarks in issue)