File tree Expand file tree Collapse file tree 2 files changed +29
-2
lines changed Expand file tree Collapse file tree 2 files changed +29
-2
lines changed Original file line number Diff line number Diff line change @@ -924,6 +924,33 @@ def foo(x):
924924 foo (torch .randn (2 ))
925925 foo (torch .randn (2 , requires_grad = True ))
926926
927+ def test_tuple_arg (self ):
928+ cnt = torch ._dynamo .testing .CompileCounter ()
929+
930+ class TupleArgFunc (torch .autograd .Function ):
931+ @staticmethod
932+ def forward (ctx , x , shape ):
933+ ctx .save_for_backward (torch .randn (shape ))
934+ return x + 1
935+
936+ @staticmethod
937+ def backward (ctx , grad_output ):
938+ (result ,) = ctx .saved_tensors
939+ return result , None
940+
941+ @torch .compile (backend = cnt , fullgraph = True )
942+ def fn ():
943+ return TupleArgFunc .apply (x , shape )
944+
945+ shape = (10 , 10 )
946+ x = torch .randn (shape , requires_grad = True )
947+ out = fn ()
948+ out .sum ().backward ()
949+ self .assertEqual (out , x + 1 )
950+ self .assertEqual (x .grad .shape , shape )
951+ self .assertEqual (cnt .frame_count , 1 )
952+ self .assertEqual (cnt .op_count , 2 )
953+
927954 @requires_cuda
928955 @skipIfRocm
929956 def test_triton_kernel_basic (self ):
Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def validate_args_and_maybe_create_graph_inputs(
137137 set_subgraph_inputs ,
138138 description ,
139139):
140- from . import AutogradFunctionContextVariable , EnumVariable
140+ from . import AutogradFunctionContextVariable
141141 from .builder import wrap_fx_proxy_cls
142142
143143 assert tracer .parent is not None
@@ -166,7 +166,7 @@ def validate_args_and_maybe_create_graph_inputs(
166166 args .append (a )
167167 continue
168168
169- if isinstance ( a , ( ConstantVariable , EnumVariable ) ):
169+ if a . is_python_constant ( ):
170170 # This arg is not used in the body of the higher order op.
171171 # Currently, this new input is added to make the calls
172172 # happy, which expect a fixed number of arguments. In
You can’t perform that action at this time.
0 commit comments