Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions backends/cadence/aot/program_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

# pyre-strict

from enum import auto, Enum
from typing import Optional

from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
from executorch.exir.pass_base import ProxyValue
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier

from torch import Tensor
from torch._export.verifier import Verifier
from torch.export import ExportedProgram
from torch.export.graph_signature import (
ExportGraphSignature,
Expand All @@ -21,14 +22,20 @@
)


class IrMode(Enum):
EXIR = auto()
ATEN = auto()


class ProgramBuilder(GraphBuilder):
"""Utility class to build a program from a graph module."""

def __init__(self) -> None:
def __init__(self, mode: Optional[IrMode] = None) -> None:
self.input_specs: list[InputSpec] = []
self.output_specs: list[OutputSpec] = []
self.constants: dict[str, Tensor] = {}
self.state_dict: dict[str, Tensor] = {}
self.mode: IrMode = mode or IrMode.EXIR
super().__init__()

def insert_input_spec(
Expand Down Expand Up @@ -68,6 +75,16 @@ def output(
)
return super().output(results)

def get_verifiers(self) -> Optional[list[Verifier]]:
if self.mode == IrMode.ATEN:
return None
return [
EXIREdgeDialectVerifier(
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
class_only=True,
)
]

def get_program(self) -> ExportedProgram:
gm = self.get_graph_module()
return ExportedProgram(
Expand All @@ -81,12 +98,8 @@ def get_program(self) -> ExportedProgram:
state_dict=self.state_dict,
range_constraints={},
module_call_graph=[],
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
class_only=True,
)
],
# pyre-ignore[6]: Incompatible parameter type.
verifiers=self.get_verifiers(),
)

def get_edge_program(self) -> EdgeProgramManager:
Expand Down
104 changes: 102 additions & 2 deletions backends/cadence/aot/tests/test_program_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# pyre-strict

import torch
from executorch.backends.cadence.aot.program_builder import ProgramBuilder
from executorch.backends.cadence.aot.program_builder import IrMode, ProgramBuilder
from executorch.exir.dialects._ops import ops as exir_ops
from later.unittest import TestCase
from torch._export.verifier import SpecViolationError
from torch.export.graph_signature import InputKind, OutputKind


Expand Down Expand Up @@ -120,3 +121,102 @@ def test_user_input_mutation(self) -> None:
self.assertEqual(
program.graph_signature.output_specs[0].kind, OutputKind.USER_INPUT_MUTATION
)

def test_get_verifier_exir_mode(self) -> None:
"""Test that get_verifier returns EXIREdgeDialectVerifier for EXIR mode."""
builder = ProgramBuilder(mode=IrMode.EXIR)
verifiers = builder.get_verifiers()
self.assertIsNotNone(verifiers)
self.assertEqual(len(verifiers), 1)

def test_get_verifier_aten_mode(self) -> None:
"""Test that get_verifier returns None for ATEN mode."""
builder = ProgramBuilder(mode=IrMode.ATEN)
verifiers = builder.get_verifiers()
self.assertIsNone(verifiers)

def test_get_verifier_default_mode(self) -> None:
"""Test that get_verifier returns EXIREdgeDialectVerifier for default mode."""
builder = ProgramBuilder() # Should default to EXIR
self.assertEqual(builder.mode, IrMode.EXIR)
verifiers = builder.get_verifiers()
self.assertIsNotNone(verifiers)
self.assertEqual(len(verifiers), 1)

def test_aten_add_tensor_exir_mode(self) -> None:
"""Test using torch.ops.aten.add.Tensor with EXIR mode."""
inp = torch.randn([3, 5])
buffer = torch.randn([5])

builder = ProgramBuilder(mode=IrMode.EXIR)
inp_proxy = builder.placeholder("inp", inp)
buffer_proxy = builder.placeholder(
"buffer", buffer, input_kind=InputKind.BUFFER
)
add = builder.call_operator(
torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy)
)
builder.output([add])
builder.get_program()

def test_aten_add_tensor_aten_mode(self) -> None:
"""Test using torch.ops.aten.add.Tensor with ATEN mode."""
inp = torch.randn([3, 5])
buffer = torch.randn([5])

builder = ProgramBuilder(mode=IrMode.ATEN)
inp_proxy = builder.placeholder("inp", inp)
buffer_proxy = builder.placeholder(
"buffer", buffer, input_kind=InputKind.BUFFER
)
add = builder.call_operator(
torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy)
)
builder.output([add])
program = builder.get_program()

# Verify the program was created successfully
self.assertEqual(len(program.graph_signature.input_specs), 2)
self.assertEqual(len(program.graph_signature.output_specs), 1)
self.assertEqual(builder.mode, IrMode.ATEN)

def test_exir_edge_aten_add_tensor_exir_mode(self) -> None:
"""Test using exir_ops.edge.aten.add.Tensor with EXIR mode."""
inp = torch.randn([3, 5])
buffer = torch.randn([5])

builder_exir = ProgramBuilder(mode=IrMode.EXIR)
inp_proxy_exir = builder_exir.placeholder("inp", inp)
buffer_proxy_exir = builder_exir.placeholder(
"buffer", buffer, input_kind=InputKind.BUFFER
)
add_exir = builder_exir.call_operator(
exir_ops.edge.aten.add.Tensor, (inp_proxy_exir, buffer_proxy_exir)
)
builder_exir.output([add_exir])
program_exir = builder_exir.get_program()

# Verify the program was created successfully
self.assertEqual(len(program_exir.graph_signature.input_specs), 2)
self.assertEqual(len(program_exir.graph_signature.output_specs), 1)
self.assertEqual(builder_exir.mode, IrMode.EXIR)

def test_exir_edge_aten_add_tensor_aten_mode(self) -> None:
"""Test using exir_ops.edge.aten.add.Tensor with ATEN mode."""
inp = torch.randn([3, 5])
buffer = torch.randn([5])

builder_aten = ProgramBuilder(mode=IrMode.ATEN)
inp_proxy_aten = builder_aten.placeholder("inp", inp)
buffer_proxy_aten = builder_aten.placeholder(
"buffer", buffer, input_kind=InputKind.BUFFER
)
add_aten = builder_aten.call_operator(
exir_ops.edge.aten.add.Tensor, (inp_proxy_aten, buffer_proxy_aten)
)
builder_aten.output([add_aten])

with self.assertRaises(
SpecViolationError, msg="Operator '<EdgeOpOverload: aten.add.Tensor>"
):
builder_aten.get_program()
Loading