Skip to content

Commit

Permalink
Run mypy over test/test_utils.py (#50278)
Browse files Browse the repository at this point in the history
Summary:
_resubmission of gh-49654, which was reverted due to a cross-merge conflict_

This caught one incorrect annotation in `cpp_extension.load`.

xref gh-16574.

Pull Request resolved: #50278

Reviewed By: walterddr

Differential Revision: D25865278

Pulled By: ezyang

fbshipit-source-id: 25489191628af5cf9468136db36f5a0f72d9d54d
  • Loading branch information
rgommers authored and facebook-github-bot committed Jan 11, 2021
1 parent eb87686 commit e29082b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
3 changes: 2 additions & 1 deletion mypy.ini
Expand Up @@ -26,7 +26,8 @@ files =
test/test_numpy_interop.py,
test/test_torch.py,
test/test_type_hints.py,
test/test_type_info.py
test/test_type_info.py,
test/test_utils.py


# Minimum version supported - variable annotations were introduced
Expand Down
56 changes: 30 additions & 26 deletions test/test_utils.py
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.utils.data
from torch.utils.data import DataLoader
import torch.cuda
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.utils.cpp_extension
Expand All @@ -28,7 +29,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests


class RandomDatasetMock(object):
class RandomDatasetMock(torch.utils.data.Dataset):

def __getitem__(self, index):
return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
Expand Down Expand Up @@ -190,7 +191,7 @@ def forward(self, a, b):
b = torch.randn(1, 100, requires_grad=True)

with self.assertRaises(TypeError):
checkpoint_sequential(model, 1, a, b)
checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg]

def test_checkpoint_sequential_deprecated_no_args(self):
class Noop(nn.Module):
Expand All @@ -200,7 +201,7 @@ def forward(self):
model = nn.Sequential(Noop())

with self.assertRaises(TypeError):
checkpoint_sequential(model, 1)
checkpoint_sequential(model, 1) # type: ignore[call-arg]

def test_checkpoint_rng_cpu(self):
for _ in range(5):
Expand Down Expand Up @@ -277,15 +278,15 @@ def run_fn(tensor1, tensor2):
out = checkpoint(run_fn, input_var, input_var2)
out[0].sum().backward()

def run_fn(tensor1, tensor2):
def run_fn2(tensor1, tensor2):
return tensor1
input_var = torch.randn(1, 4, requires_grad=False)
input_var2 = torch.randn(1, 4, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
r"none of output has requires_grad=True, this checkpoint\(\) is not necessary"
):
out = checkpoint(run_fn, input_var, input_var2)
out = checkpoint(run_fn2, input_var, input_var2)
out.sum().backward()

class TestDataLoader(TestCase):
Expand All @@ -308,35 +309,38 @@ def run():
self.assertEqual(x1, x2)

def test_single_keep(self):
dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=0,
drop_last=False)
# self.dataset is a Tensor here; technically not a valid input because
# not a Dataset subclass, but needs to stay working so add ignore's
# for type checking with mypy
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
batch_size=self.batch_size,
num_workers=0,
drop_last=False)
dataiter = iter(dataloader)
self.assertEqual(len(list(dataiter)), 2)

def test_single_drop(self):
dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=0,
drop_last=True)
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
batch_size=self.batch_size,
num_workers=0,
drop_last=True)
dataiter = iter(dataloader)
self.assertEqual(len(list(dataiter)), 1)

@unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN")
def test_multi_keep(self):
dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=2,
drop_last=False)
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
batch_size=self.batch_size,
num_workers=2,
drop_last=False)
dataiter = iter(dataloader)
self.assertEqual(len(list(dataiter)), 2)

def test_multi_drop(self):
dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=2,
drop_last=True)
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
batch_size=self.batch_size,
num_workers=2,
drop_last=True)
dataiter = iter(dataloader)
self.assertEqual(len(list(dataiter)), 1)

Expand All @@ -347,7 +351,7 @@ def test_multi_drop(self):
class TestFFI(TestCase):
def test_deprecated(self):
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
from torch.utils.ffi import create_extension # noqa: F401
from torch.utils.ffi import create_extension # type: ignore # noqa: F401


@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
Expand All @@ -364,9 +368,9 @@ def _run(self, command, timeout=30):
p.kill()
output, err = p.communicate()
rc = p.returncode
output = output.decode("ascii")
err = err.decode("ascii")
return (rc, output, err)
output_str = output.decode("ascii")
err_str = err.decode("ascii")
return (rc, output_str, err_str)

def _run_bottleneck(self, test_file, scriptargs=''):
curdir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -661,7 +665,7 @@ def forward(self, x):
# data can be passed without errors
x = torch.randn(4, 4).fill_(1.0)
ms(x)
with self.assertRaisesRegex(torch.jit.Error, "foo"):
with self.assertRaisesRegex(torch.jit.Error, "foo"): # type: ignore[type-var]
ms(torch.tensor([False], dtype=torch.bool))


Expand Down
4 changes: 2 additions & 2 deletions torch/utils/cpp_extension.py
Expand Up @@ -17,7 +17,7 @@
from ._cpp_extension_versioner import ExtensionVersioner
from .hipify import hipify_python
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from typing import List, Optional
from typing import List, Optional, Union

from setuptools.command.build_ext import build_ext
from pkg_resources import packaging # type: ignore
Expand Down Expand Up @@ -980,7 +980,7 @@ def library_paths(cuda: bool = False) -> List[str]:


def load(name,
sources: List[str],
sources: Union[str, List[str]],
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
Expand Down

0 comments on commit e29082b

Please sign in to comment.