Skip to content

Commit

Permalink
[ONNX] Run ONNX tests as part of standard run_test script
Browse files Browse the repository at this point in the history
ghstack-source-id: 82ab204175e7a0dcd81405ad4d9913fef481a7c1
Pull Request resolved: #99215
  • Loading branch information
BowenBao committed Apr 17, 2023
1 parent 0711bff commit fa5dd0f
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 84 deletions.
41 changes: 2 additions & 39 deletions scripts/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,46 +41,9 @@ args+=("--cov-report")
args+=("xml:test/coverage.xml")
args+=("--cov-append")

args_parallel=()
if [[ $PARALLEL == 1 ]]; then
args_parallel+=("-n")
args_parallel+=("auto")
fi

# onnxruntime only support py3
# "Python.h" not found in py2, needed by TorchScript custom op compilation.
if [[ "${SHARD_NUMBER}" == "1" ]]; then
# These exclusions are for tests that take a long time / a lot of GPU
# memory to run; they should be passing (and you will test them if you
# run them locally
pytest "${args[@]}" "${args_parallel[@]}" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_models.py" \
--ignore "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
"${test_paths[@]}"

# Heavy memory usage tests that cannot run in parallel.
pytest "${args[@]}" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_utility_funs.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "not TestModelsONNXRuntime"
fi

if [[ "${SHARD_NUMBER}" == "2" ]]; then
# Heavy memory usage tests that cannot run in parallel.
# TODO(#79802): Parameterize test_models.py
pytest "${args[@]}" \
"$top_dir/test/onnx/test_models.py" \
"$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "TestModelsONNXRuntime"

pytest "${args[@]}" "${args_parallel[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
time python "${top_dir}/test/run_test.py" --onnx --shard "$SHARD_NUMBER" 2 --verbose

if [[ "$SHARD_NUMBER" == "2" ]]; then
# xdoctests on onnx
xdoctest torch.onnx --style=google --options="+IGNORE_WHITESPACE"
fi
Expand Down
17 changes: 8 additions & 9 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Owner(s): ["module: onnx"]
import io
import logging
import unittest

import onnx

Expand All @@ -16,7 +15,7 @@
ResolvedExportOptions,
)

from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal import common_utils


class SampleModel(torch.nn.Module):
Expand All @@ -26,7 +25,7 @@ def forward(self, x):
return (y, z)


class TestExportOptionsAPI(unittest.TestCase):
class TestExportOptionsAPI(common_utils.TestCase):
def test_opset_version_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)
Expand Down Expand Up @@ -68,7 +67,7 @@ def test_logger_explicit(self):
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))


class TestDynamoExportAPI(unittest.TestCase):
class TestDynamoExportAPI(common_utils.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
Expand All @@ -89,7 +88,7 @@ def test_export_with_options(self):
)

def test_save_to_file_default_serializer(self):
with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)

