Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing interpolate on uint8 unsqueezed 3D CL tensor #100258

Closed

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Apr 28, 2023

Description:

  • Fixed a bug with memory format issue:

When input is channels last 4d tensor that was produced as following

t = torch.ones(1, 3, 32, 32).contiguous(memory_format=torch.channels_last)
t = t[0]
t = t[None, ...]

due to a surprising behaviour of suggest_memory_format(), the output tensor will be allocated as contiguous instead of channels_last. Up to now the uint8 AVX bilinear interpolation code was assuming that the output format was the same the input format (but this is unfortunately not the case) and this leads to the output values to be completely wrong.
This PR is a band-aid fix to address this issue in the short term, by converting the output from channels_last to contiguous when needed, to match with the (incorrect) format suggested by suggest_memory_format().

Here is a repro code to show that nightly is broken for this particular case:

import torch

torch.manual_seed(0)

input = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8).contiguous(memory_format=torch.channels_last)
input = input[0]
input = input[None, ...]

assert input.is_contiguous(memory_format=torch.channels_last)

output = torch.nn.functional.interpolate(input, (224, 224), mode="bilinear", antialias=True)
expected = torch.nn.functional.interpolate(input.float(), (224, 224), mode="bilinear", antialias=True)

assert output.is_contiguous()
assert expected.is_contiguous()

torch.testing.assert_close(expected, output.float(), atol=1, rtol=1)
# > 
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "/pytorch/torch/testing/_comparison.py", line 1511, in assert_close
#     raise error_metas[0].to_error(msg)
# AssertionError: Tensor-likes are not close!
#
# Mismatched elements: 14120 / 150528 (9.4%)
# Greatest absolute difference: 214.6112518310547 at index (0, 1, 152, 13) (up to 1 allowed)
# Greatest relative difference: 17.005144119262695 at index (0, 2, 26, 2) (up to 1 allowed)
  • Also renamed needs_unpacking by skip_unpacking

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @NicolasHug

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100258

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 39d59f6:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Apr 28, 2023
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Apr 28, 2023
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 1, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix-interp-uint8-CL-sq-unsq-input onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix-interp-uint8-CL-sq-unsq-input && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the fix-interp-uint8-CL-sq-unsq-input branch from 4f1a315 to 2460501 Compare May 1, 2023 12:26
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vfdev-5 , made a minor suggestion to save a copy (I think it works?). LMK what you think

aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 2, 2023

@NicolasHug thanks for the review. I pushed new commit with your suggestions. I also run benchmarks to ensure no regression and new code results are compatible (+/- noise in measurements):

