Skip to content

Commit

Permalink
MAINT: pytorchify torch._numpy tests: core/ and fft/ (#109815)
Browse files Browse the repository at this point in the history
1. Inherit from TestCase
2. Use pytorch parametrization
3. Use unittest.expectedFailure to mark xfails, also unittest skips

All this to make pytest-less invocation work:

$ python test/torch_np/test_basic.py

cross-ref #109593, #109718, #109775

Pull Request resolved: #109815
Approved by: https://github.com/lezcano
  • Loading branch information
ev-br authored and pytorchmergebot committed Sep 26, 2023
1 parent 8140494 commit 132a138
Show file tree
Hide file tree
Showing 16 changed files with 1,369 additions and 1,179 deletions.
73 changes: 73 additions & 0 deletions test/torch_np/check_tests_conform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pathlib
import sys
import textwrap


def check(path):
"""Check a test file for common issues with pytest->pytorch conversion."""
print(path.name)
print("=" * len(path.name), "\n")

src = path.read_text().split("\n")
for num, line in enumerate(src):
if is_comment(line):
continue

# module level test functions
if line.startswith("def test"):
report_violation(line, num, header="Module-level test function")

# test classes must inherit from TestCase
if line.startswith("class Test") and "TestCase" not in line:
report_violation(
line, num, header="Test class does not inherit from TestCase"
)

# last vestiges of pytest-specific stuff
if "pytest.mark" in line:
report_violation(line, num, header="pytest.mark.something")

for part in ["pytest.xfail", "pytest.skip", "pytest.param"]:
if part in line:
report_violation(line, num, header=f"stray {part}")

if textwrap.dedent(line).startswith("@parametrize"):
# backtrack to check
nn = num
for nn in range(num, -1, -1):
ln = src[nn]
if "class Test" in ln:
# hack: large indent => likely an inner class
if len(ln) - len(ln.lstrip()) < 8:
break
else:
report_violation(line, num, "off-class parametrize")
if not src[nn - 1].startswith("@instantiate_parametrized_tests"):
# breakpoint()
report_violation(
line, num, f"missing instantiation of parametrized tests in {ln}?"
)


def is_comment(line):
return textwrap.dedent(line).startswith("#")


def report_violation(line, lineno, header):
print(f">>>> line {lineno} : {header}\n {line}\n")


if __name__ == "__main__":
argv = sys.argv
if len(argv) != 2:
raise ValueError("Usage : python check_tests_conform path/to/file/or/dir")

path = pathlib.Path(argv[1])

if path.is_dir():
# run for all files in the directory (no subdirs)
for this_path in path.glob("test*.py"):
# breakpoint()
check(this_path)
else:
check(path)
32 changes: 21 additions & 11 deletions test/torch_np/numpy_tests/core/test_dlpack.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
# Owner(s): ["module: dynamo"]

import functools
import sys

from unittest import expectedFailure as xfail, skipIf as skipif

import pytest

import torch

import torch._numpy as np
from torch._numpy.testing import assert_array_equal
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)

skip = functools.partial(skipif, True)

IS_PYPY = False


class TestDLPack:
@pytest.mark.xfail(reason="pytorch seems to handle refcounts differently")
@pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
@instantiate_parametrized_tests
class TestDLPack(TestCase):
@xfail # (reason="pytorch seems to handle refcounts differently")
@skipif(IS_PYPY, reason="PyPy can't get refcounts.")
def test_dunder_dlpack_refcount(self):
x = np.arange(5)
y = x.__dlpack__()
assert sys.getrefcount(x) == 3
del y
assert sys.getrefcount(x) == 2

@pytest.mark.xfail(reason="pytorch does not raise")
@xfail # (reason="pytorch does not raise")
def test_dunder_dlpack_stream(self):
x = np.arange(5)
x.__dlpack__(stream=None)

with pytest.raises(RuntimeError):
x.__dlpack__(stream=1)

@pytest.mark.xfail(reason="pytorch seems to handle refcounts differently")
@pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
@xfail # (reason="pytorch seems to handle refcounts differently")
@skipif(IS_PYPY, reason="PyPy can't get refcounts.")
def test_from_dlpack_refcount(self):
x = np.arange(5)
y = np.from_dlpack(x)
assert sys.getrefcount(x) == 3
del y
assert sys.getrefcount(x) == 2

@pytest.mark.parametrize(
@parametrize(
"dtype",
[
np.int8,
Expand Down Expand Up @@ -79,7 +91,7 @@ def test_non_contiguous(self):
y5 = np.diagonal(x).copy()
assert_array_equal(y5, np.from_dlpack(y5))

@pytest.mark.parametrize("ndim", range(33))
@parametrize("ndim", range(33))
def test_higher_dims(self, ndim):
shape = (1,) * ndim
x = np.zeros(shape, dtype=np.float64)
Expand All @@ -103,7 +115,7 @@ def test_dlpack_destructor_exception(self):
with pytest.raises(RuntimeError):
self.dlpack_deleter_exception()

@pytest.mark.skip(reason="no readonly arrays in pytorch")
@skip(reason="no readonly arrays in pytorch")
def test_readonly(self):
x = np.arange(5)
x.flags.writeable = False
Expand All @@ -127,6 +139,4 @@ def test_to_torch(self):


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
Loading

1 comment on commit 132a138

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #109815 on behalf of https://github.com/PaliC due to causing various slow tests to fail (comment)

Please sign in to comment.