Skip to content

Commit

Permalink
[ONNX] STFT Support (#92087)
Browse files Browse the repository at this point in the history
This PR addresses issue [#81075](#81075),  making `torch.stft` compatible with ONNX Opset 17's STFT operator.

The conversion works for _most_ of `torch.stft` functionality:

- Batched or unbatched inputs
- Normalization
- Pre-computed windows
- Rectangular windows
- One-sided returns
- Window centering (implicitly supported)

What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (#86746).

Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`.

Note that there is already a draft PR to address this (#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing).
Pull Request resolved: #92087
Approved by: https://github.com/justinchuby
  • Loading branch information
urinieto authored and pytorchmergebot committed Mar 10, 2023
1 parent 69d3fa2 commit 5f89d14
Show file tree
Hide file tree
Showing 3 changed files with 491 additions and 88 deletions.
207 changes: 125 additions & 82 deletions test/onnx/test_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,20 @@
Note:
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
ALLOWLIST_OP lists. See "Modify this section"
TESTED_OPS lists. See "Modify this section"
"""

from __future__ import annotations

import copy
import dataclasses
import unittest
from typing import (
AbstractSet,
Callable,
Collection,
Iterable,
Optional,
Sequence,
Tuple,
Union,
)
import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union

import onnx_test_common
import parameterized

import torch
from torch.onnx import _constants
Expand Down Expand Up @@ -80,7 +75,7 @@
torch.complex128,
)

SUPPORTED_DTYPES = (
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
Expand Down Expand Up @@ -111,6 +106,7 @@ class DecorateMeta:
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None

def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
Expand Down Expand Up @@ -155,6 +151,7 @@ def dont_care(
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Expand All @@ -166,6 +163,8 @@ def dont_care(
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
dont_care is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
Expand All @@ -174,6 +173,7 @@ def dont_care(
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)


Expand All @@ -184,6 +184,7 @@ def fixme(
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Expand All @@ -193,6 +194,8 @@ def fixme(
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
Expand All @@ -201,6 +204,7 @@ def fixme(
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)


Expand Down Expand Up @@ -293,15 +297,16 @@ def reason_flaky() -> str:
# alphabetically.
#
# For example, to add a test for torch.ceil:
# 1. Add "ceil" to ALLOWLIST_OP then run pytest.
# 1. Add "ceil" to TESTED_OPS then run pytest.
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.

# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
# Ops to be tested for numerical consistency between onnx and pytorch
ALLOWLIST_OP: AbstractSet[str] = frozenset(
TESTED_OPS: frozenset[str] = frozenset(
[
"ceil",
"sqrt",
"stft",
"t",
]
)
Expand All @@ -323,14 +328,26 @@ def reason_flaky() -> str:
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
dont_care("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
dont_care("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
)
# fmt: on

SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
dont_care(
"stft",
reason="ONNX STFT does not support complex results",
matcher=lambda sample: sample.kwargs.get("return_complex") is True,
),
)

# END OF SECTION TO MODIFY #####################################################


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"


class SingleOpModel(torch.nn.Module):
Expand All @@ -345,80 +362,106 @@ def forward(self, *args):
return self.operator(*args, **self.kwargs)


def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in OP_WITH_SKIPPED_SUBTESTS:
return None
for decorator_meta in SKIP_SUBTESTS:
# Linear search on SKIP_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.matcher(sample):
return decorator_meta.reason
return None


def _get_test_class_name(cls, num, params_dict) -> str:
del cls # unused
del num # unused
return params_dict["name"]


@parameterized.parameterized_class(
[
{
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset,
}
for opset in TESTED_OPSETS
],
class_name_func=_get_test_class_name,
)
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"""Test output consistency between exported ONNX models and PyTorch eager mode.
This is a parameterized test suite.
"""

@classmethod
def create_test_base(cls, opset: int):
"""Returns the base test method for the given opset."""

def _output_match_base(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)

for (i, cpu_sample) in enumerate(samples):
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
opset=opset,
sample_num=i,
input=repr(cpu_sample.input),
args=repr(cpu_sample.args),
kwargs=repr(cpu_sample.kwargs),
):
model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()

# Run the test
inputs = (cpu_sample.input, *cpu_sample.args)

self.run_test(model, inputs)

test_name = f"test_output_match_opset_{opset}"
_output_match_base.__name__ = test_name
return _output_match_base

@classmethod
def parameterize_opsets(cls, opsets: Sequence[int]):
"""Parametrizes the TestOnnxModelOutputConsistency class with the given opsets."""
for opset in opsets:
# Generate a test method for each opset
base_method = cls.create_test_base(opset)
# Important to rename the test method so that DecorateInfo can find it
test_name = base_method.__name__

# Update the ops to skip in the OpInfo database
add_decorate_info(
OPS_DB,
cls.__name__,
test_name,
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)

# Create parameterized tests for each op
filtered_ops = [op for op in OPS_DB if op.name in ALLOWLIST_OP]
decorated = common_device_type.ops(
filtered_ops,
allowed_dtypes=SUPPORTED_DTYPES,
)(base_method)

setattr(cls, test_name, decorated)


TestOnnxModelOutputConsistency.parameterize_opsets(TESTED_OPSETS)
common_device_type.instantiate_device_type_tests(
TestOnnxModelOutputConsistency, globals(), only_for="cpu"
)
opset_version = -1

@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)

for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs

with self.subTest(
opset=self.opset_version,
sample_num=i,
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
if skip_reason is not None:
# Cannot use self.skip because pytest would skip the entire test
warnings.warn(f"skipped sample {i}. Reason: {skip_reason}")
continue

model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()

if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
# The current most relaxed values are for aten::stft
rtol = 1e-5
atol = 2e-5
elif dtype == torch.float64:
# The current most relaxed values are for aten::stft
rtol = 1e-5
atol = 2e-5
else:
rtol = None
atol = None
# Run the test
self.run_test(model, inputs, rtol=rtol, atol=atol)


for opset in TESTED_OPSETS:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match",
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
common_device_type.instantiate_device_type_tests(
globals()[test_class_name], globals(), only_for="cpu"
)


if __name__ == "__main__":
Expand Down

0 comments on commit 5f89d14

Please sign in to comment.