Skip to content

Conversation

ejguan
Copy link
Contributor

@ejguan ejguan commented Oct 22, 2020

Stack from ghstack:

Fixes #43192

Differential Revision: D24739840

@ejguan ejguan requested review from albanD and apaszke as code owners October 22, 2020 17:58
ejguan added a commit that referenced this pull request Oct 22, 2020
ghstack-source-id: db0cc44
Pull Request resolved: #46726
@ejguan ejguan linked an issue Oct 22, 2020 that may be closed by this pull request
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good. Just small comments.

@ejguan
Copy link
Contributor Author

ejguan commented Oct 22, 2020

Tests (using torch.utils.benchmark):

x = torch.randn(3, 4, requires_grad=True)
def test_repeat(x):
    y = x.repeat(20, 10, 15, 25)
    out = y.sum()
    out.backward()

Average time for 10 * 10000 runs.

Before After Improvement
CPU 4.236 ms 4.124 ms 2.64%
GPU 760.262 us 259.652 us 65.85%

ejguan added a commit that referenced this pull request Oct 23, 2020
ghstack-source-id: 0026638
Pull Request resolved: #46726
@ejguan ejguan changed the title [WIP] Optimize backward for torch.repeat Optimize backward for torch.repeat Oct 26, 2020
@ejguan ejguan requested a review from zou3519 October 26, 2020 15:20
@zou3519
Copy link
Contributor

zou3519 commented Oct 27, 2020

Tests (using torch.utils.benchmark):

I'm curious to see the equivalent torch.expand / torch.repeat_interleave call for the examples in #43192 (comment).

Also, it would be nice to see some benchmark figures on larger tensors (a tensor of size [3, 4] is pretty small, in practice users have larger tensors)

zou3519
zou3519 previously approved these changes Oct 27, 2020
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm (with some minor comments). I'm curious to see some more performance numbers

@dr-ci
Copy link

dr-ci bot commented Oct 27, 2020

💊 CI failures summary and remediations

As of commit 565d4b2 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 31 times.

ejguan added a commit that referenced this pull request Oct 27, 2020
ghstack-source-id: 52b835e
Pull Request resolved: #46726
@ejguan
Copy link
Contributor Author

ejguan commented Oct 28, 2020

Benchmark for two different implementations.
Test 1:

x = torch.randn(320, 480, requires_grad=True)
x_t = torch.randn(320, 480, requires_grad=True).t() # non-contiguous
...
y = x.repeat(16, 16)
...

Test 2:

x = torch.randn(160, 240, requires_grad=True)
x_t = torch.randn(160, 240, requires_grad=True).t() # non-contiguous
...
y = x.repeat(32, 32)
...

Test 3:

x = torch.randn(80, 120, requires_grad=True)
x_t = torch.randn(80, 120, requires_grad=True).t() # non-contiguous
...
y = x.repeat(64, 64)
...

Test 4:

x = torch.randn(3, 4, requires_grad=True)
x_t = torch.randn(3, 4, requires_grad=True).t() # non-contiguous
...
y = x.repeat(16, 32)
...

Average time for 1,000 runs:

CPU GPU CPU NC GPU NC
Before (1) 125.31 ms 4.46 ms 155.39 ms 4.72 ms
Multi-time (1) 127.16 ms 4.02 ms 165.37 ms 4.14 ms
One-time (1) 117.31 ms 3.97 ms 155.49 ms 4.22 ms
Before (2) 115.17 ms 4.57 ms 138.76 ms 4.98 ms
Multi-time (2) 122.30 ms 4.02 ms 146.28 ms 4.25 ms
One-time (2) 114.66 ms 3.98 ms 139.86 ms 4.42 ms
Before (3) 115.41 ms 4.86 ms 141.28 ms 6.57 ms
Multi-time (3) 122.91 ms 3.97 ms 146.72 ms 5.71 ms
One-time (3) 118.73 ms 3.99 ms 144.87 ms 5.66 ms
Before (4) 324.34 us 822.94 us 399.30 us 779.03 us
Multi-time (4) 129.49 us 189.01 us 120.85 us 203.14 us
One-time (4) 105.63 us 178.14 us 124.81 us 182.41 us

NC refers to non-contiguous.
Multi-time refers to the implementation with multi-time reshape and sum.
One-time refers to the implementation with one-time reshape and sum.

Conclusion:

  • Apparently, both of new implementations have better performance.
  • One-time strategy has slightly better performance, especially on CPU.

ejguan added a commit that referenced this pull request Oct 28, 2020
@ejguan
Copy link
Contributor Author

ejguan commented Oct 28, 2020

Benchmark for repeat/repeat_interleave/expand:

x = torch.rand((1, 1920)).requires_grad_()

repeated = x.repeat(1280, 1)
repeated = x.repeat_interleave(1280, dim=0)
repeated = x.expand(1280, 1920)

Average time for 10,000 runs:

previous repeat repeat repeat_interleave expand
CPU 6.71 ms 2.13 ms 1.37 ms 1.34 ms
GPU 15.72 ms 256.01 us 566.06 us 159.51 us

cc: @zou3519

@ejguan ejguan requested review from zou3519 and removed request for apaszke October 29, 2020 14:12
@zou3519 zou3519 dismissed their stale review October 29, 2020 15:11

doing a re-review

ejguan added a commit that referenced this pull request Oct 29, 2020
ghstack-source-id: b86f13d
Pull Request resolved: #46726
@ejguan
Copy link
Contributor Author

ejguan commented Oct 29, 2020

I have some suggestions on reframing the algorithm in terms of input_shape and repeats instead of using grad_shape and repeats that may make the logic easier to implement. Aside from that, I think there are some UBs in the code as-is, let me know what you think

PR and the benchmark table are updated.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

ejguan added a commit that referenced this pull request Nov 2, 2020
ghstack-source-id: efcafb1
Pull Request resolved: #46726
@facebook-github-bot
Copy link
Contributor

@ejguan merged this pull request in 4e6f244.

@ejguan ejguan reopened this Nov 4, 2020
// grad_size [4, 2, 3, 9, 4, 3, 5]
// sum_dims [0, 3, 5]
grad = grad.reshape(grad_size);
grad = grad.sum(sum_dims);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do repeat 1 at all dimensions, sum_dims becomes empty and leads to sum over the whole grad rather than keep it there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#29137 strikes again

ejguan added a commit that referenced this pull request Nov 4, 2020
ejguan added a commit that referenced this pull request Nov 5, 2020
@zou3519 zou3519 self-requested a review November 9, 2020 18:49
@zou3519
Copy link
Contributor

zou3519 commented Nov 9, 2020

If "import to phabricator" or ghimport fails, you might have to unlink this github PR from the original diff.

@facebook-github-bot facebook-github-bot deleted the gh/ejguan/5/head branch November 13, 2020 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

backward of torch.repeat slower than for torch.repeat_interleave

4 participants