Expand All @@ -107,7 +106,7 @@ def serialize(
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
Expand All @@ -126,7 +125,7 @@ def serialize(
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
Expand All @@ -144,12 +143,12 @@ def test_raise_on_invalid_save_argument_type(self):
export_output.model_proto


class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
class TestProtobufExportOutputSerializerAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]


if __name__ == "__main__":
unittest.main()
common_utils.run_tests()
91 changes: 56 additions & 35 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import io
import re
import warnings
from typing import Callable

Expand All @@ -28,6 +29,33 @@
from verify import verify


def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
"""Remove test environment prefix added to module.
Remove prefix to normalize scope names, since different test environments add
prefixes with slight differences.
Example:
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.test_abc.<locals>.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "__main__.M"
>>> )
"M"
"""
prefixes_to_remove = ["test_utility_funs", "__main__"]
for prefix in prefixes_to_remove:
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
return scope_name


class _BaseTestCase(pytorch_test_common.ExportTestCase):
def _model_to_graph(
self,
Expand Down Expand Up @@ -1096,43 +1124,32 @@ def forward(self, x, y, z):

model = M(3)
expected_scope_names = {
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.N::relu/"
"torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::",
"M::/torch.nn.modules.activation.GELU::gelu1",
"M::/torch.nn.modules.activation.GELU::gelu2",
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
"M::",
}

graph, _, _ = self._model_to_graph(
model, (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(node.scopeName(), expected_scope_names)

expected_torch_script_scope_names = {
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.M::/test_utility_funs.N::relu/torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.M::",
}
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)

graph, _, _ = self._model_to_graph(
torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(node.scopeName(), expected_torch_script_scope_names)
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)

def test_scope_of_constants_when_combined_by_cse_pass(self):
layer_num = 3
Expand Down Expand Up @@ -1167,9 +1184,8 @@ def forward(self, x):
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_constant_scope_name = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
for i in range(layer_num)
Expand All @@ -1178,7 +1194,9 @@ def forward(self, x):
constant_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Constant":
constant_scope_names.append(node.scopeName())
constant_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(constant_scope_names, expected_constant_scope_name)

def test_scope_of_nodes_when_combined_by_cse_pass(self):
Expand Down Expand Up @@ -1217,9 +1235,8 @@ def forward(self, x):
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_add_scope_names = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
]
Expand All @@ -1232,9 +1249,13 @@ def forward(self, x):
mul_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Add":
add_scope_names.append(node.scopeName())
add_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
elif node.kind() == "onnx::Mul":
mul_scope_names.append(node.scopeName())
mul_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(add_scope_names, expected_add_scope_names)
self.assertEqual(mul_scope_names, expected_mul_scope_names)

Expand Down
29 changes: 28 additions & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def skip_test_p(name: str) -> bool:
"fx", # executed by test_fx.py
"jit", # executed by test_jit.py
"mobile",
"onnx",
"onnx_caffe2",
"package", # executed by test_package.py
"quantization", # executed by test_quantization.py
"autograd", # executed by test_autograd.py
Expand All @@ -151,6 +151,8 @@ def skip_test_p(name: str) -> bool:
"distributed/test_c10d_spawn",
"distributions/test_transforms",
"distributions/test_utils",
"onnx/test_pytorch_onnx_onnxruntime_cuda",
"onnx/test_models",
],
extra_tests=[
"test_cpp_extensions_aot_ninja",
Expand Down Expand Up @@ -291,6 +293,12 @@ def skip_test_p(name: str) -> bool:
"test_native_mha", # OOM
"test_module_hooks", # OOM
]
# A subset of onnx tests that cannot run in parallel due to high memory usage.
ONNX_SERIAL_LIST = [
"onnx/test_models",
"onnx/test_models_quantized_onnxruntime",
"onnx/test_models_onnxruntime",
]

# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
CORE_TEST_LIST = [
Expand Down Expand Up @@ -359,6 +367,7 @@ def skip_test_p(name: str) -> bool:

DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith("distributed")]
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]

TESTS_REQUIRING_LAPACK = [
"distributions/test_constraints",
Expand Down Expand Up @@ -907,6 +916,15 @@ def parse_args():
help="Only run core tests, or tests that validate PyTorch's ops, modules,"
"and autograd. They are defined by CORE_TEST_LIST.",
)
parser.add_argument(
"--onnx",
"--onnx",
action="store_true",
help=(
"Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
"If this flag is not present, we will exclude ONNX tests."
),
)
parser.add_argument(
"-pt",
"--pytest",
Expand Down Expand Up @@ -1099,6 +1117,7 @@ def must_serial(file: str) -> bool:
or file in RUN_PARALLEL_BLOCKLIST
or file in CI_SERIAL_LIST
or file in JIT_EXECUTOR_TESTS
or file in ONNX_SERIAL_LIST
)


Expand Down Expand Up @@ -1136,6 +1155,14 @@ def get_selected_tests(options):
# Exclude all mps tests otherwise
options.exclude.extend(["test_mps", "test_metal"])

# Filter to only run onnx tests when --onnx option is specified
onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
if options.onnx:
selected_tests = onnx_tests
else:
# Exclude all onnx tests otherwise
options.exclude.extend(onnx_tests)

# process reordering
if options.bring_to_front:
to_front = set(options.bring_to_front)
Expand Down

0 comments on commit fa5dd0f

Please sign in to comment.