Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Optimize Center Crop performance #6880

Merged
merged 5 commits into from
Nov 1, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 1, 2022

Related to #6818

CenterCrop is a very popular inference transform that often is executed on mobile devices where performance matters a lot. This PR focuses on closing the performance gap from V1:

  • Speeds up significantly the code-branch that requires padding.
  • Additionally marginally improves the speed on the code-branch that doesn't require padding.
  • Contains a few nits to align the code-base practices across methods around how we estimate the shapes of inputs within tensor kernels.

cc @vfdev-5 @bjuncek @pmeier

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 1, 2022

@datumbox thanks for the PR, do you have a table with runtime speed-up ?

@@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat


def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)
num_channels, height, width = image.shape[-3:]
Copy link
Contributor Author

@datumbox datumbox Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligns the idiom with other parts of the code-base where we directly rely on shape to determine the dimensions:

new_height, new_width = image.shape[-2:]

At that point the image is a pure tensor and making additional method calls to fetch the sizes is unnecessary.

Comment on lines +275 to +276
shape = image.shape
num_channels, height, width = shape[-3:]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit to align idioms across the code-base. See ref:

Originally I wanted to do something like *extra_dims, num_channels, height, width but that's not JIT-scriptable. So I opted for keeping the whole original shape as we do on other parts of the code-base.

Comment on lines +1163 to +1164
if image.numel() == 0:
return image.reshape(shape[:-2] + (crop_height, crop_width))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because the original pad_image_tensor method below had a mitigation for zero batch images. So I thought to hit that point early.


if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
Copy link
Contributor Author

@datumbox datumbox Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no real reason to hit pad_image_tensor() and go through the validation of input, parsing of parameters and multiple method calls to actually hit PyTorch's pad. We should try to be as explicit as possible on the internal implementations, as this pays out. As we can see below the performance gains are significant.

Here are V1 vs V2 benchmarks on latest main:

[-------------- CenterCrop cpu torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   42 (+-  0) us  |   57 (+-  0) us
      (16, 3, 40, 50)  |  528 (+-  1) us  |  555 (+-  2) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   88 (+-  1) us  |  104 (+-  3) us
      (16, 3, 40, 50)  |  615 (+-  5) us  |  643 (+- 19) us

Times are in microseconds (us).

[------------- CenterCrop cuda torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   48 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  1) us  |   48 (+-  1) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  1) us

Times are in microseconds (us).

[--------------- CenterCrop cpu torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   68 (+-  0) us  |   77 (+-  0) us
      (16, 3, 40, 50)  |  760 (+-  4) us  |  785 (+-  4) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |  106 (+-  1) us  |  116 (+-  2) us
      (16, 3, 40, 50)  |  837 (+-  6) us  |  861 (+- 18) us

Times are in microseconds (us).

[-------------- CenterCrop cuda torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   48 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  1) us  |   48 (+-  1) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  1) us

Times are in microseconds (us).

Here is after the PR:

[-------------- CenterCrop cpu torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   45 (+-  0) us  |   42 (+-  0) us
      (16, 3, 40, 50)  |  525 (+-  2) us  |  535 (+-  1) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   88 (+-  3) us  |   88 (+-  1) us
      (16, 3, 40, 50)  |  615 (+- 23) us  |  626 (+- 18) us

Times are in microseconds (us).

[------------- CenterCrop cuda torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us

Times are in microseconds (us).

[--------------- CenterCrop cpu torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   68 (+-  0) us  |   65 (+-  0) us
      (16, 3, 40, 50)  |  763 (+-  4) us  |  773 (+-  6) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |  106 (+-  3) us  |  104 (+-  2) us
      (16, 3, 40, 50)  |  839 (+- 22) us  |  851 (+- 23) us

Times are in microseconds (us).

[-------------- CenterCrop cuda torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   34 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  1) us

Times are in microseconds (us).

This benchmark forces the branch of padding by putting images with size (40,50) and requesting a crop of (224, 224).

@datumbox datumbox changed the title [WIP] [prototype] Optimize Center Crop performance [prototype] Optimize Center Crop performance Nov 1, 2022
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@datumbox datumbox merged commit d95fbaf into pytorch:main Nov 1, 2022
@datumbox datumbox deleted the prototype/center_crop_opts branch November 1, 2022 16:56
facebook-github-bot pushed a commit that referenced this pull request Nov 4, 2022
Summary:
* Reducing unnecessary method calls

* Optimize pad branch.

* Remove unnecessary call to crop_image_tensor

* Fix linter

Reviewed By: datumbox

Differential Revision: D41020555

fbshipit-source-id: 55d55d80993830d0b70ad4140d55fab2cba9d21e

Co-authored-by: vfdev <vfdev.5@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants