Skip to content

Conversation

ngimel
Copy link
Collaborator

@ngimel ngimel commented Apr 3, 2019

Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass.

@ngimel ngimel requested a review from apaszke April 3, 2019 21:22
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 3, 2019
@soumith soumith requested a review from wanchaol April 3, 2019 21:40
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great patch! I think it would be nice to add commentary to both the matmul->mm and the cat optimization.
Also I'm not quite sure about dropping b_ih in the cell but not in the arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we either take out b_ih out of the arguments or keep it here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we know that they're 2D can't we just do the stride check manually? .t().is_contiguous() is quite convenient, but it allocates a whole new tensor which is a total overkill for this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is just

return fmap(inputs, [](const at::Tensor& i) { return i.t(); });

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires a change of the function name, because you have completely change the semantics. Is it really slower if the strides are not all the same, or is it just a guess?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so that in the transpose check I could check the strides of only the first tensor, and honestly it's hard to imagine graph that would have tensors of same sizes eligible for tree reduction, but with the different strides. In any case, I'll just make transpose check go over all the tensors and leave have_same_shape alone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I think we should add a new benchmark to see the effect of this change. We've been using this one for quite a while, so I'd rather keep its meaning consistent with what people expect

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The effect is on the order of couple percent, I'll add a separate benchmark.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm getting crazy big (7%) improvement from adding bias on my current system, but it's very system-dependent.

root@7a3abf660096:/workspace/ALL/pytorch_upstream/benchmarks# python -m benchmarks.fastrnns.bench --group rnns --inputSize 1024 --hiddenSize 1024 --rnns jit_premul jit_premul_bias jit cudnn --nloops 100
Namespace(cnns=None, device='cuda', group=['rnns'], hiddenSize=1024, inputSize=1024, miniBatch=64, nloops=100, numLayers=1, print_json=False, rnns=['jit_premul', 'jit_premul_bias', 'jit', 'cudnn'], sep=' ', seqLength=100, variable_lstms=False, warmup=10)
Benchmarking LSTMs...
            name          avg_fwd          std_fwd          avg_bwd          std_bwd
      jit_premul            10.52          0.02504             21.2            1.051
 jit_premul_bias             10.7          0.04063            19.62           0.2769
             jit            11.49          0.02089            20.96           0.2493
           cudnn            9.815          0.04521            18.98          0.09339

@ngimel
Copy link
Collaborator Author

ngimel commented Apr 9, 2019

@pytorchbot retest this please

@ngimel
Copy link
Collaborator Author

ngimel commented Apr 10, 2019

@apaszke can you please take a look? CI failures look unrelated.

@ezyang ezyang removed their request for review April 10, 2019 20:36
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like comments are all addressed and the code seems fine to me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zdevito is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in 3875e1b.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
…ytorch#18816)

Summary:
Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass.
Pull Request resolved: pytorch#18816

Differential Revision: D15013576

Pulled By: wanchaol

fbshipit-source-id: bcfa1cf44180b11b05b0f55f034707012f66281a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants