Skip to content
7 changes: 4 additions & 3 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,10 @@ def call_method(
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
):
if name == "backward":
if name == "apply":
options = VariableTracker.propagate(self, args, kwargs.values())
return self.call_apply(tx, args, kwargs).add_options(options)
elif name == "backward":
with tx.strict_translation_mode():
if isinstance(self.fn_cls.backward, types.FunctionType):
backward = UserFunctionVariable(self.fn_cls.backward)
Expand Down Expand Up @@ -642,8 +645,6 @@ def reconstruct(self, codegen):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
return self.obj.call_apply(tx, args, kwargs).add_options(self)
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)

def call_method(
Expand Down