diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index fe47f0d12..e76ea4bb4 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -641,19 +641,22 @@ def aot_function_simplified( compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) - class AOTModule(nn.Module): - def __init__(self): - super(AOTModule, self).__init__() - self.orig_module = mod - - def forward(self, *args, **kwargs): + if top_kwargs: + def forward(*args, **kwargs): return compiled_f( *params_flat, *args, **kwargs, ) + else: + def forward(*args): + return compiled_f( + *params_flat, + *args, + ) - return AOTModule() + forward.zero_grad = mod.zero_grad + return forward compiled_function = aot_function