-
Notifications
You must be signed in to change notification settings - Fork 25.6k
try to make at::cat in mm_tree_reduction operate on contig tensors #18816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great patch! I think it would be nice to add commentary to both the matmul->mm
and the cat
optimization.
Also I'm not quite sure about dropping b_ih
in the cell but not in the arguments.
benchmarks/fastrnns/cells.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we either take out b_ih
out of the arguments or keep it here?
torch/csrc/jit/passes/batch_mm.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we know that they're 2D can't we just do the stride check manually? .t().is_contiguous()
is quite convenient, but it allocates a whole new tensor which is a total overkill for this case.
torch/csrc/jit/passes/batch_mm.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is just
return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
torch/csrc/jit/passes/batch_mm.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires a change of the function name, because you have completely change the semantics. Is it really slower if the strides are not all the same, or is it just a guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's so that in the transpose check I could check the strides of only the first tensor, and honestly it's hard to imagine graph that would have tensors of same sizes eligible for tree reduction, but with the different strides. In any case, I'll just make transpose check go over all the tensors and leave have_same_shape
alone.
benchmarks/fastrnns/factory.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly I think we should add a new benchmark to see the effect of this change. We've been using this one for quite a while, so I'd rather keep its meaning consistent with what people expect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The effect is on the order of couple percent, I'll add a separate benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I'm getting crazy big (7%) improvement from adding bias on my current system, but it's very system-dependent.
root@7a3abf660096:/workspace/ALL/pytorch_upstream/benchmarks# python -m benchmarks.fastrnns.bench --group rnns --inputSize 1024 --hiddenSize 1024 --rnns jit_premul jit_premul_bias jit cudnn --nloops 100
Namespace(cnns=None, device='cuda', group=['rnns'], hiddenSize=1024, inputSize=1024, miniBatch=64, nloops=100, numLayers=1, print_json=False, rnns=['jit_premul', 'jit_premul_bias', 'jit', 'cudnn'], sep=' ', seqLength=100, variable_lstms=False, warmup=10)
Benchmarking LSTMs...
name avg_fwd std_fwd avg_bwd std_bwd
jit_premul 10.52 0.02504 21.2 1.051
jit_premul_bias 10.7 0.04063 19.62 0.2769
jit 11.49 0.02089 20.96 0.2493
cudnn 9.815 0.04521 18.98 0.09339
@pytorchbot retest this please |
@apaszke can you please take a look? CI failures look unrelated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like comments are all addressed and the code seems fine to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zdevito is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…ytorch#18816) Summary: Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass. Pull Request resolved: pytorch#18816 Differential Revision: D15013576 Pulled By: wanchaol fbshipit-source-id: bcfa1cf44180b11b05b0f55f034707012f66281a
Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass.