diff --git a/test/test_transforms.py b/test/test_transforms.py index 7346e2c5094..b8045703267 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -219,6 +219,11 @@ def test_resize(self): width = random.randint(24, 32) * 2 osize = random.randint(5, 12) * 2 + # TODO: Check output size check for bug-fix, improve this later + t = transforms.Resize(osize) + self.assertTrue(isinstance(t.size, int)) + self.assertEqual(t.size, osize) + img = torch.ones(3, height, width) result = transforms.Compose([ transforms.ToPILImage(), diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 51c9cd0280b..1fc0ab61ec4 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -280,6 +280,17 @@ def test_ten_crop(self): ) def test_resize(self): + + # TODO: Minimal check for bug-fix, improve this later + x = torch.rand(3, 32, 46) + t = T.Resize(size=38) + y = t(x) + # If size is an int, smaller edge of the image will be matched to this number. + # i.e, if height > width, then image will be rescaled to (size * height / width, size). + self.assertTrue(isinstance(y, torch.Tensor)) + self.assertEqual(y.shape[1], 38) + self.assertEqual(y.shape[2], int(38 * 46 / 32)) + tensor, _ = self._create_data(height=34, width=36, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) script_fn = torch.jit.script(F.resize) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6491ed38c20..bd1479b7cf0 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -249,7 +249,11 @@ class Resize(torch.nn.Module): def __init__(self, size, interpolation=Image.BILINEAR): super().__init__() - self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = size self.interpolation = interpolation def forward(self, img):