Skip to content

Commit 7da3a4f

Browse files
ezyangetaf
authored andcommitted
Make PT2 compile backprop through custom op without autograd key a hard error (#166367)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #166367 Approved by: https://github.com/bdhirsh
1 parent d3514d1 commit 7da3a4f

File tree

7 files changed

+91
-48
lines changed

7 files changed

+91
-48
lines changed

test/distributed/test_inductor_collectives.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def forward(self, x, world_size, tag, ranks, group_size):
414414

415415
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
416416
model = Model().to(self.device)
417+
model.emb.weight.requires_grad = False
417418
model_compiled = torch.compile(model)
418419
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
419420
out = model_compiled(inp, self.world_size, **self.get_world_trs())
@@ -1340,13 +1341,11 @@ def func(inp, *, tag, ranks, group_size):
13401341
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
13411342
assert same(out, correct)
13421343

1344+
# This doesn't work in all cases, and now we properly loudly error.
1345+
# See: https://github.com/pytorch/pytorch/issues/151240
1346+
# When differentiable funcols are implemented can revert.
1347+
@unittest.expectedFailure
13431348
def test_backwards(self):
1344-
"""
1345-
It's probably not that common to need backwards support for collectives.
1346-
1347-
However, I wanted to at least see if it was possible to support it as a design goal.
1348-
"""
1349-
13501349
def func(inp):
13511350
ar = _functional_collectives.all_reduce(inp, "sum", "0")
13521351
return ar

test/dynamo/test_misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9757,6 +9757,17 @@ def test_validate_outputs_unbacked_by_custom_op(self):
97579757
def foo_impl(x, y):
97589758
return torch.cat([x, y])
97599759

9760+
def setup_context(ctx, inputs, output):
9761+
(x, _) = inputs
9762+
ctx.xs = x.shape[0]
9763+
9764+
def foo_backward(ctx, grad):
9765+
return grad[: ctx.xs], grad[ctx.xs :]
9766+
9767+
torch.library.register_autograd(
9768+
"mylib::foo", foo_backward, setup_context=setup_context
9769+
)
9770+
97609771
@torch.compile(backend="aot_eager", fullgraph=True)
97619772
def f(x, i):
97629773
i0, i1 = i.tolist()

test/dynamo/test_structured_trace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,8 @@ def forward(self, x):
12541254
torch._dynamo.reset()
12551255

12561256
mod = SimpleModule().cuda()
1257+
for p in mod.parameters():
1258+
p.requires_grad = False
12571259
compiled = torch.compile(mod, backend="inductor")
12581260
compiled(torch.randn(4, 4, device="cuda"))
12591261

@@ -1321,6 +1323,8 @@ def forward(self, x):
13211323
torch._dynamo.reset()
13221324

13231325
mod = MixedModule().cuda()
1326+
for p in mod.parameters():
1327+
p.requires_grad = False
13241328
compiled = torch.compile(mod, backend="inductor")
13251329
compiled(torch.randn(4, 4, device="cuda"))
13261330

@@ -1375,6 +1379,8 @@ def forward(self, x):
13751379
with self._setup_runtime_estimates_capture() as payload_buffer:
13761380
torch._dynamo.reset()
13771381
mod = Mixed().cuda()
1382+
for p in mod.parameters():
1383+
p.requires_grad = False
13781384
compiled = torch.compile(mod, backend="inductor")
13791385
compiled(torch.randn(4, 4, device="cuda"))
13801386
payload = payload_buffer.getvalue().strip()

test/test_autograd_fallback.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
import torch
9+
from torch._library.autograd import autograd_fallback_mode
910
from torch.library import _scoped_library
1011
from torch.testing._internal.common_utils import (
1112
instantiate_parametrized_tests,
@@ -15,16 +16,6 @@
1516
)
1617

1718

18-
@contextlib.contextmanager
19-
def autograd_fallback_mode(mode):
20-
prev = torch._C._get_autograd_fallback_mode()
21-
try:
22-
torch._C._set_autograd_fallback_mode(mode)
23-
yield
24-
finally:
25-
torch._C._set_autograd_fallback_mode(prev)
26-
27-
2819
class TestAutogradFallback(TestCase):
2920
test_ns = "_test_autograd_fallback"
3021

torch/_functorch/aot_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch._guards import detect_fake_mode
2727
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
2828
from torch._inductor.utils import BoxedBool
29+
from torch._library.autograd import autograd_fallback_mode
2930
from torch._subclasses import FakeTensor, FakeTensorMode
3031
from torch.export._tree_utils import reorder_kwargs
3132
from torch.fx.experimental.proxy_tensor import make_fx
@@ -528,6 +529,9 @@ def create_aot_state(
528529
stack.enter_context(
529530
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing()
530531
)
532+
# Make it an error to backprop through PT2 compliant ops that silently
533+
# detach autograd
534+
stack.enter_context(autograd_fallback_mode("error"))
531535

532536
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
533537

torch/_library/autograd.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
import contextlib
23
import dataclasses
34
from collections.abc import Callable
45
from dataclasses import dataclass
@@ -235,6 +236,16 @@ def not_list_of_optional_tensor(tree):
235236
return True
236237

237238

239+
@contextlib.contextmanager
240+
def autograd_fallback_mode(mode):
241+
prev = _C._get_autograd_fallback_mode()
242+
try:
243+
_C._set_autograd_fallback_mode(mode)
244+
yield
245+
finally:
246+
_C._set_autograd_fallback_mode(prev)
247+
248+
238249
flatten = _pytree.tree_flatten
239250
unflatten = _pytree.tree_unflatten
240251
spec_t = _pytree.TreeSpec

torch/csrc/autograd/autograd_not_implemented_fallback.cpp

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,49 +50,68 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
5050
} // namespace
5151

5252
void setAutogradFallbackMode(AutogradFallbackMode mode) {
53-
TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
5453
kAutogradFallbackMode = mode;
5554
}
5655

5756
AutogradFallbackMode getAutogradFallbackMode() {
5857
return kAutogradFallbackMode;
5958
}
6059

61-
static void warnAutogradNotImplemented(const std::string& op_name) {
62-
TORCH_WARN(
63-
op_name,
64-
": an autograd kernel was not registered to the Autograd key(s) ",
65-
"but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
66-
"This behavior is deprecated and will be removed in a future version of PyTorch. ",
67-
"If your operator is differentiable, please ensure you have registered an "
68-
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
69-
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
70-
"differentiable, or to squash this warning and use the previous behavior, "
71-
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
60+
static void reportAutogradNotImplemented(
61+
const std::string& op_name,
62+
bool is_warn) {
63+
if (is_warn) {
64+
TORCH_WARN(
65+
op_name,
66+
": an autograd kernel was not registered to the Autograd key(s) ",
67+
"but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
68+
"This behavior is deprecated and will be removed in a future version of PyTorch. ",
69+
"If your operator is differentiable, please ensure you have registered an "
70+
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
71+
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
72+
"differentiable, or to squash this warning and use the previous behavior, "
73+
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
74+
} else {
75+
TORCH_CHECK(
76+
0,
77+
op_name,
78+
": an autograd kernel was not registered to the Autograd key(s) ",
79+
"but we are trying to backprop through it. This can lead to silently incorrect behavior. ",
80+
"If your operator is differentiable, please ensure you have registered an "
81+
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
82+
"). If your operator is not "
83+
"differentiable and ensure NO gradients flow through this operator, "
84+
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")
85+
}
7286
}
7387

74-
struct WarnNotImplemented : public Node {
75-
WarnNotImplemented(
88+
struct NotImplementedBackward : public Node {
89+
NotImplementedBackward(
7690
std::string op_name,
7791
size_t num_outputs,
92+
bool is_warn,
7893
edge_list&& next_edges)
7994
: Node(std::move(next_edges)),
8095
op_name(std::move(op_name)),
81-
num_outputs(num_outputs) {}
96+
num_outputs(num_outputs),
97+
is_warn(is_warn) {}
8298

83-
WarnNotImplemented(std::string op_name, size_t num_outputs)
84-
: op_name(std::move(op_name)), num_outputs(num_outputs) {}
99+
NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn)
100+
: op_name(std::move(op_name)),
101+
num_outputs(num_outputs),
102+
is_warn(is_warn) {}
85103

86104
variable_list apply(variable_list&& inputs) override;
87105

88106
std::string op_name;
89107
size_t num_outputs;
108+
bool is_warn;
90109
};
91110

92111
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
93-
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
112+
auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list {
94113
auto inputsLocal = std::move(inputs);
95-
warnAutogradNotImplemented(op_name);
114+
reportAutogradNotImplemented(op_name, is_warn);
96115
std::vector<at::Tensor> output(num_outputs);
97116
return output;
98117
}
@@ -111,8 +130,6 @@ static void basicAutogradNotImplementedFallbackImpl(
111130
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
112131
return;
113132
}
114-
TORCH_INTERNAL_ASSERT(
115-
getAutogradFallbackMode() == AutogradFallbackMode::Warn);
116133

117134
bool any_input_requires_grad = false;
118135
_foreach_tensor(
@@ -128,7 +145,9 @@ static void basicAutogradNotImplementedFallbackImpl(
128145
// by putting it after the requires_grad checks.
129146
any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
130147

131-
std::shared_ptr<WarnNotImplemented> grad_fn;
148+
bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn;
149+
150+
std::shared_ptr<NotImplementedBackward> grad_fn;
132151
if (any_input_requires_grad) {
133152
// NB: It is standard to collect edges from all tensors
134153
// (see generated/VariableTypeEverything.cpp for examples)
@@ -140,8 +159,9 @@ static void basicAutogradNotImplementedFallbackImpl(
140159
stack,
141160
stack_start,
142161
num_arguments);
143-
grad_fn = std::shared_ptr<WarnNotImplemented>(
144-
new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
162+
grad_fn = std::shared_ptr<NotImplementedBackward>(
163+
new NotImplementedBackward(
164+
op_name, all_tensors_on_stack.size(), is_warn),
145165
deleteNode);
146166
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
147167
}
@@ -177,8 +197,8 @@ static void basicAutogradNotImplementedFallbackImpl(
177197
// >>> y = op(k)
178198
// >>> torch.autograd.grad(z.sum(), w)
179199
if (t.requires_grad()) {
180-
t.register_hook([op_name](const at::Tensor& grad) {
181-
warnAutogradNotImplemented(op_name);
200+
t.register_hook([op_name, is_warn](const at::Tensor& grad) {
201+
reportAutogradNotImplemented(op_name, is_warn);
182202
});
183203
// If history is rebased, then we will attempt to warn
184204
// on the view's base. This will catch most cases (because
@@ -188,18 +208,19 @@ static void basicAutogradNotImplementedFallbackImpl(
188208
const auto& base = t._base();
189209
if (base.requires_grad()) {
190210
// Can only register_hook on tensors that require grad.
191-
base.register_hook([op_name](const at::TensorBase& grad) {
192-
warnAutogradNotImplemented(op_name);
193-
});
211+
base.register_hook(
212+
[op_name, is_warn](const at::TensorBase& grad) {
213+
reportAutogradNotImplemented(op_name, is_warn);
214+
});
194215
}
195216
}
196217
return;
197218
}
198219

199220
// If the post-autograd implementation returns any Tensors that
200-
// don't require grad, then we install the WarnNotImplemented grad_fn.
201-
// This grad_fn warns in backward and returns undefined tensor
202-
// gradients.
221+
// don't require grad, then we install the NotImplementedBackward
222+
// grad_fn. This grad_fn warns in backward and returns undefined
223+
// tensor gradients.
203224
//
204225
// NOTE [autograd fallback and in-place operations]
205226
// If the schema says the output is mutable, and the output

0 commit comments

Comments
 (0)