New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve compare kernel #29743
Improve compare kernel #29743
Conversation
The PR looks good. You are right, type promotion logic is becoming very complicated, so it would be good to have a benchmarking script (under /benchmarks looks like a natural place). |
Will merge conflict with #30065 |
@VitalyFedyunin I just saw #30065, I suggest land #30065 first because I feel it might be easier to fix merge conflicts here. (The merge conflict would tell me about |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Currently, the way the compare kernels handle dtypes is very funny (this behavior is introduced in #28427 and I just realize it today):
Let's say
a, b
are two float tensors on CUDA.If you do
a < b
, this is what would happen inside the loop:a
andb
, dynamically cast them fromfloat
tofloat
. (i.e. check the scalar type to figure out if it needs cast. it doesn't. so do nothing then.)a < b
, get abool
resultfloat
float
tobool
and store the valueAnd if you do
a.lt_(b)
, this is what would happen:a
andb
, no castinga < b
, get abool
resultfloat
Although dynamic casting happens on registers, it still hurt the performance a bit (~8%).
This PR fixes this issue. Now for compare kernels, if the output is bool and inputs have the same dtype, then there is no dynamic casting. Otherwise, there will be dynamic casting for each input and output. That is, the dynamic casting behavior of the two cases described above are swapped.
Benchmark on
a < b
for tensor of 1000000000 fp32 elements:Before #28427 6.35 ms
Current master: 6.88 ms
With this PR: 6.36 ms
Benchmark on
a.lt_(b)
does not show any difference across versions.Besides this, what worries me most is, with type promotion, the logic for tensor iterator is becoming super complicated, and it is hard to see if one change causes the performance regression of others. I suggest we create scripts that could benchmark tensor iterator entirely, review that code and put it somewhere inside the repository (maybe under
/tools
or/test/scripts
?), and whenever we are not certain about the performance we could run it to check. (I guess not on this PR but on PRs after the script is done. If there are worries about performance, the author of PRs should run the script manually, and the reviewer should remind PR author to do so if necessary) If this is a good idea, I will send a PR for the script.