Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backends/arm/test/passes/test_broadcast_args_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
# LICENSE file in the root directory of this source tree.

import operator
from typing import Tuple
from typing import Callable, Tuple

import torch
from executorch.backends.arm._passes import BroadcastArgsPass

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor] # Input x
input_t = Tuple[torch.Tensor, torch.Tensor]


class NeedsMultipleBroadcastsModel(torch.nn.Module):
test_data = (torch.rand(1, 10), torch.rand(10, 1))

def __init__(self, op: operator):
def __init__(
self, op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> None:
self.op = op
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return self.op(x, y)


Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/passes/test_cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Int64Model(torch.nn.Module):
"rand": (torch.rand(4),),
}

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + 3


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ class Expand(torch.nn.Module):
Basic expand model using torch.Tensor.expand function
"""

def __init__(self):
super(Expand, self).__init__()
def __init__(self) -> None:
super().__init__()

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.expand(3, 4)

def get_inputs(self) -> input_t:
return (torch.rand(3, 1),)


def test_expand_to_repeat_tosa_INT():
def test_expand_to_repeat_tosa_INT() -> None:
module = Expand()
pipeline = PassPipeline[input_t](
module,
Expand Down
92 changes: 53 additions & 39 deletions backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple, Union
from typing import Callable, ClassVar, Dict, Tuple, Union

import pytest

Expand All @@ -22,18 +22,21 @@
input_t1 = Tuple[torch.Tensor] # Input x
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y

Scalar = Union[bool, float, int]
ArangeNoneParam = Tuple[Callable[[], input_t1], Tuple[Scalar, Scalar, Scalar]]
FullNoneParam = Tuple[Callable[[], input_t1], Tuple[Tuple[int, ...], Scalar]]


#####################################################
## Test arange(dtype=int64) -> arange(dtype=int32) ##
#####################################################


class ArangeDefaultIncrementViewLessThan(torch.nn.Module):

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (torch.arange(10, dtype=torch.int64) + 1).view(-1, 1) < x

test_data = {
test_data: ClassVar[Dict[str, input_t1]] = {
"randint": (
torch.randint(
0,
Expand All @@ -46,7 +49,9 @@ def forward(self, x: torch.Tensor):


@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data)
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(
test_data: input_t1,
) -> None:
module = ArangeDefaultIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand All @@ -67,7 +72,9 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: inp


@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data)
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(
test_data: input_t1,
) -> None:
module = ArangeDefaultIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand All @@ -88,11 +95,10 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: in


class ArangeStartIncrementViewLessThan(torch.nn.Module):

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (torch.arange(0, 10, dtype=torch.int64) + 1).view(-1, 1) < x

test_data = {
test_data: ClassVar[Dict[str, input_t1]] = {
"randint": (
torch.randint(
0,
Expand All @@ -105,7 +111,9 @@ def forward(self, x: torch.Tensor):


@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data)
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(
test_data: input_t1,
) -> None:
module = ArangeStartIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand All @@ -126,7 +134,9 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input


@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data)
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(
test_data: input_t1,
) -> None:
module = ArangeStartIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand All @@ -147,11 +157,10 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: inpu


class ArangeStartStepIncrementViewLessThan(torch.nn.Module):

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (torch.arange(0, 10, 2, dtype=torch.int64) + 1).view(-1, 1) < x

test_data = {
test_data: ClassVar[Dict[str, input_t1]] = {
"randint": (
torch.randint(
0,
Expand All @@ -166,7 +175,7 @@ def forward(self, x: torch.Tensor):
@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data)
def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP(
test_data: input_t1,
):
) -> None:
module = ArangeStartStepIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand All @@ -189,7 +198,7 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP(
@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data)
def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT(
test_data: input_t1,
):
) -> None:
module = ArangeStartStepIncrementViewLessThan()
aten_ops_checks = [
"torch.ops.aten.lt.Tensor",
Expand Down Expand Up @@ -225,7 +234,7 @@ def __init__(self, start: float, stop: float, step: float):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.arange(*self.args) + x

test_data = {
test_data: ClassVar[Dict[str, ArangeNoneParam]] = {
"int64": (lambda: (torch.randn(10, 1),), (0, 10, 1)),
"float32_start": (lambda: (torch.randn(10, 1),), (0.0, 10, 1)),
"float32_stop": (lambda: (torch.randn(10, 1),), (0, 10.0, 1)),
Expand All @@ -238,23 +247,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


@common.parametrize("test_data", ArangeAddDtypeNone.test_data)
def test_arange_dtype_none_tosa_FP(test_data):
input_data, init_data = test_data
def test_arange_dtype_none_tosa_FP(test_data: ArangeNoneParam) -> None:
input_factory, init_data = test_data
pipeline = TosaPipelineFP[input_t1](
ArangeAddDtypeNone(*init_data),
input_data(),
input_factory(),
ArangeAddDtypeNone.aten_op,
ArangeAddDtypeNone.exir_op,
)
pipeline.run()


@common.parametrize("test_data", ArangeAddDtypeNone.test_data)
def test_arange_dtype_none_tosa_INT(test_data):
input_data, init_data = test_data
def test_arange_dtype_none_tosa_INT(test_data: ArangeNoneParam) -> None:
input_factory, init_data = test_data
pipeline = TosaPipelineINT[input_t1](
ArangeAddDtypeNone(*init_data),
input_data(),
input_factory(),
ArangeAddDtypeNone.aten_op,
ArangeAddDtypeNone.exir_op,
)
Expand All @@ -268,8 +277,7 @@ def test_arange_dtype_none_tosa_INT(test_data):


class FullIncrementViewMulXLessThanY(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (
(
torch.full(
Expand All @@ -286,7 +294,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
* x
) < y

test_data = {
test_data: ClassVar[Dict[str, input_t2]] = {
"randint": (
torch.randint(
0,
Expand All @@ -305,7 +313,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):


@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data)
def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(
test_data: input_t2,
) -> None:
"""
There are four int64 placeholders in the original graph:
1. _lifted_tensor_constant0: 1
Expand Down Expand Up @@ -347,7 +357,9 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):


