@@ -50,49 +50,68 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
5050} // namespace
5151
5252void setAutogradFallbackMode (AutogradFallbackMode mode) {
53- TORCH_CHECK (mode != AutogradFallbackMode::Error, " NYI: mode='error'" );
5453 kAutogradFallbackMode = mode;
5554}
5655
5756AutogradFallbackMode getAutogradFallbackMode () {
5857 return kAutogradFallbackMode ;
5958}
6059
61- static void warnAutogradNotImplemented (const std::string& op_name) {
62- TORCH_WARN (
63- op_name,
64- " : an autograd kernel was not registered to the Autograd key(s) " ,
65- " but we are trying to backprop through it. This may lead to silently incorrect behavior. " ,
66- " This behavior is deprecated and will be removed in a future version of PyTorch. " ,
67- " If your operator is differentiable, please ensure you have registered an "
68- " autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
69- " DispatchKey::CompositeImplicitAutograd). If your operator is not "
70- " differentiable, or to squash this warning and use the previous behavior, "
71- " please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd." );
60+ static void reportAutogradNotImplemented (
61+ const std::string& op_name,
62+ bool is_warn) {
63+ if (is_warn) {
64+ TORCH_WARN (
65+ op_name,
66+ " : an autograd kernel was not registered to the Autograd key(s) " ,
67+ " but we are trying to backprop through it. This may lead to silently incorrect behavior. " ,
68+ " This behavior is deprecated and will be removed in a future version of PyTorch. " ,
69+ " If your operator is differentiable, please ensure you have registered an "
70+ " autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
71+ " DispatchKey::CompositeImplicitAutograd). If your operator is not "
72+ " differentiable, or to squash this warning and use the previous behavior, "
73+ " please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd." );
74+ } else {
75+ TORCH_CHECK (
76+ 0 ,
77+ op_name,
78+ " : an autograd kernel was not registered to the Autograd key(s) " ,
79+ " but we are trying to backprop through it. This can lead to silently incorrect behavior. " ,
80+ " If your operator is differentiable, please ensure you have registered an "
81+ " autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
82+ " ). If your operator is not "
83+ " differentiable and ensure NO gradients flow through this operator, "
84+ " please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd." )
85+ }
7286}
7387
74- struct WarnNotImplemented : public Node {
75- WarnNotImplemented (
88+ struct NotImplementedBackward : public Node {
89+ NotImplementedBackward (
7690 std::string op_name,
7791 size_t num_outputs,
92+ bool is_warn,
7893 edge_list&& next_edges)
7994 : Node(std::move(next_edges)),
8095 op_name (std::move(op_name)),
81- num_outputs(num_outputs) {}
96+ num_outputs(num_outputs),
97+ is_warn(is_warn) {}
8298
83- WarnNotImplemented (std::string op_name, size_t num_outputs)
84- : op_name(std::move(op_name)), num_outputs(num_outputs) {}
99+ NotImplementedBackward (std::string op_name, size_t num_outputs, bool is_warn)
100+ : op_name(std::move(op_name)),
101+ num_outputs(num_outputs),
102+ is_warn(is_warn) {}
85103
86104 variable_list apply (variable_list&& inputs) override ;
87105
88106 std::string op_name;
89107 size_t num_outputs;
108+ bool is_warn;
90109};
91110
92111// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
93- auto WarnNotImplemented ::apply (variable_list&& inputs) -> variable_list {
112+ auto NotImplementedBackward ::apply (variable_list&& inputs) -> variable_list {
94113 auto inputsLocal = std::move (inputs);
95- warnAutogradNotImplemented (op_name);
114+ reportAutogradNotImplemented (op_name, is_warn );
96115 std::vector<at::Tensor> output (num_outputs);
97116 return output;
98117}
@@ -111,8 +130,6 @@ static void basicAutogradNotImplementedFallbackImpl(
111130 op.redispatchBoxed (dispatch_keys & c10::after_autograd_keyset, stack);
112131 return ;
113132 }
114- TORCH_INTERNAL_ASSERT (
115- getAutogradFallbackMode () == AutogradFallbackMode::Warn);
116133
117134 bool any_input_requires_grad = false ;
118135 _foreach_tensor (
@@ -128,7 +145,9 @@ static void basicAutogradNotImplementedFallbackImpl(
128145 // by putting it after the requires_grad checks.
129146 any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled ();
130147
131- std::shared_ptr<WarnNotImplemented> grad_fn;
148+ bool is_warn = getAutogradFallbackMode () == AutogradFallbackMode::Warn;
149+
150+ std::shared_ptr<NotImplementedBackward> grad_fn;
132151 if (any_input_requires_grad) {
133152 // NB: It is standard to collect edges from all tensors
134153 // (see generated/VariableTypeEverything.cpp for examples)
@@ -140,8 +159,9 @@ static void basicAutogradNotImplementedFallbackImpl(
140159 stack,
141160 stack_start,
142161 num_arguments);
143- grad_fn = std::shared_ptr<WarnNotImplemented>(
144- new WarnNotImplemented (op_name, all_tensors_on_stack.size ()),
162+ grad_fn = std::shared_ptr<NotImplementedBackward>(
163+ new NotImplementedBackward (
164+ op_name, all_tensors_on_stack.size (), is_warn),
145165 deleteNode);
146166 grad_fn->set_next_edges (collect_next_edges (all_tensors_on_stack));
147167 }
@@ -177,8 +197,8 @@ static void basicAutogradNotImplementedFallbackImpl(
177197 // >>> y = op(k)
178198 // >>> torch.autograd.grad(z.sum(), w)
179199 if (t.requires_grad ()) {
180- t.register_hook ([op_name](const at::Tensor& grad) {
181- warnAutogradNotImplemented (op_name);
200+ t.register_hook ([op_name, is_warn ](const at::Tensor& grad) {
201+ reportAutogradNotImplemented (op_name, is_warn );
182202 });
183203 // If history is rebased, then we will attempt to warn
184204 // on the view's base. This will catch most cases (because
@@ -188,18 +208,19 @@ static void basicAutogradNotImplementedFallbackImpl(
188208 const auto & base = t._base ();
189209 if (base.requires_grad ()) {
190210 // Can only register_hook on tensors that require grad.
191- base.register_hook ([op_name](const at::TensorBase& grad) {
192- warnAutogradNotImplemented (op_name);
193- });
211+ base.register_hook (
212+ [op_name, is_warn](const at::TensorBase& grad) {
213+ reportAutogradNotImplemented (op_name, is_warn);
214+ });
194215 }
195216 }
196217 return ;
197218 }
198219
199220 // If the post-autograd implementation returns any Tensors that
200- // don't require grad, then we install the WarnNotImplemented grad_fn.
201- // This grad_fn warns in backward and returns undefined tensor
202- // gradients.
221+ // don't require grad, then we install the NotImplementedBackward
222+ // grad_fn. This grad_fn warns in backward and returns undefined
223+ // tensor gradients.
203224 //
204225 // NOTE [autograd fallback and in-place operations]
205226 // If the schema says the output is mutable, and the output
0 commit comments