Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow len 1 sequences for fill with PIL #7928

Merged
merged 5 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,9 +2581,6 @@ def test_transform(self, param, value, make_input):
# 2. the fill parameter only has an affect if we need padding
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]

if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1:
pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.")

if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")

Expand Down
37 changes: 0 additions & 37 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections.abc

import pytest
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
Expand Down Expand Up @@ -112,32 +110,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
multi_crop_skips.append(skip_dispatch_tv_tensor)


def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False

if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False

return image_loader.num_channels > 1


xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)


DISPATCHER_INFOS = [
DispatcherInfo(
F.resized_crop,
Expand All @@ -159,14 +131,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"),
],
Expand All @@ -181,7 +145,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
],
),
Expand Down
6 changes: 4 additions & 2 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ def _parse_fill(
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
if len(fill) == 1:
fill = fill * num_channels
elif len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))

fill = tuple(fill)
fill = tuple(fill) # type: ignore[arg-type]

if img.mode != "F":
if isinstance(fill, (list, tuple)):
Expand Down