Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] STFT Support #92087

Closed
wants to merge 12 commits into from
207 changes: 125 additions & 82 deletions test/onnx/test_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,20 @@
Note:

When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
ALLOWLIST_OP lists. See "Modify this section"
TESTED_OPS lists. See "Modify this section"

"""

from __future__ import annotations

import copy
import dataclasses
import unittest
from typing import (
AbstractSet,
Callable,
Collection,
Iterable,
Optional,
Sequence,
Tuple,
Union,
)
import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union

import onnx_test_common
import parameterized

import torch
from torch.onnx import _constants
Expand Down Expand Up @@ -80,7 +75,7 @@
torch.complex128,
)

SUPPORTED_DTYPES = (
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
Expand Down Expand Up @@ -111,6 +106,7 @@ class DecorateMeta:
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None

def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
Expand Down Expand Up @@ -155,6 +151,7 @@ def dont_care(
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.

Expand All @@ -166,6 +163,8 @@ def dont_care(
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
dont_care is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
Expand All @@ -174,6 +173,7 @@ def dont_care(
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)


Expand All @@ -184,6 +184,7 @@ def fixme(
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.

Expand All @@ -193,6 +194,8 @@ def fixme(
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
Expand All @@ -201,6 +204,7 @@ def fixme(
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)


Expand Down Expand Up @@ -293,15 +297,16 @@ def reason_flaky() -> str:
# alphabetically.
#
# For example, to add a test for torch.ceil:
# 1. Add "ceil" to ALLOWLIST_OP then run pytest.
# 1. Add "ceil" to TESTED_OPS then run pytest.
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.

# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
# Ops to be tested for numerical consistency between onnx and pytorch
ALLOWLIST_OP: AbstractSet[str] = frozenset(
TESTED_OPS: frozenset[str] = frozenset(
[
"ceil",
"sqrt",
"stft",
"t",
]
)
urinieto marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -323,14 +328,26 @@ def reason_flaky() -> str:
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
dont_care("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
dont_care("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
)
# fmt: on

SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
dont_care(
"stft",
reason="ONNX STFT does not support complex results",
matcher=lambda sample: sample.kwargs.get("return_complex") is True,
),
)

# END OF SECTION TO MODIFY #####################################################


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"


class SingleOpModel(torch.nn.Module):
Expand All @@ -345,80 +362,106 @@ def forward(self, *args):
return self.operator(*args, **self.kwargs)


def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in OP_WITH_SKIPPED_SUBTESTS:
return None
for decorator_meta in SKIP_SUBTESTS:
# Linear search on SKIP_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.matcher(sample):
return decorator_meta.reason
return None


def _get_test_class_name(cls, num, params_dict) -> str:
del cls # unused
del num # unused
return params_dict["name"]


@parameterized.parameterized_class(
[
{
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset,
}
for opset in TESTED_OPSETS
],
class_name_func=_get_test_class_name,
)
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"""Test output consistency between exported ONNX models and PyTorch eager mode.

This is a parameterized test suite.
"""

@classmethod
def create_test_base(cls, opset: int):
"""Returns the base test method for the given opset."""

def _output_match_base(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)

for (i, cpu_sample) in enumerate(samples):
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
opset=opset,
sample_num=i,
input=repr(cpu_sample.input),
args=repr(cpu_sample.args),
kwargs=repr(cpu_sample.kwargs),
):
model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()

# Run the test
inputs = (cpu_sample.input, *cpu_sample.args)

self.run_test(model, inputs)

test_name = f"test_output_match_opset_{opset}"
_output_match_base.__name__ = test_name
return _output_match_base

@classmethod
def parameterize_opsets(cls, opsets: Sequence[int]):
"""Parametrizes the TestOnnxModelOutputConsistency class with the given opsets."""
for opset in opsets:
# Generate a test method for each opset
base_method = cls.create_test_base(opset)
# Important to rename the test method so that DecorateInfo can find it
test_name = base_method.__name__

# Update the ops to skip in the OpInfo database
add_decorate_info(
OPS_DB,
cls.__name__,
test_name,
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)

# Create parameterized tests for each op
filtered_ops = [op for op in OPS_DB if op.name in ALLOWLIST_OP]
decorated = common_device_type.ops(
filtered_ops,
allowed_dtypes=SUPPORTED_DTYPES,
)(base_method)

setattr(cls, test_name, decorated)


TestOnnxModelOutputConsistency.parameterize_opsets(TESTED_OPSETS)
common_device_type.instantiate_device_type_tests(
TestOnnxModelOutputConsistency, globals(), only_for="cpu"
)
opset_version = -1

@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)

for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs

with self.subTest(
opset=self.opset_version,
sample_num=i,
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
if skip_reason is not None:
# Cannot use self.skip because pytest would skip the entire test
warnings.warn(f"skipped sample {i}. Reason: {skip_reason}")
continue

model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()

if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
# The current most relaxed values are for aten::stft
rtol = 1e-5
atol = 2e-5
elif dtype == torch.float64:
# The current most relaxed values are for aten::stft
rtol = 1e-5
atol = 2e-5
else:
rtol = None
atol = None
# Run the test
self.run_test(model, inputs, rtol=rtol, atol=atol)


for opset in TESTED_OPSETS:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match",
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
common_device_type.instantiate_device_type_tests(
globals()[test_class_name], globals(), only_for="cpu"
)


if __name__ == "__main__":
Expand Down