-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prototype] Optimize Center Crop performance #6880
Conversation
@datumbox thanks for the PR, do you have a table with runtime speed-up ? |
@@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat | |||
|
|||
|
|||
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: | |||
num_channels, height, width = get_dimensions_image_tensor(image) | |||
num_channels, height, width = image.shape[-3:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligns the idiom with other parts of the code-base where we directly rely on shape
to determine the dimensions:
new_height, new_width = image.shape[-2:] |
At that point the image
is a pure tensor and making additional method calls to fetch the sizes is unnecessary.
shape = image.shape | ||
num_channels, height, width = shape[-3:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit to align idioms across the code-base. See ref:
shape = bounding_box.shape |
Originally I wanted to do something like *extra_dims, num_channels, height, width
but that's not JIT-scriptable. So I opted for keeping the whole original shape as we do on other parts of the code-base.
if image.numel() == 0: | ||
return image.reshape(shape[:-2] + (crop_height, crop_width)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed because the original pad_image_tensor
method below had a mitigation for zero batch images. So I thought to hit that point early.
|
||
if crop_height > image_height or crop_width > image_width: | ||
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) | ||
image = pad_image_tensor(image, padding_ltrb, fill=0) | ||
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no real reason to hit pad_image_tensor()
and go through the validation of input, parsing of parameters and multiple method calls to actually hit PyTorch's pad. We should try to be as explicit as possible on the internal implementations, as this pays out. As we can see below the performance gains are significant.
Here are V1 vs V2 benchmarks on latest main:
[-------------- CenterCrop cpu torch.float32 -------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 42 (+- 0) us | 57 (+- 0) us
(16, 3, 40, 50) | 528 (+- 1) us | 555 (+- 2) us
6 threads: ------------------------------------------------
(3, 40, 50) | 88 (+- 1) us | 104 (+- 3) us
(16, 3, 40, 50) | 615 (+- 5) us | 643 (+- 19) us
Times are in microseconds (us).
[------------- CenterCrop cuda torch.float32 -------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 48 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 48 (+- 0) us
6 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 1) us | 48 (+- 1) us
(16, 3, 40, 50) | 31 (+- 0) us | 48 (+- 1) us
Times are in microseconds (us).
[--------------- CenterCrop cpu torch.uint8 --------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 68 (+- 0) us | 77 (+- 0) us
(16, 3, 40, 50) | 760 (+- 4) us | 785 (+- 4) us
6 threads: ------------------------------------------------
(3, 40, 50) | 106 (+- 1) us | 116 (+- 2) us
(16, 3, 40, 50) | 837 (+- 6) us | 861 (+- 18) us
Times are in microseconds (us).
[-------------- CenterCrop cuda torch.uint8 --------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 48 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 48 (+- 0) us
6 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 1) us | 48 (+- 1) us
(16, 3, 40, 50) | 31 (+- 0) us | 48 (+- 1) us
Times are in microseconds (us).
Here is after the PR:
[-------------- CenterCrop cpu torch.float32 -------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 45 (+- 0) us | 42 (+- 0) us
(16, 3, 40, 50) | 525 (+- 2) us | 535 (+- 1) us
6 threads: ------------------------------------------------
(3, 40, 50) | 88 (+- 3) us | 88 (+- 1) us
(16, 3, 40, 50) | 615 (+- 23) us | 626 (+- 18) us
Times are in microseconds (us).
[------------- CenterCrop cuda torch.float32 -------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 35 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 35 (+- 0) us
6 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 35 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 35 (+- 0) us
Times are in microseconds (us).
[--------------- CenterCrop cpu torch.uint8 --------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 68 (+- 0) us | 65 (+- 0) us
(16, 3, 40, 50) | 763 (+- 4) us | 773 (+- 6) us
6 threads: ------------------------------------------------
(3, 40, 50) | 106 (+- 3) us | 104 (+- 2) us
(16, 3, 40, 50) | 839 (+- 22) us | 851 (+- 23) us
Times are in microseconds (us).
[-------------- CenterCrop cuda torch.uint8 --------------]
| old | new
1 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 34 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 35 (+- 0) us
6 threads: ------------------------------------------------
(3, 40, 50) | 37 (+- 0) us | 35 (+- 0) us
(16, 3, 40, 50) | 31 (+- 0) us | 35 (+- 1) us
Times are in microseconds (us).
This benchmark forces the branch of padding by putting images with size (40,50)
and requesting a crop of (224, 224)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Summary: * Reducing unnecessary method calls * Optimize pad branch. * Remove unnecessary call to crop_image_tensor * Fix linter Reviewed By: datumbox Differential Revision: D41020555 fbshipit-source-id: 55d55d80993830d0b70ad4140d55fab2cba9d21e Co-authored-by: vfdev <vfdev.5@gmail.com>
Related to #6818
CenterCrop is a very popular inference transform that often is executed on mobile devices where performance matters a lot. This PR focuses on closing the performance gap from V1:
cc @vfdev-5 @bjuncek @pmeier