Skip to content

Commit

Permalink
Add instructional error message for cudnn RNN double backward workaro…
Browse files Browse the repository at this point in the history
…und (#33884)

Summary:
Pull Request resolved: #33884

Mitigates #5261.

It's not possible for us to support cudnn RNN double backwards due to
limitations in the cudnn API. This PR makes it so that we raise an error
message if users try to get the double backward on a cudnn RNN; in the
error message we suggest using the non-cudnn RNN.

Test Plan: - added some tests to check the error message

Reviewed By: albanD

Differential Revision: D20143544

Pulled By: zou3519

fbshipit-source-id: c2e49b3d8bdb9b34b561f006150e4c7551a78fac
  • Loading branch information
zou3519 authored and facebook-github-bot committed Jan 19, 2021
1 parent 5d64658 commit 1154a85
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
20 changes: 19 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
suppress_warnings, slowTest,
load_tests, random_symmetric_matrix,
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
TemporaryFileName)
TemporaryFileName, TEST_WITH_ROCM)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
Expand Down Expand Up @@ -6880,6 +6880,24 @@ def flatten_out(mod, inp):
torch.autograd.gradcheck(gradcheckfunc, inp)
torch.autograd.gradgradcheck(gradcheckfunc, inp)

if inp.is_cuda and not TEST_WITH_ROCM:
# Assert that we have good error message around unsupported CuDNN double backward
# NB: we trigger double backward using .backward() instead of autograd.grad due to
# https://github.com/pytorch/pytorch/issues/37874
with torch.backends.cudnn.flags(enabled=True):
result = gradcheckfunc(inp)
result[0].sum().backward(create_graph=True)
grad0 = next(mod.parameters()).grad
with self.assertRaisesRegex(RuntimeError,
"please disable the CuDNN backend temporarily"):
grad0.sum().backward()

# Here we avoid the backward(create_graph=True) memory leak
# described in https://github.com/pytorch/pytorch/issues/7343
for param in mod.parameters():
param.grad = None
inp.grad = None

def test_LSTM_grad_and_gradgrad(self, device):
hsize = 4
inp = torch.rand(1, 3, hsize, device=device, dtype=torch.float64, requires_grad=True)
Expand Down
8 changes: 8 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,14 @@

- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
dropout_state: non_differentiable
input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)

# miopen

Expand Down
21 changes: 18 additions & 3 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ using at::Scalar;
using at::IntArrayRef;
using at::TensorList;

const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)";


bool isDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
Expand Down Expand Up @@ -71,9 +74,21 @@ Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, c
return grad * ratio;
}

Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
template <typename T>
T not_implemented_base(const char* name, const char* reason) {
std::string msg = c10::str("the derivative for '", name, "' is not implemented.");
if (strlen(reason) > 0) {
msg = c10::str(msg, " ", reason);
};
throw std::runtime_error(msg);
}

Tensor not_implemented(const char* name, const char* reason) {
return not_implemented_base<Tensor>(name, reason);
}

std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
return not_implemented_base<std::vector<Tensor>>(name, reason);
}

Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace autograd {
namespace generated {
namespace details {

extern const char* kCudnnDoubleBackwardMsg;

// A simple way to imperatively compute index ranges for slots
// that have been flattened
struct IndexRangeGenerator {
Expand All @@ -37,7 +39,8 @@ bool any_variable_defined(variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor & t);
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<at::Tensor> t);
at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result);
at::Tensor not_implemented(const char* name);
at::Tensor not_implemented(const char* name, const char* reason="");
std::vector<Tensor> not_implemented_list(const char* name, const char* reason="");
at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s);
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
Expand Down

0 comments on commit 1154a85

Please sign in to comment.