[------------------------------------------------------------------------------------ Resize -----------------------------------------------------------------------------------]
                                                                                 |  torch (2.1.0a0+git642ff29) PR  |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+git2b75955) nightly
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=True       |        130.454 (+-1.253)        |    38.932 (+-0.141)    |          130.156 (+-1.198)
      3 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=False      |        113.436 (+-0.886)        |                        |          113.090 (+-0.845)
      3 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=True     |        299.910 (+-1.663)        |   128.562 (+-1.052)    |          297.497 (+-1.513)
      3 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=False    |        289.811 (+-1.445)        |                        |          284.572 (+-1.724)
      3 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=True     |        435.729 (+-2.656)        |   177.710 (+-1.145)    |          428.966 (+-2.281)
      3 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=False    |        431.445 (+-7.529)        |                        |          427.157 (+-3.360)
      3 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=True       |        436.363 (+-15.860)       |   113.192 (+-1.564)    |          435.647 (+-2.457)
      3 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=False      |        367.168 (+-2.358)        |                        |          368.374 (+-2.392)
      3 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=True     |        670.886 (+-2.795)        |   281.037 (+-1.682)    |          672.237 (+-3.140)
      3 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=False    |        604.945 (+-2.238)        |                        |          599.064 (+-2.170)
      3 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=True       |        773.890 (+-8.635)        |   186.807 (+-1.402)    |          773.957 (+-2.273)
      3 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=False      |        653.021 (+-2.553)        |                        |          652.533 (+-2.453)
      3 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=True     |       1076.630 (+-14.478)       |   409.352 (+-1.459)    |         1077.477 (+-27.304)
      3 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=False    |        936.126 (+-4.393)        |                        |          927.276 (+-3.197)
      3 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=True       |        162.954 (+-1.953)        |                        |          161.050 (+-1.879)
      3 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=True     |        326.284 (+-2.572)        |                        |          321.202 (+-1.606)
      3 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=True   |       3038.104 (+-297.238)      |                        |         2611.023 (+-15.586)
      3 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=True       |        132.138 (+-0.999)        |                        |          132.621 (+-0.996)
      3 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=True     |        316.963 (+-2.040)        |                        |          316.663 (+-6.584)
      3 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=True   |       2117.061 (+-133.743)      |                        |         2444.293 (+-250.152)
      3 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=False      |        161.159 (+-1.408)        |                        |          159.147 (+-1.471)
      3 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=False    |        324.444 (+-1.822)        |                        |          319.925 (+-1.869)
      3 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=False  |       2664.916 (+-42.042)       |                        |         2793.362 (+-62.544)
      3 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=False      |        114.110 (+-1.273)        |                        |          114.003 (+-0.867)
      3 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=False    |        302.792 (+-2.563)        |                        |          298.243 (+-1.933)
      3 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=False  |       1760.817 (+-68.172)       |                        |         1993.225 (+-119.838)
      4 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=True       |         72.175 (+-0.282)        |                        |           71.982 (+-0.401)
      4 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=False      |         55.412 (+-0.222)        |                        |           54.667 (+-0.204)
      4 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=True     |        254.876 (+-1.962)        |                        |          251.780 (+-1.419)
      4 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=False    |        246.226 (+-2.131)        |                        |          239.121 (+-1.706)
      4 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=True     |        408.890 (+-4.142)        |                        |          401.858 (+-3.618)
      4 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=False    |        409.223 (+-3.658)        |                        |          397.509 (+-3.603)
      4 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=True       |        194.260 (+-1.202)        |                        |          195.091 (+-1.881)
      4 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=False      |        126.143 (+-0.685)        |                        |          126.012 (+-0.718)
      4 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=True     |        443.454 (+-2.013)        |                        |          444.901 (+-2.627)
      4 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=False    |        379.477 (+-1.950)        |                        |          371.343 (+-1.855)
      4 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=True       |        322.995 (+-10.864)       |                        |          323.207 (+-1.854)
      4 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=False      |        200.938 (+-1.068)        |                        |          201.037 (+-1.042)
      4 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=True     |        641.310 (+-5.975)        |                        |          638.586 (+-3.610)
      4 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=False    |        505.239 (+-6.366)        |                        |          488.475 (+-2.911)
      4 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=True       |        172.068 (+-1.276)        |                        |          171.226 (+-1.142)
      4 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=True     |        300.086 (+-1.594)        |                        |          296.751 (+-2.027)
      4 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=True   |       4411.853 (+-113.128)      |                        |         2929.503 (+-21.368)
      4 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=True       |         88.924 (+-0.438)        |                        |           89.530 (+-0.358)
      4 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=True     |        266.332 (+-1.870)        |                        |          265.148 (+-2.375)
      4 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=True   |       1421.863 (+-174.917)      |                        |         1197.055 (+-45.796)
      4 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=False      |        172.632 (+-1.933)        |                        |          169.611 (+-1.288)
      4 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=False    |        300.601 (+-2.266)        |                        |          294.500 (+-1.712)
      4 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=False  |       3022.943 (+-89.118)       |                        |         2931.171 (+-87.910)
      4 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=False      |         70.667 (+-0.302)        |                        |           70.614 (+-0.377)
      4 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=False    |        253.305 (+-1.862)        |                        |          246.772 (+-1.822)
      4 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=False  |       1136.408 (+-197.936)      |                        |          911.411 (+-4.646)

Times are in microseconds (us).

Description:

- Fixed memory format issue:

When input is channels last 4d tensor that was produced as following
```
t = torch.ones(1, 3, 32, 32).contiguous(memory_format=torch.channels_last)
t = t[0]
t = t[None, ...]
```
upsampling will produce output with channels first memory format but
our avx code does not take that into account.

- Also renamed needs_unpacking by skip_unpacking
@vfdev-5 vfdev-5 force-pushed the fix-interp-uint8-CL-sq-unsq-input branch from 642ff29 to 39d59f6 Compare May 3, 2023 13:46
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix @vfdev-5 , LGTM.

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 4, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 4, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@NicolasHug
Copy link
Member

For completeness, on top of the benchmarks above (#100258 (comment)) showing no regression for "regular" 4D tensors, @vfdev-5 also benchmarked the case of channels_last unsqueezed tensors which are affected by the issue that the output would be allocated as contiguous instead of channels_last (which is what this PR is addressing). Results are here and show that we now have to pay a ~2X slow-down for these tensors.

It's worth stressing that the 2X slow-down is caused by the unexpected result that suggest_memory_format() suggests contiguous format when it really should be suggesting channels_last (as far as we understand); it is not due to the fact that this PR is doing something unexpected or inefficient.

Hopefully, #100373 will help fixing this original problem at its root, which will allow us to revert this PR and gain back the 2X perf that we just lost.

@vfdev-5 vfdev-5 deleted the fix-interp-uint8-CL-sq-unsq-input branch May 4, 2023 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants