diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index 44ae8df544f..25818a721a6 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -90,6 +90,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]: return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n)) +class ModuleLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + return (torch.randn(3),) + + # # Backends # @@ -116,7 +128,7 @@ def export_module_to_program( extract_delegate_segments: bool, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, - method: str = "forward", + method_name: str = "forward", ) -> ExecutorchProgramManager: eager_module = module_class().eval() inputs = () @@ -124,16 +136,15 @@ def export_module_to_program( inputs = eager_module.get_random_inputs() # type: ignore[operator] class WrapperModule(torch.nn.Module): - def __init__(self, fn): + def __init__(self, fn, method_name=method_name): super().__init__() self.fn = fn + self.method_name = method_name def forward(self, *args, **kwargs): - return self.fn(*args, **kwargs) + return getattr(self.fn, self.method_name)(*args, **kwargs) - exported_program = export( - WrapperModule(getattr(eager_module, method)), args=inputs, strict=True - ) + exported_program = export(WrapperModule(eager_module), args=inputs, strict=True) edge_config = EdgeCompileConfig(_check_ir_validity=False) et_config = exir.ExecutorchBackendConfig( diff --git a/test/models/targets.bzl b/test/models/targets.bzl index 6538302c507..ab5fcc8a51d 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -156,6 +156,7 @@ def define_common_targets(): "ModuleAddMul", "ModuleAddLarge", "ModuleSubLarge", + "ModuleLinear", ] # Name of the backend to use when exporting delegated programs.