From 04478d368dd3523f5dd3c86fa324cd3f5a3c6cea Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Mon, 29 Sep 2025 12:04:22 -0600 Subject: [PATCH] [Backend Tester] Migrate to pytest (#14456) Refactor the backend test suites to use pytest. This includes the following changes: * Define pytest markers for each backend and test flow (recipe). This allows for easy filter, such as by running `pytest some/path/... -m backend_xnnpack`. * Use a parameterized pytest fixture to handle test generation / expansion for each test flow. * Switch to using the pytest-json-report plugin for reporting. Update the markdown generation script to take json. * Shim the existing unittest-based logic for op tests. * I've updated add.py to show what they should look like long-term. I've also just updated the model tests, since there aren't as many. I'll update the remaining op tests later in this stack, though this is purely to clean up the code. The shimming logic makes them work properly with pytest in this PR. * Update the backend test CI to use pytest. This also has the benefit of making the jobs much faster by leveraging parallel execution. I've also added a repro command to the markdown summary. (cherry picked from commit d09dd798ff340983aa11fb63de4e07cacea787e3) --- ...{test_backend_linux.sh => test_backend.sh} | 27 +- .ci/scripts/test_backend_macos.sh | 30 -- .github/workflows/_test_backend.yml | 4 +- backends/test/suite/__init__.py | 6 + backends/test/suite/conftest.py | 182 ++++++++++ backends/test/suite/flow.py | 3 + .../suite/generate_markdown_summary_json.py | 229 +++++++++++++ backends/test/suite/models/__init__.py | 133 -------- backends/test/suite/models/test_torchaudio.py | 122 +++---- .../test/suite/models/test_torchvision.py | 320 ++++++++++-------- backends/test/suite/operators/__init__.py | 135 ++------ backends/test/suite/operators/test_add.py | 109 +++--- backends/test/suite/operators/test_sub.py | 1 - pyproject.toml | 1 + 14 files changed, 757 insertions(+), 545 deletions(-) rename .ci/scripts/{test_backend_linux.sh => test_backend.sh} (64%) delete mode 100755 .ci/scripts/test_backend_macos.sh create mode 100644 backends/test/suite/conftest.py create mode 100644 backends/test/suite/generate_markdown_summary_json.py diff --git a/.ci/scripts/test_backend_linux.sh b/.ci/scripts/test_backend.sh similarity index 64% rename from .ci/scripts/test_backend_linux.sh rename to .ci/scripts/test_backend.sh index d230860875d..df98fb43372 100755 --- a/.ci/scripts/test_backend_linux.sh +++ b/.ci/scripts/test_backend.sh @@ -10,16 +10,26 @@ SUITE=$1 FLOW=$2 ARTIFACT_DIR=$3 -REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.csv" +REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.json" echo "Running backend test job for suite $SUITE, flow $FLOW." echo "Saving job artifacts to $ARTIFACT_DIR." -# The generic Linux job chooses to use base env, not the one setup by the image eval "$(conda shell.bash hook)" CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" +if [[ "$(uname)" == "Darwin" ]]; then + bash .ci/scripts/setup-conda.sh + eval "$(conda shell.bash hook)" + CONDA_RUN_CMD="${CONDA_RUN} --no-capture-output" + ${CONDA_RUN_CMD} pip install awscli==1.37.21 + IS_MACOS=1 +else + CONDA_RUN_CMD="" + IS_MACOS=0 +fi + export PYTHON_EXECUTABLE=python # CMake options to use, in addition to the defaults. @@ -50,11 +60,14 @@ if [[ "$FLOW" == *arm* ]]; then .ci/scripts/setup-arm-baremetal-tools.sh fi -# We need the runner to test the built library. -PYTHON_EXECUTABLE=python CMAKE_ARGS="$EXTRA_BUILD_ARGS" .ci/scripts/setup-linux.sh --build-tool cmake --build-mode Release --editable true +if [[ $IS_MACOS -eq 1 ]]; then + SETUP_SCRIPT=.ci/scripts/setup-macos.sh +else + SETUP_SCRIPT=.ci/scripts/setup-linux.sh +fi +CMAKE_ARGS="$EXTRA_BUILD_ARGS" ${CONDA_RUN_CMD} $SETUP_SCRIPT --build-tool cmake --build-mode Release --editable true EXIT_CODE=0 -python -m executorch.backends.test.suite.runner $SUITE --flow $FLOW --report "$REPORT_FILE" || EXIT_CODE=$? - +${CONDA_RUN_CMD} pytest -c /dev/nul -n auto backends/test/suite/$SUITE/ -m flow_$FLOW --json-report --json-report-file="$REPORT_FILE" || EXIT_CODE=$? # Generate markdown summary. -python -m executorch.backends.test.suite.generate_markdown_summary "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE +${CONDA_RUN_CMD} python -m executorch.backends.test.suite.generate_markdown_summary_json "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE diff --git a/.ci/scripts/test_backend_macos.sh b/.ci/scripts/test_backend_macos.sh deleted file mode 100755 index c31fd504b03..00000000000 --- a/.ci/scripts/test_backend_macos.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -set -eux - -SUITE=$1 -FLOW=$2 -ARTIFACT_DIR=$3 - -REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.csv" - -echo "Running backend test job for suite $SUITE, flow $FLOW." -echo "Saving job artifacts to $ARTIFACT_DIR." - -${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 - -bash .ci/scripts/setup-conda.sh -eval "$(conda shell.bash hook)" - -PYTHON_EXECUTABLE=python -${CONDA_RUN} --no-capture-output .ci/scripts/setup-macos.sh --build-tool cmake --build-mode Release - -EXIT_CODE=0 -${CONDA_RUN} --no-capture-output python -m executorch.backends.test.suite.runner $SUITE --flow $FLOW --report "$REPORT_FILE" || EXIT_CODE=$? - -# Generate markdown summary. -${CONDA_RUN} --no-capture-output python -m executorch.backends.test.suite.generate_markdown_summary "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE diff --git a/.github/workflows/_test_backend.yml b/.github/workflows/_test_backend.yml index 64ade2d84ad..42c00155d57 100644 --- a/.github/workflows/_test_backend.yml +++ b/.github/workflows/_test_backend.yml @@ -57,7 +57,7 @@ jobs: script: | set -eux - source .ci/scripts/test_backend_linux.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" + source .ci/scripts/test_backend.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" test-backend-macos: if: ${{ inputs.run-macos }} @@ -81,4 +81,4 @@ jobs: # This is needed to get the prebuilt PyTorch wheel from S3 ${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 - source .ci/scripts/test_backend_macos.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" + source .ci/scripts/test_backend.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index 43d4e16818f..734a6690fd2 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -11,6 +11,7 @@ import os import executorch.backends.test.suite.flow +import torch from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.runner import runner_main @@ -55,6 +56,11 @@ def get_test_flows() -> dict[str, TestFlow]: return _ALL_TEST_FLOWS +def dtype_to_str(dtype: torch.dtype) -> str: + # Strip off "torch." + return str(dtype)[6:] + + def load_tests(loader, suite, pattern): package_dir = os.path.dirname(__file__) discovered_suite = loader.discover( diff --git a/backends/test/suite/conftest.py b/backends/test/suite/conftest.py new file mode 100644 index 00000000000..70a97454c4e --- /dev/null +++ b/backends/test/suite/conftest.py @@ -0,0 +1,182 @@ +from typing import Any + +import pytest +import torch + +from executorch.backends.test.suite.flow import all_flows +from executorch.backends.test.suite.reporting import _sum_op_counts +from executorch.backends.test.suite.runner import run_test + + +def pytest_configure(config): + backends = set() + + for flow in all_flows().values(): + config.addinivalue_line( + "markers", + f"flow_{flow.name}: mark a test as testing the {flow.name} flow", + ) + + if flow.backend not in backends: + config.addinivalue_line( + "markers", + f"backend_{flow.backend}: mark a test as testing the {flow.backend} backend", + ) + backends.add(flow.backend) + + +class TestRunner: + def __init__(self, flow, test_name, test_base_name): + self._flow = flow + self._test_name = test_name + self._test_base_name = test_base_name + self._subtest = 0 + self._results = [] + + def lower_and_run_model( + self, + model: torch.nn.Module, + inputs: Any, + generate_random_test_inputs=True, + dynamic_shapes=None, + ): + run_summary = run_test( + model, + inputs, + self._flow, + self._test_name, + self._test_base_name, + self._subtest, + None, + generate_random_test_inputs=generate_random_test_inputs, + dynamic_shapes=dynamic_shapes, + ) + + self._subtest += 1 + self._results.append(run_summary) + + if not run_summary.result.is_success(): + if run_summary.result.is_backend_failure(): + raise RuntimeError("Test failure.") from run_summary.error + else: + # Non-backend failure indicates a bad test. Mark as skipped. + pytest.skip( + f"Test failed for reasons other than backend failure. Error: {run_summary.error}" + ) + + +@pytest.fixture( + params=[ + pytest.param( + f, + marks=[ + getattr(pytest.mark, f"flow_{f.name}"), + getattr(pytest.mark, f"backend_{f.backend}"), + ], + ) + for f in all_flows().values() + ], + ids=str, +) +def test_runner(request): + return TestRunner(request.param, request.node.name, request.node.originalname) + + +@pytest.hookimpl(optionalhook=True) +def pytest_json_runtest_metadata(item, call): + # Store detailed results in the test report under the metadata key. + metadata = {"subtests": []} + + if hasattr(item, "funcargs") and "test_runner" in item.funcargs: + runner_instance = item.funcargs["test_runner"] + + for record in runner_instance._results: + subtest_metadata = {} + + error_message = "" + if record.error is not None: + error_str = str(record.error) + if len(error_str) > 400: + error_message = error_str[:200] + "..." + error_str[-200:] + else: + error_message = error_str + + subtest_metadata["Test ID"] = record.name + subtest_metadata["Test Case"] = record.base_name + subtest_metadata["Subtest"] = record.subtest_index + subtest_metadata["Flow"] = record.flow + subtest_metadata["Result"] = record.result.to_short_str() + subtest_metadata["Result Detail"] = record.result.to_detail_str() + subtest_metadata["Error"] = error_message + subtest_metadata["Delegated"] = "True" if record.is_delegated() else "False" + subtest_metadata["Quantize Time (s)"] = ( + f"{record.quantize_time.total_seconds():.3f}" + if record.quantize_time + else None + ) + subtest_metadata["Lower Time (s)"] = ( + f"{record.lower_time.total_seconds():.3f}" + if record.lower_time + else None + ) + + for output_idx, error_stats in enumerate(record.tensor_error_statistics): + subtest_metadata[f"Output {output_idx} Error Max"] = ( + f"{error_stats.error_max:.3f}" + ) + subtest_metadata[f"Output {output_idx} Error MAE"] = ( + f"{error_stats.error_mae:.3f}" + ) + subtest_metadata[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}" + + subtest_metadata["Delegated Nodes"] = _sum_op_counts( + record.delegated_op_counts + ) + subtest_metadata["Undelegated Nodes"] = _sum_op_counts( + record.undelegated_op_counts + ) + if record.delegated_op_counts: + subtest_metadata["Delegated Ops"] = dict(record.delegated_op_counts) + if record.undelegated_op_counts: + subtest_metadata["Undelegated Ops"] = dict(record.undelegated_op_counts) + subtest_metadata["PTE Size (Kb)"] = ( + f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else "" + ) + + metadata["subtests"].append(subtest_metadata) + return metadata + + +@pytest.hookimpl(optionalhook=True) +def pytest_json_modifyreport(json_report): + # Post-process the report, mainly to populate metadata for crashed tests. The runtest_metadata + # hook doesn't seem to be called when there's a native crash, but xdist still creates a report + # entry. + + for test_data in json_report["tests"]: + if "metadata" not in test_data: + test_data["metadata"] = {} + metadata = test_data["metadata"] + if "subtests" not in metadata: + metadata["subtests"] = [] + subtests = metadata["subtests"] + + # Native crashes are recorded differently and won't have the full metadata. + # Pytest-xdist records crash info under the "???" key. + if "???" in test_data: + test_id = test_data["nodeid"].removeprefix("::") # Remove leading :: + test_base_id = test_id.split("[")[ + 0 + ] # Strip parameterization to get the base test case + params = test_id[len(test_base_id) + 1 : -1].split("-") + flow = params[0] + + crashed_test_meta = { + "Test ID": test_id, + "Test Case": test_base_id, + "Flow": flow, + "Result": "Fail", + "Result Detail": "Process Crash", + "Error": test_data["???"].get("longrepr", "Process crashed."), + } + subtests.append(crashed_test_meta) diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index a4b34fee98d..05fc760683d 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -44,6 +44,9 @@ class TestFlow: def should_skip_test(self, test_name: str) -> bool: return any(pattern in test_name for pattern in self.skip_patterns) + def __str__(self): + return self.name + def all_flows() -> dict[str, TestFlow]: flows = [] diff --git a/backends/test/suite/generate_markdown_summary_json.py b/backends/test/suite/generate_markdown_summary_json.py new file mode 100644 index 00000000000..4b6edc2a635 --- /dev/null +++ b/backends/test/suite/generate_markdown_summary_json.py @@ -0,0 +1,229 @@ +import argparse +import json + +from dataclasses import dataclass, field + + +@dataclass +class ResultCounts: + """ + Represents aggregated result counts for each status. + """ + + total: int = 0 + passes: int = 0 + fails: int = 0 + skips: int = 0 + by_detail: dict[str, int] = field(default_factory=lambda: {}) + + def add_row(self, result_value: str, result_detail: str) -> None: + """ + Update the result counts for the specified row. + """ + + self.total += 1 + + if result_value == "Pass": + self.passes += 1 + elif result_value == "Fail": + self.fails += 1 + elif result_value == "Skip": + self.skips += 1 + else: + raise RuntimeError(f"Unknown result value {result_value}") + + if result_detail: + if result_detail not in self.by_detail: + self.by_detail[result_detail] = 0 + + self.by_detail[result_detail] += 1 + + +@dataclass +class AggregatedSummary: + """ + Represents aggegrated summary data for the test run. + """ + + counts: ResultCounts + counts_by_params: dict[str, ResultCounts] + failed_tests: list[list[str]] + + +# +# A standalone script to generate a Markdown representation of a test report. +# This is primarily intended to be used with GitHub actions to generate a nice +# representation of the test results when looking at the action run. +# +# Usage: python executorch/backends/test/suite/generate_markdown_summary.py +# Markdown is written to stdout. +# + + +def aggregate_results(json_path: str) -> AggregatedSummary: + with open(json_path) as f: + data = json.load(f) + + # Count results and prepare data + counts = ResultCounts() + failed_tests = [] + counts_by_param = {} + + for test_data in data["tests"]: + result_meta = test_data["metadata"] + for subtest_meta in result_meta["subtests"]: + result = subtest_meta["Result"] + result_detail = subtest_meta.get("Result Detail") or "" + + counts.add_row(result, result_detail) + + test_id = subtest_meta["Test ID"] + base_test = subtest_meta["Test Case"] + params = test_id[len(base_test) + 1 : -1] + + if params: + if params not in counts_by_param: + counts_by_param[params] = ResultCounts() + counts_by_param[params].add_row(result, result_detail) + + if result.lower() == "fail": + failed_tests.append(subtest_meta) + + return AggregatedSummary( + counts=counts, + failed_tests=failed_tests, + counts_by_params=counts_by_param, + ) + + +def escape_for_markdown(text: str) -> str: + """ + Modify a string to properly display in a markdown table cell. + """ + if not text: + return text + + # Replace newlines with
tags + escaped = text.replace("\n", "
") + + # Escape backslashes. + escaped = escaped.replace("\\", "\\\\") + + # Escape pipe characters that would break table structure + escaped = escaped.replace("|", "\\|") + + return escaped + + +def generate_markdown(json_path: str, exit_code: int = 0): # noqa (C901) + results = aggregate_results(json_path) + + # Generate Summary section + print("# Summary\n") + total_excluding_skips = results.counts.passes + results.counts.fails + pass_fraction = results.counts.passes / total_excluding_skips + fail_fraction = results.counts.fails / total_excluding_skips + print( + f"- **Pass**: {results.counts.passes}/{total_excluding_skips} ({pass_fraction*100:.2f}%)" + ) + print( + f"- **Fail**: {results.counts.fails}/{total_excluding_skips} ({fail_fraction*100:.2f}%)" + ) + print(f"- **Skip**: {results.counts.skips}") + + if results.counts_by_params: + print("\n## Results by Parameters\n") + + if len(results.counts_by_params) > 0: + # Create table header + header_cols = ["Params", "Pass", "Fail", "Skip", "Pass %"] + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + # Create table rows + for params_str, counts in results.counts_by_params.items(): + row_values = [params_str] + + # Add parameter values + pass_fraction = counts.passes / (counts.passes + counts.fails) + + # Add count values + row_values.extend( + [ + str(counts.passes), + str(counts.fails), + str(counts.skips), + f"{pass_fraction*100:.2f}%", + ] + ) + + print("| " + " | ".join(row_values) + " |") + + print() + + print("## Failure Breakdown:") + total_rows_with_result_detail = sum(results.counts.by_detail.values()) + for detail, count in sorted(results.counts.by_detail.items()): + print(f"- **{detail}**: {count}/{total_rows_with_result_detail}") + + # Generate Failed Tests section + print("# Failed Tests\n") + print( + "To reproduce, run the following command from the root of the ExecuTorch repository:" + ) + print("```") + print('pytest -c /dev/nul backends/test/suite/ -k ""') + print("```") + if results.failed_tests: + header = build_header(results.failed_tests) + + escaped_header = [escape_for_markdown(col) for col in header.keys()] + print("| " + " | ".join(escaped_header) + " |") + print("|" + "|".join(["---"] * len(escaped_header)) + "|") + for rec in results.failed_tests: + row = build_row(rec, header) + print("| " + " | ".join(row) + " |") + else: + print("No failed tests.\n") + + +def build_header(data) -> dict[str, int]: + """ + Find the union of all keys and return a dict of header keys and indices. Try to preserve + ordering as much as possible. + """ + + keys = max(data, key=len) + + header = {k: i for (i, k) in enumerate(keys)} + + for rec in data: + keys = set(rec.keys()) + for k in keys: + if k not in header: + header[k] = len(header) + + return header + + +def build_row(rec, header: dict[str, int]) -> list[str]: + row = [""] * len(header) + for k, v in rec.items(): + row[header[k]] = escape_for_markdown(str(v)) + return row + + +def main(): + parser = argparse.ArgumentParser( + description="Generate a Markdown representation of a test report." + ) + parser.add_argument("json_path", help="Path to the test report CSV file.") + parser.add_argument( + "--exit-code", type=int, default=0, help="Exit code from the test process." + ) + args = parser.parse_args() + generate_markdown(args.json_path, args.exit_code) + + +if __name__ == "__main__": + main() diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index ea44275a463..6ac1a72bde6 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -5,136 +5,3 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - -import itertools -import os -import unittest -from typing import Any, Callable - -import torch -from executorch.backends.test.suite import get_test_flows -from executorch.backends.test.suite.context import get_active_test_context, TestContext -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.reporting import log_test_summary -from executorch.backends.test.suite.runner import run_test - - -DTYPES: list[torch.dtype] = [ - torch.float16, - torch.float32, -] - - -def load_tests(loader, suite, pattern): - package_dir = os.path.dirname(__file__) - discovered_suite = loader.discover( - start_dir=package_dir, pattern=pattern or "test_*.py" - ) - suite.addTests(discovered_suite) - return suite - - -def _create_test( - cls, - test_func: Callable, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, -): - dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}" - if use_dynamic_shapes: - test_name += "_dynamic_shape" - - def wrapped_test(self): - params = { - "dtype": dtype, - "use_dynamic_shapes": use_dynamic_shapes, - } - with TestContext(test_name, test_func.__name__, flow.name, params): - if flow.should_skip_test(test_name): - raise unittest.SkipTest( - f"Skipping test due to matching flow {flow.name} skip patterns" - ) - - test_func(self, flow, dtype, use_dynamic_shapes) - - wrapped_test._name = test_func.__name__ # type: ignore - wrapped_test._flow = flow # type: ignore - - setattr(cls, test_name, wrapped_test) - - -# Expand a test into variants for each registered flow. -def _expand_test(cls, test_name: str) -> None: - test_func = getattr(cls, test_name) - supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True) - dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False] - dtypes = getattr(test_func, "dtypes", DTYPES) - - for flow, dtype, use_dynamic_shapes in itertools.product( - get_test_flows().values(), dtypes, dynamic_shape_values - ): - _create_test(cls, test_func, flow, dtype, use_dynamic_shapes) - delattr(cls, test_name) - - -def model_test_cls(cls) -> Callable | None: - """Decorator for model tests. Handles generating test variants for each test flow and configuration.""" - for key in dir(cls): - if key.startswith("test_"): - _expand_test(cls, key) - return cls - - -def model_test_params( - supports_dynamic_shapes: bool = True, - dtypes: list[torch.dtype] | None = None, -) -> Callable: - """Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls.""" - - def inner_decorator(func: Callable) -> Callable: - func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore - - if dtypes is not None: - func.dtypes = dtypes # type: ignore - - return func - - return inner_decorator - - -def run_model_test( - model: torch.nn.Module, - inputs: tuple[Any], - flow: TestFlow, - dtype: torch.dtype, - dynamic_shapes: Any | None, -): - model = model.to(dtype) - context = get_active_test_context() - - # This should be set in the wrapped test. See _create_test above. - assert context is not None, "Missing test context." - - run_summary = run_test( - model, - inputs, - flow, - context.test_name, - context.test_base_name, - 0, # subtest_index - currently unused for model tests - context.params, - dynamic_shapes=dynamic_shapes, - ) - - log_test_summary(run_summary) - - if not run_summary.result.is_success(): - if run_summary.result.is_backend_failure(): - raise RuntimeError("Test failure.") from run_summary.error - else: - # Non-backend failure indicates a bad test. Mark as skipped. - raise unittest.SkipTest( - f"Test failed for reasons other than backend failure. Error: {run_summary.error}" - ) diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py index 69f6de4684f..2287b226c37 100644 --- a/backends/test/suite/models/test_torchaudio.py +++ b/backends/test/suite/models/test_torchaudio.py @@ -9,15 +9,11 @@ import unittest from typing import Tuple +import pytest import torch import torchaudio -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.models import ( - model_test_cls, - model_test_params, - run_model_test, -) +from executorch.backends.test.suite import dtype_to_str from torch.export import Dim # @@ -47,64 +43,68 @@ def forward( return x.transpose(0, 1) -@model_test_cls -class TorchAudio(unittest.TestCase): - @model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False) - def test_conformer( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - inner_model = torchaudio.models.Conformer( - input_dim=80, - num_heads=4, - ffn_dim=128, - num_layers=4, - depthwise_conv_kernel_size=31, - ) - model = PatchedConformer(inner_model) - lengths = torch.randint(1, 400, (10,)) +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize("use_dynamic_shapes", [False], ids=["static_shapes"]) +def test_conformer(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + model = PatchedConformer(inner_model).eval().to(dtype) + lengths = torch.randint(1, 400, (10,)) - encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask( - lengths - ) - inputs = ( - torch.rand(10, int(lengths.max()), 80), - encoder_padding_mask, - ) + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(lengths) + inputs = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask, + ) + + test_runner.lower_and_run_model(model, inputs) - run_model_test(model, inputs, flow, dtype, None) - - @model_test_params(dtypes=[torch.float32]) - def test_wav2letter( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchaudio.models.Wav2Letter() - inputs = (torch.randn(1, 1, 1024, dtype=dtype),) - dynamic_shapes = ( - { - "x": { - 2: Dim("d", min=900, max=1024), - } + +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize( + "use_dynamic_shapes", [False, True], ids=["static_shapes", "dynamic_shapes"] +) +def test_wav2letter(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchaudio.models.Wav2Letter().to(dtype) + inputs = (torch.randn(1, 1, 1024, dtype=dtype),) + dynamic_shapes = ( + { + "x": { + 2: Dim("d", min=900, max=1024), } - if use_dynamic_shapes - else None - ) - run_model_test(model, inputs, flow, dtype, dynamic_shapes) - - @unittest.skip("This model times out on all backends.") - def test_wavernn( - self, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, - ): - model = torchaudio.models.WaveRNN( + } + if use_dynamic_shapes + else None + ) + + test_runner.lower_and_run_model(model, inputs, dynamic_shapes=dynamic_shapes) + + +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize("use_dynamic_shapes", [False], ids=["static_shapes"]) +@unittest.skip("This model times out on all backends.") +def test_wavernn( + test_runner, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + model = ( + torchaudio.models.WaveRNN( upsample_scales=[5, 5, 8], n_classes=512, hop_length=200 - ).eval() - - # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward - inputs = ( - torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform - torch.randn(1, 1, 128, 64), # specgram ) + .eval() + .to(dtype) + ) + + # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward + inputs = ( + torch.randn(1, 1, (64 - 5 + 1) * 200).to(dtype), # waveform + torch.randn(1, 1, 128, 64).to(dtype), # specgram + ) - run_model_test(model, inputs, flow, dtype, None) + test_runner.lower_and_run_model(model, inputs) diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py index e69de80a871..58cf6a990d4 100644 --- a/backends/test/suite/models/test_torchvision.py +++ b/backends/test/suite/models/test_torchvision.py @@ -6,17 +6,12 @@ # pyre-unsafe -import unittest +import pytest import torch import torchvision +from executorch.backends.test.suite import dtype_to_str -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.models import ( - model_test_cls, - model_test_params, - run_model_test, -) from torch.export import Dim # @@ -25,148 +20,175 @@ # multiple size variants, one small or medium variant is used. # +PARAMETERIZE_DTYPE = pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +PARAMETERIZE_DYNAMIC_SHAPES = pytest.mark.parametrize( + "use_dynamic_shapes", [False, True], ids=["static_shapes", "dynamic_shapes"] +) +PARAMETERIZE_STATIC_ONLY = pytest.mark.parametrize( + "use_dynamic_shapes", [False], ids=["static_shapes"] +) + + +def _test_cv_model( + model: torch.nn.Module, + test_runner, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + model = model.eval().to(dtype) + + # Test a CV model that follows the standard conventions. + inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) -@model_test_cls -class TorchVision(unittest.TestCase): - def _test_cv_model( - self, - model: torch.nn.Module, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, - ): - # Test a CV model that follows the standard conventions. - inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) - - dynamic_shapes = ( - ( - { - 2: Dim("height", min=1, max=16) * 16, - 3: Dim("width", min=1, max=16) * 16, - }, - ) - if use_dynamic_shapes - else None + dynamic_shapes = ( + ( + { + 2: Dim("height", min=1, max=16) * 16, + 3: Dim("width", min=1, max=16) * 16, + }, ) + if use_dynamic_shapes + else None + ) + + test_runner.lower_and_run_model(model, inputs, dynamic_shapes=dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_alexnet(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.alexnet() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_convnext_small(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.convnext_small() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_densenet161(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.densenet161() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_efficientnet_b4(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.efficientnet_b4() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_efficientnet_v2_s(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.efficientnet_v2_s() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_googlenet(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.googlenet() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_inception_v3(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.inception_v3() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_STATIC_ONLY +def test_maxvit_t(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.maxvit_t() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mnasnet1_0(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mnasnet1_0() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mobilenet_v2(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mobilenet_v2() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mobilenet_v3_small(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mobilenet_v3_small() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_regnet_y_1_6gf(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.regnet_y_1_6gf() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_resnet50(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.resnet50() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_resnext50_32x4d(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.resnext50_32x4d() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_shufflenet_v2_x1_0(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.shufflenet_v2_x1_0() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_squeezenet1_1(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.squeezenet1_1() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_swin_v2_t(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.swin_v2_t() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_vgg11(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.vgg11() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_STATIC_ONLY +def test_vit_b_16(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.vit_b_16() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + - run_model_test(model, inputs, flow, dtype, dynamic_shapes) - - def test_alexnet( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.alexnet() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_convnext_small( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.convnext_small() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_densenet161( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.densenet161() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_efficientnet_b4( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.efficientnet_b4() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_efficientnet_v2_s( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.efficientnet_v2_s() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_googlenet( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.googlenet() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_inception_v3( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.inception_v3() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - @model_test_params(supports_dynamic_shapes=False) - def test_maxvit_t( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.maxvit_t() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mnasnet1_0( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mnasnet1_0() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mobilenet_v2( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mobilenet_v2() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mobilenet_v3_small( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mobilenet_v3_small() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_regnet_y_1_6gf( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.regnet_y_1_6gf() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_resnet50( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.resnet50() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_resnext50_32x4d( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.resnext50_32x4d() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_shufflenet_v2_x1_0( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.shufflenet_v2_x1_0() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_squeezenet1_1( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.squeezenet1_1() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_swin_v2_t( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.swin_v2_t() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_vgg11(self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool): - model = torchvision.models.vgg11() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - @model_test_params(supports_dynamic_shapes=False) - def test_vit_b_16( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.vit_b_16() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_wide_resnet50_2( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.wide_resnet50_2() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_wide_resnet50_2(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.wide_resnet50_2() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 9c550b3a49c..fa5ec2566d4 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -6,19 +6,14 @@ # pyre-unsafe -import copy import os +import sys import unittest from enum import Enum -from typing import Callable +import pytest import torch -from executorch.backends.test.suite import get_test_flows -from executorch.backends.test.suite.context import get_active_test_context, TestContext -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.reporting import log_test_summary -from executorch.backends.test.suite.runner import run_test def load_tests(loader, suite, pattern): @@ -66,112 +61,46 @@ def dtype_test(func): return func -# Class annotation for operator tests. This triggers the test framework to register -# the tests. -def operator_test(cls): - _create_tests(cls) - return cls - - -# Generate test cases for each backend flow. -def _create_tests(cls): - for key in dir(cls): - if key.startswith("test_"): - _expand_test(cls, key) +class OperatorTest(unittest.TestCase): + pass -# Expand a test into variants for each registered flow. -def _expand_test(cls, test_name: str): - test_func = getattr(cls, test_name) - for flow in get_test_flows().values(): - _create_test_for_backend(cls, test_func, flow) - delattr(cls, test_name) +class TestCaseShim: + def __init__(self, test_runner): + self._test_runner = test_runner + def _test_op(self, model, args, flow, generate_random_test_inputs=True): + self._test_runner.lower_and_run_model(model, args) -def _make_wrapped_test( - test_func: Callable, - test_name: str, - test_base_name: str, - flow: TestFlow, - params: dict | None = None, -): - def wrapped_test(self): - with TestContext(test_name, test_base_name, flow.name, params): - if flow.should_skip_test(test_name): - raise unittest.SkipTest( - f"Skipping test due to matching flow {flow.name} skip patterns" - ) - test_kwargs = copy.copy(params) or {} - test_kwargs["flow"] = flow +def wrap_test(original_func, test_type): + if test_type == TestType.STANDARD: - test_func(self, **test_kwargs) + def wrapped_func(test_runner): + shim = TestCaseShim(test_runner) + original_func(shim, test_runner._flow) - wrapped_test._name = test_name - wrapped_test._flow = flow + return wrapped_func + elif test_type == TestType.DTYPE: - return wrapped_test + @pytest.mark.parametrize("dtype", [torch.float32], ids=lambda s: str(s)[6:]) + def wrapped_func(test_runner, dtype): + shim = TestCaseShim(test_runner) + original_func(shim, test_runner._flow, dtype) + return wrapped_func + else: + raise ValueError() -def _create_test_for_backend( - cls, - test_func: Callable, - flow: TestFlow, -): - test_type = getattr(test_func, "test_type", TestType.STANDARD) - if test_type == TestType.STANDARD: - test_name = f"{test_func.__name__}_{flow.name}" - wrapped_test = _make_wrapped_test( - test_func, test_name, test_func.__name__, flow - ) - setattr(cls, test_name, wrapped_test) - elif test_type == TestType.DTYPE: - for dtype in DTYPES: - dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" - wrapped_test = _make_wrapped_test( - test_func, - test_name, - test_func.__name__, - flow, - {"dtype": dtype}, - ) - setattr(cls, test_name, wrapped_test) - else: - raise NotImplementedError(f"Unknown test type {test_type}.") +def operator_test(cls): + parent_module = sys.modules[cls.__module__] + for func_name in dir(cls): + if func_name.startswith("test"): + original_func = getattr(cls, func_name) + test_type = getattr(original_func, "test_type", TestType.STANDARD) + wrapped_func = wrap_test(original_func, test_type) + setattr(parent_module, func_name, wrapped_func) -class OperatorTest(unittest.TestCase): - def _test_op( - self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True - ): - context = get_active_test_context() - - # This should be set in the wrapped test. See _make_wrapped_test above. - assert context is not None, "Missing test context." - - run_summary = run_test( - model, - inputs, - flow, - context.test_name, - context.test_base_name, - context.subtest_index, - context.params, - generate_random_test_inputs=generate_random_test_inputs, - ) - - log_test_summary(run_summary) - - # This is reset when a new test is started - it creates the context per-test. - context.subtest_index = context.subtest_index + 1 - - if not run_summary.result.is_success(): - if run_summary.result.is_backend_failure(): - raise RuntimeError("Test failure.") from run_summary.error - else: - # Non-backend failure indicates a bad test. Mark as skipped. - raise unittest.SkipTest( - f"Test failed for reasons other than backend failure. Error: {run_summary.error}" - ) + return None diff --git a/backends/test/suite/operators/test_add.py b/backends/test/suite/operators/test_add.py index 6b21c3bf985..850e6f5132c 100644 --- a/backends/test/suite/operators/test_add.py +++ b/backends/test/suite/operators/test_add.py @@ -7,14 +7,8 @@ # pyre-unsafe +import pytest import torch -from executorch.backends.test.suite.flow import TestFlow - -from executorch.backends.test.suite.operators import ( - dtype_test, - operator_test, - OperatorTest, -) class Model(torch.nn.Module): @@ -31,55 +25,52 @@ def forward(self, x, y): return torch.add(x, y, alpha=self.alpha) -@operator_test -class Add(OperatorTest): - @dtype_test - def test_add_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), - ( - (torch.rand(2, 10) * 100).to(dtype), - (torch.rand(2, 10) * 100).to(dtype), - ), - flow, - ) - - def test_add_f32_bcast_first(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(5), - torch.randn(1, 5, 1, 5), - ), - flow, - ) - - def test_add_f32_bcast_second(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(4, 4, 2, 7), - torch.randn(2, 7), - ), - flow, - ) - - def test_add_f32_bcast_unary(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(5), - torch.randn(1, 1, 5), - ), - flow, - ) - - def test_add_f32_alpha(self, flow: TestFlow) -> None: - self._test_op( - ModelAlpha(alpha=2), - ( - torch.randn(1, 25), - torch.randn(1, 25), - ), - flow, - ) +@pytest.mark.parametrize("dtype", [torch.float32], ids=lambda s: str(s)[6:]) +def test_add_dtype(test_runner, dtype) -> None: + test_runner.lower_and_run_model( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100).to(dtype), + ), + ) + + +def test_add_f32_bcast_first(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5), + ), + ) + + +def test_add_f32_bcast_second(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7), + ), + ) + + +def test_add_f32_bcast_unary(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5), + ), + ) + + +def test_add_f32_alpha(test_runner) -> None: + test_runner.lower_and_run_model( + ModelAlpha(alpha=2), + ( + torch.randn(1, 25), + torch.randn(1, 25), + ), + ) diff --git a/backends/test/suite/operators/test_sub.py b/backends/test/suite/operators/test_sub.py index be7b871fdad..2243eb6ee71 100644 --- a/backends/test/suite/operators/test_sub.py +++ b/backends/test/suite/operators/test_sub.py @@ -6,7 +6,6 @@ # pyre-unsafe - import torch from executorch.backends.test.suite.flow import TestFlow diff --git a/pyproject.toml b/pyproject.toml index 00cae6de2e7..fbed875a824 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies=[ "pytest", "pytest-xdist", "pytest-rerunfailures==15.1", + "pytest-json-report", "pyyaml", "ruamel.yaml", "sympy",