Skip to content
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ If you have modified the code by adding a new feature or a bug-fix, please add u
test:
```bash
pytest test/<test-module.py> -vvv -k <test_myfunc>
# e.g. pytest test/test_transforms.py -vvv -k test_crop
# e.g. pytest test/test_transforms.py -vvv -k test_center_crop
```

If you would like to run all tests:
Expand Down
61 changes: 60 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import os
import torch
import torchvision.transforms as transforms
Expand Down Expand Up @@ -29,7 +30,7 @@

class Tester(unittest.TestCase):

def test_crop(self):
def test_center_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
Expand Down Expand Up @@ -70,6 +71,64 @@ def test_crop(self):
self.assertGreater(sum2, sum1,
"height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))

def test_center_crop_2(self):
""" Tests when center crop size is larger than image size, along any dimension"""
even_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
odd_image_size = (even_image_size[0] + 1, even_image_size[1] + 1)

# Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
input_image_sizes = [even_image_size, odd_image_size]

# Get different crop sizes
delta = random.choice((1, 3, 5))
crop_size_delta = [-2 * delta, -delta, 0, delta, 2 * delta]
crop_size_params = itertools.product(input_image_sizes, crop_size_delta, crop_size_delta)

for (input_image_size, delta_height, delta_width) in crop_size_params:
img = torch.ones(3, *input_image_size)
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)

# Test both transforms, one with PIL input and one with tensor
output_pil = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop(crop_size),
transforms.ToTensor()],
)(img)
self.assertEqual(output_pil.size()[1:3], crop_size,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))

output_tensor = transforms.CenterCrop(crop_size)(img)
self.assertEqual(output_tensor.size()[1:3], crop_size,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))

# Ensure output for PIL and Tensor are equal
self.assertEqual((output_tensor - output_pil).sum(), 0,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))

# Check if content in center of both image and cropped output is same.
center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
crop_center_tl, input_center_tl = [0, 0], [0, 0]
for index in range(2):
if crop_size[index] > input_image_size[index]:
crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2
else:
input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2

output_center = output_pil[
:,
crop_center_tl[0]:crop_center_tl[0] + center_size[0],
crop_center_tl[1]:crop_center_tl[1] + center_size[1]
]

img_center = img[
:,
input_center_tl[0]:input_center_tl[0] + center_size[0],
input_center_tl[1]:input_center_tl[1] + center_size[1]
]

self.assertEqual((output_center - img_center).sum(), 0,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))

def test_five_crop(self):
to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25)
Expand Down
15 changes: 14 additions & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.

Args:
img (PIL Image or Tensor): Image to be cropped.
Expand All @@ -469,6 +470,18 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
image_width, image_height = _get_image_size(img)
crop_height, crop_width = output_size

if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img

crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width)
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def __init__(self, *args, **kwargs):
class CenterCrop(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.

Args:
size (sequence or int): Desired output size of the crop. If size is an
Expand Down