Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
TilmannR committed Oct 3, 2023
2 parents 43b452c + 48f8473 commit 5d87844
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 173 deletions.
55 changes: 0 additions & 55 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,61 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
t(inpt)


class TestToImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
"torchvision.transforms.v2.functional.to_image",
return_value=torch.rand(1, 3, 8, 8),
)

inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImage()
transform(inpt)
if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)


class TestToPILImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")

inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)


class TestToTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor")

inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
transform(inpt)
if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)


class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
Expand Down
23 changes: 0 additions & 23 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,6 @@ def __init__(
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
"GRAY",
"GRAY_ALPHA",
"RGB",
"RGBA",
],
extra_dims=[()],
),
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.Lambda,
legacy_transforms.Lambda,
Expand All @@ -97,14 +82,6 @@ def __init__(
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
),
ConsistencyConfig(
v2_transforms.ToTensor,
legacy_transforms.ToTensor,
),
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
Expand Down
82 changes: 81 additions & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,7 +2491,7 @@ def _make_displacement(self, inpt):
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
fill=EXHAUSTIVE_TYPE_FILLS,
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, param, value, dtype, device):
image = make_image_tensor(dtype=dtype, device=device)
Expand All @@ -2502,6 +2502,7 @@ def test_kernel_image(self, param, value, dtype, device):
displacement=self._make_displacement(image),
**{param: value},
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
check_cuda_vs_cpu=dtype is not torch.float16,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
Expand Down Expand Up @@ -5046,3 +5047,82 @@ def test_transform_error_cuda(self):
ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
):
transform(input)


def make_image_numpy(*args, **kwargs):
image = make_image_tensor(*args, **kwargs)
return image.permute((1, 2, 0)).numpy()


class TestToImage:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
@pytest.mark.parametrize("fn", [F.to_image, transform_cls_to_functional(transforms.ToImage)])
def test_functional_and_transform(self, make_input, fn):
input = make_input()
output = fn(input)

assert isinstance(output, tv_tensors.Image)

input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size

if isinstance(input, torch.Tensor):
assert output.data_ptr() == input.data_ptr()

def test_functional_error(self):
with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
F.to_image(object())


class TestToPILImage:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_numpy])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.to_pil_image, transform_cls_to_functional(transforms.ToPILImage)])
def test_functional_and_transform(self, make_input, color_space, fn):
input = make_input(color_space=color_space)
output = fn(input)

assert isinstance(output, PIL.Image.Image)

input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size

def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be Tensor or ndarray"):
F.to_pil_image(object())

for ndim in [1, 4]:
with pytest.raises(ValueError, match="pic should be 2/3 dimensional"):
F.to_pil_image(torch.empty(*[1] * ndim))

with pytest.raises(ValueError, match="pic should not have > 4 channels"):
num_channels = 5
F.to_pil_image(torch.empty(num_channels, 1, 1))


class TestToTensor:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
def test_smoke(self, make_input):
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()

input = make_input()
output = transform(input)

input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size


class TestPILToTensor:
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.pil_to_tensor, transform_cls_to_functional(transforms.PILToTensor)])
def test_functional_and_transform(self, color_space, fn):
input = make_image_pil(color_space=color_space)
output = fn(input)

assert isinstance(output, torch.Tensor) and not isinstance(output, tv_tensors.TVTensor)
assert F.get_size(output) == F.get_size(input)

def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be PIL Image"):
F.pil_to_tensor(object())

0 comments on commit 5d87844

Please sign in to comment.