From 336f050d4df25ee01ae669acd716d617143b9e9e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 15 Jul 2022 20:36:30 -0700 Subject: [PATCH 1/2] Remove nn.Module from aot_module_simplified --- functorch/_src/aot_autograd.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index fe47f0d12..1e5738ff8 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -641,19 +641,21 @@ def aot_function_simplified( compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) - class AOTModule(nn.Module): - def __init__(self): - super(AOTModule, self).__init__() - self.orig_module = mod - - def forward(self, *args, **kwargs): + if top_kwargs: + def forward(*args, **kwargs): return compiled_f( *params_flat, *args, **kwargs, ) + else: + def forward(*args): + return compiled_f( + *params_flat, + *args, + ) - return AOTModule() + return forward compiled_function = aot_function From 1ce2f931dcf6fb178111d54a75a416cb5de16372 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 15 Jul 2022 21:00:46 -0700 Subject: [PATCH 2/2] zero_grad --- functorch/_src/aot_autograd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 1e5738ff8..e76ea4bb4 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -655,6 +655,7 @@ def forward(*args): *args, ) + forward.zero_grad = mod.zero_grad return forward