@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data)
def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(
test_data: input_t2,
) -> None:
"""
For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass().
And an int64->int32 cast is inserted at the beginning of the graph.
Expand Down Expand Up @@ -380,8 +392,7 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):


class RejectFullIncrementViewMulXLessThanY(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (
(
torch.full(
Expand All @@ -398,7 +409,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
* x
) < y

test_data = {
test_data: ClassVar[Dict[str, input_t2]] = {
"randint": (
torch.randint(
0,
Expand All @@ -420,7 +431,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
@pytest.mark.xfail(
reason="MLETORCH-1254: Add operator support check for aten.arange and aten.full"
)
def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(
test_data: input_t2,
) -> None:
module = RejectFullIncrementViewMulXLessThanY()
aten_ops_checks = [
"torch.ops.aten.full.default",
Expand Down Expand Up @@ -469,23 +482,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


@common.parametrize("test_data", AddConstFullDtypeNone.test_data)
def test_full_dtype_none_tosa_FP(test_data):
input_data, init_data = test_data
def test_full_dtype_none_tosa_FP(test_data: FullNoneParam) -> None:
input_factory, init_data = test_data
pipeline = TosaPipelineFP[input_t1](
AddConstFullDtypeNone(*init_data),
input_data(),
input_factory(),
aten_op=[],
exir_op=AddConstFullDtypeNone.exir_op,
)
pipeline.run()


@common.parametrize("test_data", AddConstFullDtypeNone.test_data_bool)
def test_full_dtype_none_tosa_FP_bool(test_data):
input_data, init_data = test_data
def test_full_dtype_none_tosa_FP_bool(test_data: FullNoneParam) -> None:
input_factory, init_data = test_data
pipeline = TosaPipelineFP[input_t1](
AddConstFullDtypeNone(*init_data),
input_data(),
input_factory(),
aten_op=[],
exir_op=AddConstFullDtypeNone.exir_op,
)
Expand All @@ -501,9 +514,10 @@ def test_full_dtype_none_tosa_FP_bool(test_data):
)
def test_full_dtype_none_tosa_INT(test_data):
input_data, init_data = test_data
input_factory, init_data = test_data
pipeline = TosaPipelineINT[input_t1](
AddConstFullDtypeNone(*init_data),
input_data(),
input_factory(),
aten_op=[],
exir_op=AddConstFullDtypeNone.exir_op,
)
Expand Down
Loading
Loading