Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up equalize transform: use bincount instead of histc #3493

Merged
merged 8 commits into from
Mar 8, 2021

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Mar 3, 2021

bincount is faster than histc on CPU, so this should speed up (~2X) the equalize transforms.

This is sort of a follow-up to #3334 and #3173

On GPU:

In [16]: img = torch.randint(0, 256, size=(460 * 680,)).to('cuda')

In [17]: %timeit torch.bincount(img, minlength=256)
1.52 ms ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [18]: %timeit torch.histc(img.to(torch.float32), bins=256, min=0, max=255)
127 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

on CPU (different machine):

img = torch.randint(0, 256, size=(460 * 680,))

%timeit torch.histc(img.to(torch.float32), bins=256, min=0, max=255)
696 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.bincount(img, minlength=256)
289 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

@NicolasHug NicolasHug changed the title WIP use bincount instead of histc [TRANS, IMP] Speed up equalize transform: use bincount instead of histc Mar 3, 2021
@codecov
Copy link

codecov bot commented Mar 3, 2021

Codecov Report

Merging #3493 (c294ec0) into master (414427d) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #3493   +/-   ##
=======================================
  Coverage   78.75%   78.75%           
=======================================
  Files         105      105           
  Lines        9748     9750    +2     
  Branches     1565     1566    +1     
=======================================
+ Hits         7677     7679    +2     
  Misses       1581     1581           
  Partials      490      490           
Impacted Files Coverage Δ
torchvision/transforms/functional_tensor.py 79.92% <100.00%> (+0.07%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 414427d...c294ec0. Read the comment docs.

@fmassa
Copy link
Member

fmassa commented Mar 3, 2021

There is a huge difference on CUDA. I think it would be worth opening an issue in PyTorch to point out that the implementation of bincount is suboptimal?

Also, it would be good to check the number of CUDA again, but this time using torch.cuda.synchronize() to get more accurate timings

if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to clip values between 0 and 255 to keep things consistent between cuda/cpu?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing minlength=256 ensures that bincount will count the number of values for all values in [0, 255], so histc and bincount would be equivalent.

In the histc call we have bins=256, min=0, max=255, which IIUC assumes that the image is within [0, 255] already, is this not the case?

@NicolasHug NicolasHug changed the title [TRANS, IMP] Speed up equalize transform: use bincount instead of histc Speed up equalize transform: use bincount instead of histc Mar 8, 2021
@NicolasHug
Copy link
Member Author

I created a new "Improvement" label, in accordance with #3351 (comment)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@fmassa fmassa merged commit 77e4187 into pytorch:master Mar 8, 2021
@NicolasHug NicolasHug added the Perf For performance improvements label Mar 9, 2021
facebook-github-bot pushed a commit that referenced this pull request Mar 10, 2021
…3493)

Summary:
* use bincount instead of hist

* only use bincount when on CPU

* Added equality test for CPU vs cuda

* Fix flake8 and tests

* tuple instead of int for size

Reviewed By: NicolasHug, cpuhrsch

Differential Revision: D26945736

fbshipit-source-id: 5b13a01e1b04d8c92317d3478bf9c9bb1c7d1375
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants