-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify convolution double backward gradInput formulas #54840
Conversation
💊 CI failures summary and remediationsAs of commit 7bde291 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_xla_linux_bionic_py3_6_clang9_build (1/2)Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Codecov Report
@@ Coverage Diff @@
## master #54840 +/- ##
==========================================
- Coverage 77.47% 77.47% -0.01%
==========================================
Files 1893 1893
Lines 185963 185949 -14
==========================================
- Hits 144077 144056 -21
- Misses 41886 41893 +7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much simpler thanks! just a question about the dilation/stride change.
Also the slowdown on cudnn 8 is... wow
if (expected_input_shape != input_shape[0]) { | ||
gi_conv_params.output_padding[1] = input_shape[0] - expected_input_shape; | ||
} | ||
} else { | ||
for(size_t i = 0; i < kernel_size.size(); ++i) { | ||
// Check if whole input has been used or not | ||
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.stride[i] | ||
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you change this? It was wrong before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was correct before, because stride and dilation was swapped. Changing the order of arguments also requires swapping stride and dilation, so I removed that std::swap call, but to make sure that expected input shape remains the same (as it should) I had to make this change.
This is tested btw, so when I had it wrong here, tests were failing.
@@ -1212,7 +1212,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option | |||
} | |||
} | |||
|
|||
// Compute gI = convT(ggW, gO.t()) if !transposed | |||
// Compute gI = convT(gO, ggW) if !transposed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: re-align the "if"?
Also do we want to mention here where it comes from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh, I don't know where it's coming from, it's your code :-), I also don't exactly know what .t()
meant in the original comment for a 4d tensor.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamped
Currently in convolution double backward grad of input is computed as
convT(ggW, gO.T)
. Notice how first argument is, in fact, of the size that convolution weight has, and second is of the size of gradOutput, which is an inverse order compared to how convolutions are regularly called, and sizes are far from what cudnn heuristics is trained for and what cudnn is guaranteed to have efficient kernels for. This takes cudnn 8 to some dark places, calling kernels that take 20-100 s. But, luckily for us, convT is a commutative operation (unlike conv), so convT(ggW, gO) is actually the same as convT(gO, ggW), modulo some transposes because of conventions around the weight size, so we can use convT(gO, ggW). As an added bonus, we don't need a special branch for groups with this formulation.For the following pretty standard convolution,
benchmarking script is below: