Skip to content

Commit

Permalink
Use re to thoroughly remove prefix on "[ONNX] Run ONNX tests as part …
Browse files Browse the repository at this point in the history
…of standard run_test script"


<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 8277fb8</samp>

### Summary
🧪🛠️🚀

<!--
1.  🧪 for simplifying the test script for onnx
2.  🛠️ for refactoring the test module for onnx exporter api
3.  🚀 for enabling more testing features and improving performance
-->
This pull request refactors and improves the onnx testing code in `pytorch/pytorch`. It uses the common_utils module to simplify and standardize the test cases, adds a helper function to normalize the scope names of onnx nodes, enables running onnx tests separately and in subtests with the `run_test.py` script, and simplifies the `scripts/onnx/test.sh` script.

> _To test onnx without much hassle_
> _The code uses `run_test.py` as a vessel_
> _It normalizes the scopes_
> _And refactors the tropes_
> _Of the exporter api and its tassel_

### Walkthrough
*  Simplify onnx test script by using `run_test.py` with `--onnx` flag ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-0017f5b22ae1329acb0f54af8d9811c9b6180a72dac70d7a5b89d7c23c958198L44-R46))
*  Remove unused import of `unittest` from onnx exporter api test module ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L4))
*  Change base class of onnx exporter api test classes from `unittest.TestCase` to `common_utils.TestCase` to inherit useful methods and attributes ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L19-R18), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L29-R28), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L71-R70), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L147-R146))
*  Change import and usage of `TemporaryFileName` from global to class attribute of `common_utils.TestCase` to avoid name clashes and improve consistency ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L92-R91), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L110-R109), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L129-R128))
*  Change main entry point of onnx exporter api test module from `unittest.main` to `common_utils.run_tests` to make it compatible with `run_test.py` script ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L155-R154))
*  Add helper function to remove test environment prefix from scope name of onnx nodes in `test_utility_funs.py` to make scope name comparisons more robust and consistent ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7R31-R81))
*  Move and rename nested classes `_N` and `_M` from `test_node_scope` function to top level of `test_utility_funs.py` to fix class names for scope name normalization ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1067-R1130))
*  Use helper function to remove test environment prefix from scope name of onnx nodes in `test_node_scope` function and update expected scope names ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1119-R1149))
*  Update expected layer scope name to use new class name `_M` in `test_scope_of_constants_when_combined_by_cse_pass` and `test_scope_of_nodes_when_combined_by_cse_pass` functions ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1172-R1186), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1222-R1236))
*  Replace `onnx` test with `onnx_caffe2` test in `run_test.py` script to reflect test split ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7L127-R127))
*  Add two new onnx subtests to `run_test.py` script that were previously excluded due to high memory usage or long running time ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R154-R155))
*  Add list of onnx subtests that cannot run in parallel due to high memory usage and run them serially when `--onnx` flag is used ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R296-R301), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1120))
*  Add list of onnx tests that includes all subtests with onnx prefix and use it to filter selected tests based on `--onnx` flag ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R370), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1158-R1165))
*  Add `--onnx` flag to `run_test.py` script to allow user to only run onnx tests or exclude them ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R920-R928))



cc seemethere malfet pytorch/pytorch-dev-infra

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 17, 2023
2 parents 8277fb8 + 7f9aec4 commit dcbf7e2
Showing 1 changed file with 54 additions and 47 deletions.
101 changes: 54 additions & 47 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import io
import re
import warnings
from typing import Callable

Expand Down Expand Up @@ -41,44 +42,20 @@ def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.test_abc.<locals>.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "__main__.M"
>>> )
"M"
"""
prefixes_to_remove = ["test_utility_funs.", "__main__."]
prefixes_to_remove = ["test_utility_funs", "__main__"]
for prefix in prefixes_to_remove:
scope_name = scope_name.replace(prefix, "")
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
return scope_name


class _N(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x)


class _M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
)
self.gelu1 = torch.nn.GELU()
self.gelu2 = torch.nn.GELU()
self.relu = _N()

def forward(self, x, y, z):
res1 = self.gelu1(x)
res2 = self.gelu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.relu(z)


class _BaseTestCase(pytorch_test_common.ExportTestCase):
def _model_to_graph(
self,
Expand Down Expand Up @@ -1115,19 +1092,45 @@ def forward(self, x):
self.assertIn(ln_node.attribute[1], expected_ln_attrs)

def test_node_scope(self):
class N(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x)

class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
)
self.gelu1 = torch.nn.GELU()
self.gelu2 = torch.nn.GELU()
self.relu = N()

def forward(self, x, y, z):
res1 = self.gelu1(x)
res2 = self.gelu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.relu(z)

x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)

model = _M(3)
model = M(3)
expected_scope_names = {
"_M::/torch.nn.modules.activation.GELU::gelu1",
"_M::/torch.nn.modules.activation.GELU::gelu2",
"_M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"_M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"_M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"_M::/_N::relu/torch.nn.modules.activation.ReLU::relu",
"_M::",
"M::/torch.nn.modules.activation.GELU::gelu1",
"M::/torch.nn.modules.activation.GELU::gelu2",
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
"M::",
}

graph, _, _ = self._model_to_graph(
Expand Down Expand Up @@ -1181,9 +1184,8 @@ def forward(self, x):
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}._M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_constant_scope_name = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
for i in range(layer_num)
Expand All @@ -1192,7 +1194,9 @@ def forward(self, x):
constant_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Constant":
constant_scope_names.append(node.scopeName())
constant_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(constant_scope_names, expected_constant_scope_name)

def test_scope_of_nodes_when_combined_by_cse_pass(self):
Expand Down Expand Up @@ -1231,9 +1235,8 @@ def forward(self, x):
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}._M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_add_scope_names = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
]
Expand All @@ -1246,9 +1249,13 @@ def forward(self, x):
mul_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Add":
add_scope_names.append(node.scopeName())
add_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
elif node.kind() == "onnx::Mul":
mul_scope_names.append(node.scopeName())
mul_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(add_scope_names, expected_add_scope_names)
self.assertEqual(mul_scope_names, expected_mul_scope_names)

Expand Down

0 comments on commit dcbf7e2

Please sign in to comment.