Skip to content

Commit 8069469

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Support Tuple[int] args to autograd.Function (#123887)
Pull Request resolved: #123887 Approved by: https://github.com/anijain2305 ghstack dependencies: #123700, #123705, #123786, #123790, #123803, #123804, #123896
1 parent 70b8c58 commit 8069469

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

test/dynamo/test_autograd_function.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,33 @@ def foo(x):
924924
foo(torch.randn(2))
925925
foo(torch.randn(2, requires_grad=True))
926926

927+
def test_tuple_arg(self):
928+
cnt = torch._dynamo.testing.CompileCounter()
929+
930+
class TupleArgFunc(torch.autograd.Function):
931+
@staticmethod
932+
def forward(ctx, x, shape):
933+
ctx.save_for_backward(torch.randn(shape))
934+
return x + 1
935+
936+
@staticmethod
937+
def backward(ctx, grad_output):
938+
(result,) = ctx.saved_tensors
939+
return result, None
940+
941+
@torch.compile(backend=cnt, fullgraph=True)
942+
def fn():
943+
return TupleArgFunc.apply(x, shape)
944+
945+
shape = (10, 10)
946+
x = torch.randn(shape, requires_grad=True)
947+
out = fn()
948+
out.sum().backward()
949+
self.assertEqual(out, x + 1)
950+
self.assertEqual(x.grad.shape, shape)
951+
self.assertEqual(cnt.frame_count, 1)
952+
self.assertEqual(cnt.op_count, 2)
953+
927954
@requires_cuda
928955
@skipIfRocm
929956
def test_triton_kernel_basic(self):

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def validate_args_and_maybe_create_graph_inputs(
137137
set_subgraph_inputs,
138138
description,
139139
):
140-
from . import AutogradFunctionContextVariable, EnumVariable
140+
from . import AutogradFunctionContextVariable
141141
from .builder import wrap_fx_proxy_cls
142142

143143
assert tracer.parent is not None
@@ -166,7 +166,7 @@ def validate_args_and_maybe_create_graph_inputs(
166166
args.append(a)
167167
continue
168168

169-
if isinstance(a, (ConstantVariable, EnumVariable)):
169+
if a.is_python_constant():
170170
# This arg is not used in the body of the higher order op.
171171
# Currently, this new input is added to make the calls
172172
# happy, which expect a fixed number of arguments. In

0 commit comments

Comments
 (0)