Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ def _test_classification_model(self, name, input_shape, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50)

def _test_segmentation_model(self, name, dev):
Expand All @@ -94,7 +94,7 @@ def _test_segmentation_model(self, name, dev):
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
Expand Down Expand Up @@ -143,7 +143,7 @@ def compute_mean_std(tensor):

output = map_nested_tensor_object(out, tensor_map_fn=compact)
prec = 0.01
strip_suffix = "_" + dev
strip_suffix = f"_{dev}"
try:
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
Expand All @@ -169,7 +169,7 @@ def compute_mean_std(tensor):
full_validation = check_out(out)
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
Expand Down Expand Up @@ -220,7 +220,7 @@ def _test_video_model(self, name, dev):
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50)

if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(out.shape[-1], 50)
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_generalizedrcnn_transform_repr(self):
self.assertEqual(t.__repr__(), expected_string)


_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devs = [torch.device("cpu"), torch.device("cuda")] if torch.cuda.is_available() else [torch.device("cpu")]


for model_name in get_available_classification_models():
Expand All @@ -393,7 +393,7 @@ def do_test(self, model_name=model_name, dev=dev):
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev)

setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)


for model_name in get_available_segmentation_models():
Expand All @@ -403,7 +403,7 @@ def do_test(self, model_name=model_name, dev=dev):
def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev)

setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)


for model_name in get_available_detection_models():
Expand All @@ -413,7 +413,7 @@ def do_test(self, model_name=model_name, dev=dev):
def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev)

setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
Expand All @@ -426,7 +426,7 @@ def do_validation_test(self, model_name=model_name):
def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev)

setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)

if __name__ == '__main__':
unittest.main()
6 changes: 2 additions & 4 deletions torchvision/models/detection/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class ImageList(object):
and storing in a field the original sizes of each image
"""

def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]]) -> None
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]):
"""
Arguments:
tensors (tensor)
Expand All @@ -22,7 +21,6 @@ def __init__(self, tensors, image_sizes):
self.tensors = tensors
self.image_sizes = image_sizes

def to(self, device):
# type: (Device) -> ImageList # noqa
def to(self, device: torch.device) -> 'ImageList':
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
2 changes: 1 addition & 1 deletion torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple
from torch.jit.annotations import Dict, List, Tuple, Optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in here

# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Let's keep this here then, as when we move the # type: annotations to be inline the import will be necessary


from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
Expand Down