Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPS, IMP] New batched_nms implementation #3426

Merged
merged 28 commits into from
Mar 8, 2021

Conversation

NicolasHug
Copy link
Member

Closes #1311

This PR introduces a new implementation of batched_nms, which is faster than the current one in some cases (refer to benchmarks in issue)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

ONNX failures are related, and I think the easiest workaround is to simplify the checks be independent on the device

iou_threshold: float,
) -> Tensor:
# Based on Detectron2 implementation
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be torch.zeros_like(scores, dtype = torch.bool)?

@datumbox
Copy link
Contributor

Looks like a nice boost. :)

Does it make sense to check that this fix does not have any negative side-effect on the validation metrics of all pre-trained models? I think we should confirm by re-estimating all the validation stats and by retraining some of the models just before merge.

@NicolasHug
Copy link
Member Author

The two implementations are equivalent in terms of results (they both rely on nms). The only difference is the way we account for the classes, as we only want overlaps to be detected within the same class.

@fmassa
Copy link
Member

fmassa commented Mar 1, 2021

Looks like argsort is not supported in ONNX either, so this change is still breaking ONNX. I think sort is implemented for ONNX, so it would be a quick thing to check

@NicolasHug NicolasHug changed the title New batched_nms implementation [OPS, IMP] New batched_nms implementation Mar 4, 2021
@codecov
Copy link

codecov bot commented Mar 5, 2021

Codecov Report

Merging #3426 (4b7e942) into master (c991db8) will increase coverage by 0.04%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3426      +/-   ##
==========================================
+ Coverage   78.70%   78.75%   +0.04%     
==========================================
  Files         105      105              
  Lines        9735     9748      +13     
  Branches     1563     1565       +2     
==========================================
+ Hits         7662     7677      +15     
+ Misses       1582     1581       -1     
+ Partials      491      490       -1     
Impacted Files Coverage Δ
torchvision/ops/boxes.py 91.26% <100.00%> (+3.48%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c991db8...4b7e942. Read the comment docs.

@NicolasHug
Copy link
Member Author

Since the code has changed due to onnx and torchscript compatibility issues, I re-ran the benchmarks.
They're quite similar to the previous ones so I see no reason to change the thresholds. If CI goes green, I think this is good to go!

image

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines +470 to +472
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be preferable to construct boxes which are always well-formed.
So something like

boxes = torch.rand(num_boxes, 4)
boxes[:, 2:] += boxes[:, :2]

is generally better when constructing bounding boxes. It might be good to move this to a helper function btw

@fmassa fmassa merged commit 414427d into pytorch:master Mar 8, 2021
@NicolasHug NicolasHug added improvement module: ops Perf For performance improvements labels Mar 9, 2021
facebook-github-bot pushed a commit that referenced this pull request Mar 10, 2021
Summary:
* new batched_nms implem

* flake8

* hopefully fix torchscipt tests

* Use where instead of nonzero

* Use same threshold (4k) for CPU and GPU

* Remove use of argsort

* use views again

* remove print

* trying stuff, I don't know what's going on

* previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now

* add tracing decorators

* cleanup

* wip

* ignore new path with ONNX

* use vanilla if tracing...????

* Remove script_if_tracing decorator as it was conflicting with _is_tracing

* flake8

* Improve coverage

Reviewed By: NicolasHug, cpuhrsch

Differential Revision: D26945728

fbshipit-source-id: 118a41e03da2939a726e5bd18f5f77b7c0ce6339

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Space/time complexity of batched_nms grows quadratically with batch size
5 participants