Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,45 @@ def forward(self, x, y):
)
edge_manager.to_executorch()

def test_data_dependent(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo1",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)

@torch.library.impl("mylib::foo1", "cpu", lib=lib)
def foo_impl(a, b):
return a + b

@torch.library.register_fake("mylib::foo1", lib=lib)
def mylib_foo_default_fake(*args, **kwargs):
ctx = torch.library.get_ctx()
fake_shape = ctx.new_dynamic_size()
return torch.empty(fake_shape, dtype=torch.float32, device="cpu")

class M(torch.nn.Module):
def forward(self, a, b, c):
res = torch.ops.mylib.foo1(a, b)

c_item = c.item()
torch._check_is_size(c_item)
torch._check(c_item < res.shape[0])
return res[:c_item]

inp = (torch.randn(10), torch.randn(10), torch.tensor(3))

ep = export(M(), inp)
edge = to_edge(ep)
self.assertTrue(
torch.allclose(
edge.exported_program().module()(*inp),
M()(*inp),
)
)

def test_edge_manager_transform(self):
edge_manager: EdgeProgramManager = to_edge(
get_exported_programs(), get_config_methods()
Expand Down
34 changes: 0 additions & 34 deletions exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,6 @@ def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None:
verifier.check_valid_edge_op(edge_op)
verifier.check_valid_op(edge_op)

def test_edge_verifier_enablement(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
z = y.item()
torch._check(z > 0)
torch._check(z < 4)
return x[z : z + y.shape[0]]

ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])), strict=True)

compile_config_with_disable_ir_validity = EdgeCompileConfig(
_check_ir_validity=False
)
edge_manager = to_edge(
ep, compile_config=compile_config_with_disable_ir_validity
)

normal_verifier = EXIREdgeDialectVerifier()
disable_ir_validity_verifier = EXIREdgeDialectVerifier(
compile_config_with_disable_ir_validity
)

# exported model can not pass normal verifier due to
# aten.sym_constrain_range.default is illegal to be edge op
with self.assertRaises(SpecViolationError):
normal_verifier(edge_manager.exported_program())

# exported model can pass disable_ir_validity_verifier due to verifier
# is disabled by compile_config_with_disable_ir_validity
# (_check_ir_validity=False). Noted that this verifation has been done
# when calling `to_edge`. Explicitly calling verifier here just for better
# demonstration and is unnecessary in real world for ir verification.
disable_ir_validity_verifier(edge_manager.exported_program())

def test_edge_verifier_check_edge_op(self) -> None:
class Model(torch.nn.Module):
def __init__(self):
Expand Down
36 changes: 19 additions & 17 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
from executorch.exir.verification.arg_validator import (
EdgeOpArgValidator,
RunHigherOrderOperatorError,
Expand Down Expand Up @@ -99,16 +101,20 @@ def __init__(self) -> None:
self._exception_list = exception_list if exception_list else []

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = [
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
exception_list += self._exception_list
exception_list = (
[
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
+ list(_EXECUTORCH_SYM_OPS)
+ DISALLOW_LIST
+ self._exception_list
)

return exception_list

Expand Down Expand Up @@ -249,13 +255,9 @@ def check_valid_edge_op(self, op):
return
if (
op
in [
operator.getitem,
torch.ops.aten.sym_size.int,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
]
in [operator.getitem]
+ DISALLOW_LIST
+ list(_EXECUTORCH_SYM_OPS)
+ self._exception_list
):
return
Expand Down