-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Optimize backward for torch.repeat #46726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good. Just small comments.
Tests (using x = torch.randn(3, 4, requires_grad=True)
def test_repeat(x):
y = x.repeat(20, 10, 15, 25)
out = y.sum()
out.backward() Average time for 10 * 10000 runs.
|
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
I'm curious to see the equivalent torch.expand / torch.repeat_interleave call for the examples in #43192 (comment). Also, it would be nice to see some benchmark figures on larger tensors (a tensor of size [3, 4] is pretty small, in practice users have larger tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lgtm (with some minor comments). I'm curious to see some more performance numbers
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 565d4b2 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 31 times. |
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
Benchmark for two different implementations. x = torch.randn(320, 480, requires_grad=True)
x_t = torch.randn(320, 480, requires_grad=True).t() # non-contiguous
...
y = x.repeat(16, 16)
... Test 2: x = torch.randn(160, 240, requires_grad=True)
x_t = torch.randn(160, 240, requires_grad=True).t() # non-contiguous
...
y = x.repeat(32, 32)
... Test 3: x = torch.randn(80, 120, requires_grad=True)
x_t = torch.randn(80, 120, requires_grad=True).t() # non-contiguous
...
y = x.repeat(64, 64)
... Test 4: x = torch.randn(3, 4, requires_grad=True)
x_t = torch.randn(3, 4, requires_grad=True).t() # non-contiguous
...
y = x.repeat(16, 32)
... Average time for 1,000 runs:
NC refers to non-contiguous. Conclusion:
|
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
Benchmark for repeat/repeat_interleave/expand: x = torch.rand((1, 1920)).requires_grad_()
repeated = x.repeat(1280, 1)
repeated = x.repeat_interleave(1280, dim=0)
repeated = x.expand(1280, 1920) Average time for 10,000 runs:
cc: @zou3519 |
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
PR and the benchmark table are updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
// grad_size [4, 2, 3, 9, 4, 3, 5] | ||
// sum_dims [0, 3, 5] | ||
grad = grad.reshape(grad_size); | ||
grad = grad.sum(sum_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we do repeat 1 at all dimensions, sum_dims
becomes empty and leads to sum over the whole grad
rather than keep it there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#29137 strikes again
Fixes #43192 Differential Revision: [D24481801](https://our.internmc.facebook.com/intern/diff/D24481801) [ghstack-poisoned]
Fixes #43192 Differential Revision: [D24739840](https://our.internmc.facebook.com/intern/diff/D24739840) [ghstack-poisoned]
If "import to phabricator" or ghimport fails, you might have to unlink this github PR from the original diff. |
Stack from ghstack:
Fixes #43192
Differential Revision: D24739840