-
Notifications
You must be signed in to change notification settings - Fork 21.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAINT: pytorchify torch._numpy tests: core/ and fft/ (#109815)
1. Inherit from TestCase 2. Use pytorch parametrization 3. Use unittest.expectedFailure to mark xfails, also unittest skips All this to make pytest-less invocation work: $ python test/torch_np/test_basic.py cross-ref #109593, #109718, #109775 Pull Request resolved: #109815 Approved by: https://github.com/lezcano
- Loading branch information
1 parent
8140494
commit 132a138
Showing
16 changed files
with
1,369 additions
and
1,179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pathlib | ||
import sys | ||
import textwrap | ||
|
||
|
||
def check(path): | ||
"""Check a test file for common issues with pytest->pytorch conversion.""" | ||
print(path.name) | ||
print("=" * len(path.name), "\n") | ||
|
||
src = path.read_text().split("\n") | ||
for num, line in enumerate(src): | ||
if is_comment(line): | ||
continue | ||
|
||
# module level test functions | ||
if line.startswith("def test"): | ||
report_violation(line, num, header="Module-level test function") | ||
|
||
# test classes must inherit from TestCase | ||
if line.startswith("class Test") and "TestCase" not in line: | ||
report_violation( | ||
line, num, header="Test class does not inherit from TestCase" | ||
) | ||
|
||
# last vestiges of pytest-specific stuff | ||
if "pytest.mark" in line: | ||
report_violation(line, num, header="pytest.mark.something") | ||
|
||
for part in ["pytest.xfail", "pytest.skip", "pytest.param"]: | ||
if part in line: | ||
report_violation(line, num, header=f"stray {part}") | ||
|
||
if textwrap.dedent(line).startswith("@parametrize"): | ||
# backtrack to check | ||
nn = num | ||
for nn in range(num, -1, -1): | ||
ln = src[nn] | ||
if "class Test" in ln: | ||
# hack: large indent => likely an inner class | ||
if len(ln) - len(ln.lstrip()) < 8: | ||
break | ||
else: | ||
report_violation(line, num, "off-class parametrize") | ||
if not src[nn - 1].startswith("@instantiate_parametrized_tests"): | ||
# breakpoint() | ||
report_violation( | ||
line, num, f"missing instantiation of parametrized tests in {ln}?" | ||
) | ||
|
||
|
||
def is_comment(line): | ||
return textwrap.dedent(line).startswith("#") | ||
|
||
|
||
def report_violation(line, lineno, header): | ||
print(f">>>> line {lineno} : {header}\n {line}\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
argv = sys.argv | ||
if len(argv) != 2: | ||
raise ValueError("Usage : python check_tests_conform path/to/file/or/dir") | ||
|
||
path = pathlib.Path(argv[1]) | ||
|
||
if path.is_dir(): | ||
# run for all files in the directory (no subdirs) | ||
for this_path in path.glob("test*.py"): | ||
# breakpoint() | ||
check(this_path) | ||
else: | ||
check(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
132a138
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted #109815 on behalf of https://github.com/PaliC due to causing various slow tests to fail (comment)