Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions backends/arm/test/misc/test_tosa_dialect_shape_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,55 @@ def test_mul_mixed_shape():
assert _expr_equals(result[0], sympy.Integer(3) * sympy.Symbol("s0"))


# Test MOD_SHAPE with constant values, which should perform modulo and return a constant shape.
def test_mod_const_shape_no_target():
shape_env = ShapeEnv()
with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
), FakeTensorMode():
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([8, 21])
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([3, 5])
result = exir_ops.backend.tosa.MOD_SHAPE.default(const_0, const_1)
assert len(result) == 2
assert result == [2, 1]


# Test MOD_SHAPE with symbolic values, which should produce a Mod expression.
def test_mod_symbolic_shape_no_target():
shape_env = ShapeEnv()
s0 = _make_symint(shape_env, "s0", hint=8)
s1 = _make_symint(shape_env, "s1", hint=3)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
), FakeTensorMode(shape_env=shape_env) as mode:
s0_tensor = torch.empty(size=(1, 3, s0))
s1_tensor = torch.empty(size=(1, 3, s1))
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
result = exir_ops.backend.tosa.MOD_SHAPE.default(dim_s0, dim_s1)
assert len(result) == 1
assert isinstance(result[0], torch.SymInt)
assert _expr_equals(result[0], sympy.Mod(sympy.Symbol("s0"), sympy.Symbol("s1")))


def test_mod_mixed_shape_no_target():
shape_env = ShapeEnv()
s0 = _make_symint(shape_env, "s0", hint=4)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
), FakeTensorMode(shape_env=shape_env) as mode:
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([8])
s0_tensor = torch.empty(size=(1, 3, s0))
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
result = exir_ops.backend.tosa.MOD_SHAPE.default(const_shape, dim_s0)

assert len(result) == 1
assert isinstance(result[0], torch.SymInt)
assert _expr_equals(result[0], sympy.Mod(sympy.Integer(8), sympy.Symbol("s0")))


# Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape.
def test_div_floor_const_shape():
shape_env = ShapeEnv()
Expand Down
15 changes: 15 additions & 0 deletions backends/arm/tosa/dialect/ops/shape_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,18 @@ def MUL_SHAPE(
"""

return _combine_shapes(shape1, shape2, lambda a, b: a * b)


@register_fake_tosa_op(
"MOD_SHAPE(SymInt[] shape1, SymInt[] shape2) -> SymInt[]", # schema
TosaSpecification.all_profiles_for_version("1.1"),
)
def MOD_SHAPE(
shape1: list[IntLikeType],
shape2: list[IntLikeType],
) -> list[IntLikeType]:
"""MOD_SHAPE operator computes the element-wise modulo of the first shape
tensor by the second.
"""

return _combine_shapes(shape1, shape2, lambda a, b: a % b)
Loading