Skip to content

Commit

Permalink
Enable UFMT format on test/test_throughput_benchmark.py test/test_typ…
Browse files Browse the repository at this point in the history
…e_hints.py test/test_type_info.py (pytorch#125906)

Fixes some files in pytorch#123062

Run lintrunner on files:
test/test_throughput_benchmark.py
test/test_type_hints.py
test/test_type_info.py

```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: pytorch#125906
Approved by: https://github.com/shink, https://github.com/soulitzer, https://github.com/malfet
  • Loading branch information
zeshengzong authored and tinglvv committed May 14, 2024
1 parent 37335bd commit 873a4c2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 33 deletions.
3 changes: 0 additions & 3 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1139,11 +1139,8 @@ exclude_patterns = [
'test/test_tensorexpr.py',
'test/test_tensorexpr_pybind.py',
'test/test_testing.py',
'test/test_throughput_benchmark.py',
'test/test_torch.py',
'test/test_transformers.py',
'test/test_type_hints.py',
'test/test_type_info.py',
'test/test_type_promotion.py',
'test/test_unary_ufuncs.py',
'test/test_utils.py',
Expand Down
8 changes: 5 additions & 3 deletions test/test_throughput_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Owner(s): ["module: unknown"]

import torch

from torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase
from torch.utils import ThroughputBenchmark

from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName

class TwoLayerNet(torch.jit.ScriptModule):
def __init__(self, D_in, H, D_out):
Expand All @@ -19,6 +20,7 @@ def forward(self, x1, x2):
y_pred = self.linear2(cat)
return y_pred


class TwoLayerNetModule(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
Expand All @@ -32,6 +34,7 @@ def forward(self, x1, x2):
y_pred = self.linear2(cat)
return y_pred


class TestThroughputBenchmark(TestCase):
def linear_test(self, Module, profiler_output_path=""):
D_in = 10
Expand Down Expand Up @@ -67,7 +70,6 @@ def linear_test(self, Module, profiler_output_path=""):

print(stats)


def test_script_module(self):
self.linear_test(TwoLayerNet)

Expand All @@ -79,5 +81,5 @@ def test_profiling(self):
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)


if __name__ == '__main__':
if __name__ == "__main__":
run_tests()
36 changes: 20 additions & 16 deletions test/test_type_hints.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Owner(s): ["module: typing"]

import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, set_cwd
import tempfile
import torch
import doctest
import os
import inspect
import os
import tempfile
import unittest
from pathlib import Path

import torch
from torch.testing._internal.common_utils import run_tests, set_cwd, TestCase

try:
import mypy.api

HAVE_MYPY = True
except ImportError:
HAVE_MYPY = False
Expand All @@ -22,7 +24,7 @@ def get_examples_from_docstring(docstr):
in docstrings; returns a list of lines.
"""
examples = doctest.DocTestParser().get_examples(docstr)
return [f' {l}' for e in examples for l in e.source.splitlines()]
return [f" {l}" for e in examples for l in e.source.splitlines()]


def get_all_examples():
Expand Down Expand Up @@ -79,7 +81,7 @@ def test_doc_examples(self):
"""
Run documentation examples through mypy.
"""
fn = Path(__file__).resolve().parent / 'generated_type_hints_smoketest.py'
fn = Path(__file__).resolve().parent / "generated_type_hints_smoketest.py"
with open(fn, "w") as f:
print(get_all_examples(), file=f)

Expand Down Expand Up @@ -116,23 +118,25 @@ def test_doc_examples(self):
try:
os.symlink(
os.path.dirname(torch.__file__),
os.path.join(tmp_dir, 'torch'),
target_is_directory=True
os.path.join(tmp_dir, "torch"),
target_is_directory=True,
)
except OSError:
raise unittest.SkipTest('cannot symlink') from None
raise unittest.SkipTest("cannot symlink") from None
repo_rootdir = Path(__file__).resolve().parent.parent
# TODO: Would be better not to chdir here, this affects the
# entire process!
with set_cwd(str(repo_rootdir)):
(stdout, stderr, result) = mypy.api.run([
'--cache-dir=.mypy_cache/doc',
'--no-strict-optional', # needed because of torch.lu_unpack, see gh-36584
str(fn),
])
(stdout, stderr, result) = mypy.api.run(
[
"--cache-dir=.mypy_cache/doc",
"--no-strict-optional", # needed because of torch.lu_unpack, see gh-36584
str(fn),
]
)
if result != 0:
self.fail(f"mypy failed:\n{stderr}\n{stdout}")


if __name__ == '__main__':
if __name__ == "__main__":
run_tests()
50 changes: 39 additions & 11 deletions test/test_type_info.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
# Owner(s): ["module: typing"]

from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests, set_default_dtype
from torch.testing._internal.common_utils import (
load_tests,
run_tests,
set_default_dtype,
TEST_NUMPY,
TestCase,
)

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests

import sys
import torch
import unittest

import torch

if TEST_NUMPY:
import numpy as np


class TestDTypeInfo(TestCase):

def test_invalid_input(self):
for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
for dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
torch.complex64,
torch.complex128,
torch.bool,
]:
with self.assertRaises(TypeError):
_ = torch.iinfo(dtype)

for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool]:
for dtype in [
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.uint8,
torch.bool,
]:
with self.assertRaises(TypeError):
_ = torch.finfo(dtype)
with self.assertRaises(RuntimeError):
Expand All @@ -41,7 +62,13 @@ def test_iinfo(self):

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_finfo(self):
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
for dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.finfo(x.dtype)
xn = x.cpu().numpy()
Expand All @@ -61,8 +88,8 @@ def test_finfo(self):
x = torch.zeros((2, 2), dtype=torch.bfloat16)
xinfo = torch.finfo(x.dtype)
self.assertEqual(xinfo.bits, 16)
self.assertEqual(xinfo.max, 3.38953e+38)
self.assertEqual(xinfo.min, -3.38953e+38)
self.assertEqual(xinfo.max, 3.38953e38)
self.assertEqual(xinfo.min, -3.38953e38)
self.assertEqual(xinfo.eps, 0.0078125)
self.assertEqual(xinfo.tiny, 1.17549e-38)
self.assertEqual(xinfo.tiny, xinfo.smallest_normal)
Expand All @@ -76,7 +103,7 @@ def test_finfo(self):
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 57344.0)
self.assertEqual(xinfo.min, -57344.0)
self.assertEqual(xinfo.eps, .25)
self.assertEqual(xinfo.eps, 0.25)
self.assertEqual(xinfo.tiny, 6.10352e-05)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e5m2")
Expand All @@ -86,7 +113,7 @@ def test_finfo(self):
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 448.0)
self.assertEqual(xinfo.min, -448.0)
self.assertEqual(xinfo.eps, .125)
self.assertEqual(xinfo.eps, 0.125)
self.assertEqual(xinfo.tiny, 0.015625)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e4m3fn")
Expand All @@ -111,6 +138,7 @@ def test_to_real(self):
self.assertEqual(torch.complex64.to_real(), torch.float32)
self.assertEqual(torch.complex32.to_real(), torch.float16)

if __name__ == '__main__':

if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()

0 comments on commit 873a4c2

Please sign in to comment.