diff --git a/exir/pass_base.py b/exir/pass_base.py index db6bef8e3f6..9c97921f51b 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -318,7 +318,11 @@ def call_function( if target == operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) - elif getattr(target, "__module__", None) in {"_operator", "math"}: + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: assert callable(target) return self.callback.call_sym(target, args, meta) elif target in _TORCH_SYM_OPS: diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 7a0623040f8..fdb954010ca 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -339,7 +339,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue - elif getattr(target, "__module__", None) == "_operator": + elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: continue diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 4af233aaa66..fa1c2e6913f 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import builtins import math import operator -from typing import Dict, Set, Union +from typing import Any, Dict, Set, Union # necessary to ensure the ops are registered import torch @@ -94,12 +95,24 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") +def ceil(a: _SymScalar) -> _SymScalar: + return math.ceil(a) # pyre-ignore + + +@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") +def builtin_round(a: _SymScalar) -> _SymScalar: + return round(a) # pyre-ignore + + @bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") def trunc(a: _SymScalar) -> _SymScalar: return math.trunc(a) # pyre-ignore -_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { +_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { + builtins.round: ops.backend.executorch_prim.round.Scalar, + math.ceil: ops.backend.executorch_prim.ceil.Scalar, math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 5755ab8d66e..38901bb8407 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -303,6 +303,51 @@ static Kernel prim_ops[] = { } }), + // ceil.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::ceil.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(ceil(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + + // round.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::round.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + // Round half to even to match Python round(). Need an explicit + // implementation as not all platforms support fenv rounding modes. + // See + // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html + const auto val = a.toDouble(); + const auto r = round(val); + const auto d = r - val; + auto res = 0.0; + + if (std::abs(d) != 0.5) { + res = r; + } else if (fmod(r, 2.0) == 0.0) { + res = r; + } else { + res = val - d; + } + + out = EValue(static_cast(res)); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + // trunc.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::trunc.Scalar", diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 3581a470da7..ab6bd28e6cc 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,6 +503,47 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestCeil) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 1, 1, 1, 1, 2, 0, -1, -1, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::ceil.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + +TEST_F(RegisterPrimOpsTest, TestRound) { + // Note that Python uses round-to-even for halfway values. + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.5, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 1, 1, 2, 0, -1, -2, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::round.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + TEST_F(RegisterPrimOpsTest, TestTrunc) { std::array inputs = { 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};