diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py index 0bb71c95a4a..d73cc9fcfbf 100644 --- a/backends/cadence/aot/program_builder.py +++ b/backends/cadence/aot/program_builder.py @@ -2,14 +2,15 @@ # pyre-strict +from enum import auto, Enum from typing import Optional from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.pass_base import ProxyValue from executorch.exir.verification.verifier import EXIREdgeDialectVerifier - from torch import Tensor +from torch._export.verifier import Verifier from torch.export import ExportedProgram from torch.export.graph_signature import ( ExportGraphSignature, @@ -21,14 +22,20 @@ ) +class IrMode(Enum): + EXIR = auto() + ATEN = auto() + + class ProgramBuilder(GraphBuilder): """Utility class to build a program from a graph module.""" - def __init__(self) -> None: + def __init__(self, mode: Optional[IrMode] = None) -> None: self.input_specs: list[InputSpec] = [] self.output_specs: list[OutputSpec] = [] self.constants: dict[str, Tensor] = {} self.state_dict: dict[str, Tensor] = {} + self.mode: IrMode = mode or IrMode.EXIR super().__init__() def insert_input_spec( @@ -68,6 +75,16 @@ def output( ) return super().output(results) + def get_verifiers(self) -> Optional[list[Verifier]]: + if self.mode == IrMode.ATEN: + return None + return [ + EXIREdgeDialectVerifier( + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + class_only=True, + ) + ] + def get_program(self) -> ExportedProgram: gm = self.get_graph_module() return ExportedProgram( @@ -81,12 +98,8 @@ def get_program(self) -> ExportedProgram: state_dict=self.state_dict, range_constraints={}, module_call_graph=[], - verifiers=[ - EXIREdgeDialectVerifier( - edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), - class_only=True, - ) - ], + # pyre-ignore[6]: Incompatible parameter type. + verifiers=self.get_verifiers(), ) def get_edge_program(self) -> EdgeProgramManager: diff --git a/backends/cadence/aot/tests/test_program_builder.py b/backends/cadence/aot/tests/test_program_builder.py index f2c138dce80..a16d42e2378 100644 --- a/backends/cadence/aot/tests/test_program_builder.py +++ b/backends/cadence/aot/tests/test_program_builder.py @@ -1,10 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # pyre-strict - import torch -from executorch.backends.cadence.aot.program_builder import ProgramBuilder +from executorch.backends.cadence.aot.program_builder import IrMode, ProgramBuilder +from executorch.exir.dialects._ops import ops as exir_ops from later.unittest import TestCase +from torch._export.verifier import SpecViolationError from torch.export.graph_signature import InputKind, OutputKind @@ -120,3 +121,102 @@ def test_user_input_mutation(self) -> None: self.assertEqual( program.graph_signature.output_specs[0].kind, OutputKind.USER_INPUT_MUTATION ) + + def test_get_verifier_exir_mode(self) -> None: + """Test that get_verifier returns EXIREdgeDialectVerifier for EXIR mode.""" + builder = ProgramBuilder(mode=IrMode.EXIR) + verifiers = builder.get_verifiers() + self.assertIsNotNone(verifiers) + self.assertEqual(len(verifiers), 1) + + def test_get_verifier_aten_mode(self) -> None: + """Test that get_verifier returns None for ATEN mode.""" + builder = ProgramBuilder(mode=IrMode.ATEN) + verifiers = builder.get_verifiers() + self.assertIsNone(verifiers) + + def test_get_verifier_default_mode(self) -> None: + """Test that get_verifier returns EXIREdgeDialectVerifier for default mode.""" + builder = ProgramBuilder() # Should default to EXIR + self.assertEqual(builder.mode, IrMode.EXIR) + verifiers = builder.get_verifiers() + self.assertIsNotNone(verifiers) + self.assertEqual(len(verifiers), 1) + + def test_aten_add_tensor_exir_mode(self) -> None: + """Test using torch.ops.aten.add.Tensor with EXIR mode.""" + inp = torch.randn([3, 5]) + buffer = torch.randn([5]) + + builder = ProgramBuilder(mode=IrMode.EXIR) + inp_proxy = builder.placeholder("inp", inp) + buffer_proxy = builder.placeholder( + "buffer", buffer, input_kind=InputKind.BUFFER + ) + add = builder.call_operator( + torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy) + ) + builder.output([add]) + builder.get_program() + + def test_aten_add_tensor_aten_mode(self) -> None: + """Test using torch.ops.aten.add.Tensor with ATEN mode.""" + inp = torch.randn([3, 5]) + buffer = torch.randn([5]) + + builder = ProgramBuilder(mode=IrMode.ATEN) + inp_proxy = builder.placeholder("inp", inp) + buffer_proxy = builder.placeholder( + "buffer", buffer, input_kind=InputKind.BUFFER + ) + add = builder.call_operator( + torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy) + ) + builder.output([add]) + program = builder.get_program() + + # Verify the program was created successfully + self.assertEqual(len(program.graph_signature.input_specs), 2) + self.assertEqual(len(program.graph_signature.output_specs), 1) + self.assertEqual(builder.mode, IrMode.ATEN) + + def test_exir_edge_aten_add_tensor_exir_mode(self) -> None: + """Test using exir_ops.edge.aten.add.Tensor with EXIR mode.""" + inp = torch.randn([3, 5]) + buffer = torch.randn([5]) + + builder_exir = ProgramBuilder(mode=IrMode.EXIR) + inp_proxy_exir = builder_exir.placeholder("inp", inp) + buffer_proxy_exir = builder_exir.placeholder( + "buffer", buffer, input_kind=InputKind.BUFFER + ) + add_exir = builder_exir.call_operator( + exir_ops.edge.aten.add.Tensor, (inp_proxy_exir, buffer_proxy_exir) + ) + builder_exir.output([add_exir]) + program_exir = builder_exir.get_program() + + # Verify the program was created successfully + self.assertEqual(len(program_exir.graph_signature.input_specs), 2) + self.assertEqual(len(program_exir.graph_signature.output_specs), 1) + self.assertEqual(builder_exir.mode, IrMode.EXIR) + + def test_exir_edge_aten_add_tensor_aten_mode(self) -> None: + """Test using exir_ops.edge.aten.add.Tensor with ATEN mode.""" + inp = torch.randn([3, 5]) + buffer = torch.randn([5]) + + builder_aten = ProgramBuilder(mode=IrMode.ATEN) + inp_proxy_aten = builder_aten.placeholder("inp", inp) + buffer_proxy_aten = builder_aten.placeholder( + "buffer", buffer, input_kind=InputKind.BUFFER + ) + add_aten = builder_aten.call_operator( + exir_ops.edge.aten.add.Tensor, (inp_proxy_aten, buffer_proxy_aten) + ) + builder_aten.output([add_aten]) + + with self.assertRaises( + SpecViolationError, msg="Operator '" + ): + builder_aten.get_program()