Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify convolution double backward gradInput formulas #54840

Closed
wants to merge 4 commits into from

Conversation

ngimel
Copy link
Collaborator

@ngimel ngimel commented Mar 28, 2021

Currently in convolution double backward grad of input is computed as convT(ggW, gO.T). Notice how first argument is, in fact, of the size that convolution weight has, and second is of the size of gradOutput, which is an inverse order compared to how convolutions are regularly called, and sizes are far from what cudnn heuristics is trained for and what cudnn is guaranteed to have efficient kernels for. This takes cudnn 8 to some dark places, calling kernels that take 20-100 s. But, luckily for us, convT is a commutative operation (unlike conv), so convT(ggW, gO) is actually the same as convT(gO, ggW), modulo some transposes because of conventions around the weight size, so we can use convT(gO, ggW). As an added bonus, we don't need a special branch for groups with this formulation.
For the following pretty standard convolution,

  • cudnn 7.6+old formulation takes 7.5 ms for double backward,
  • cudnn 8 + old formulation takes ~40 s,
  • cudnn 8 + new formulation is 1.8 ms with benchmark enabled,
  • cudnn 8 + new formulation is 4 ms with benchmark disabled,
    benchmarking script is below:
import torch
import time

#torch.backends.cudnn.benchmark=True

def ggI(conv, inp):
    out = conv(inp)
    grads = torch.autograd.grad(out, conv.weight, torch.rand_like(out), create_graph=True, retain_graph=True)
    torch.cuda.synchronize()
    start = time.time()
    grads[0].backward(torch.rand_like(grads[0]))
    torch.cuda.synchronize()
    print("db time: ", time.time()-start)
    return inp.grad

conv = torch.nn.Conv2d(512,256,kernel_size=3, padding=1, groups=2).cuda()
inp = torch.randn(1,512,128,128, device="cuda", requires_grad=True)
for _ in range(20):
    ggI(conv, inp)
torch.cuda.synchronize()

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 28, 2021

💊 CI failures summary and remediations

As of commit 7bde291 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (1/2)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .github/workflows/lint.yml
Auto-merging .github/workflows/lint.yml
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/workflows/workflows-scheduled-ci.yml
Auto-merging .circleci/verbatim-sources/workflows/workflows-scheduled-ci.yml
CONFLICT (add/add): Merge conflict in .circleci/scripts/windows_cuda_install.sh
Auto-merging .circleci/scripts/windows_cuda_install.sh
CONFLICT (add/add): Merge conflict in .circleci/docker/common/install_base.sh
Auto-merging .circleci/docker/common/install_base.sh
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (2/2)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .github/workflows/lint.yml
Auto-merging .github/workflows/lint.yml
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/workflows/workflows-scheduled-ci.yml
Auto-merging .circleci/verbatim-sources/workflows/workflows-scheduled-ci.yml
CONFLICT (add/add): Merge conflict in .circleci/scripts/windows_cuda_install.sh
Auto-merging .circleci/scripts/windows_cuda_install.sh
CONFLICT (add/add): Merge conflict in .circleci/docker/common/install_base.sh
Auto-merging .circleci/docker/common/install_base.sh
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@ngimel ngimel requested a review from albanD March 28, 2021 05:27
@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel ngimel changed the title Simplify double backward gI formulas Simplify convolution double backward gradInput formulas Mar 28, 2021
@codecov
Copy link

codecov bot commented Mar 28, 2021

Codecov Report

Merging #54840 (8c4ce54) into master (1442a92) will decrease coverage by 0.00%.
The diff coverage is 57.14%.

❗ Current head 8c4ce54 differs from pull request most recent head 7bde291. Consider uploading reports for the commit 7bde291 to get more accurate results

@@            Coverage Diff             @@
##           master   #54840      +/-   ##
==========================================
- Coverage   77.47%   77.47%   -0.01%     
==========================================
  Files        1893     1893              
  Lines      185963   185949      -14     
==========================================
- Hits       144077   144056      -21     
- Misses      41886    41893       +7     

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much simpler thanks! just a question about the dilation/stride change.

Also the slowdown on cudnn 8 is... wow

if (expected_input_shape != input_shape[0]) {
gi_conv_params.output_padding[1] = input_shape[0] - expected_input_shape;
}
} else {
for(size_t i = 0; i < kernel_size.size(); ++i) {
// Check if whole input has been used or not
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.stride[i]
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you change this? It was wrong before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was correct before, because stride and dilation was swapped. Changing the order of arguments also requires swapping stride and dilation, so I removed that std::swap call, but to make sure that expected input shape remains the same (as it should) I had to make this change.
This is tested btw, so when I had it wrong here, tests were failing.

@@ -1212,7 +1212,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
}
}

// Compute gI = convT(ggW, gO.t()) if !transposed
// Compute gI = convT(gO, ggW) if !transposed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: re-align the "if"?
Also do we want to mention here where it comes from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, I don't know where it's coming from, it's your code :-), I also don't exactly know what .t() meant in the original comment for a 4d tensor.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in fb1c193.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants