Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change antialias default from None to True #7949

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
preprocess = weights.transforms(antialias=(device != "mps")) # antialias not supported on MPS

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
Expand Down
18 changes: 0 additions & 18 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
import math
import os
import warnings
from functools import partial
from typing import Sequence

Expand Down Expand Up @@ -569,23 +568,6 @@ def test_resize_antialias(device, dt, size, interpolation):
assert_equal(resized_tensor, resize_result)


def test_resize_antialias_default_warning():

img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)

match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
F.resize(img, size=(20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))

# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
F.resize(img, size=(20, 20), interpolation=NEAREST)
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)


def check_functional_vs_PIL_vs_scripted(
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
):
Expand Down
20 changes: 0 additions & 20 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,25 +1057,5 @@ def test_raft(model_fn, scripted):
_assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)


def test_presets_antialias():

img = torch.randint(0, 256, size=(1, 3, 224, 224), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
models.ResNet18_Weights.DEFAULT.transforms()(img)
with pytest.warns(UserWarning, match=match):
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms()(img)

with warnings.catch_warnings():
warnings.simplefilter("error")
models.ResNet18_Weights.DEFAULT.transforms(antialias=True)(img)
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms(antialias=True)(img)

models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(img)
models.video.R3D_18_Weights.DEFAULT.transforms()(img)
models.optical_flow.Raft_Small_Weights.DEFAULT.transforms()(img, img)


if __name__ == "__main__":
pytest.main([__file__])
11 changes: 0 additions & 11 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
import re
import textwrap
import warnings
from functools import partial

import numpy as np
Expand Down Expand Up @@ -440,16 +439,6 @@ def test_resize_antialias_error():
t(img)


def test_resize_antialias_default_warning():

img = Image.new("RGB", size=(10, 10), color=127)
# We make sure we don't warn for PIL images since the default behaviour doesn't change
with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.Resize((20, 20))(img)
transforms.RandomResizedCrop((20, 20))(img)


@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
def test_resize_size_equals_small_edge_size(height, width):
# Non-regression test for https://github.com/pytorch/vision/issues/5405
Expand Down
17 changes: 0 additions & 17 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
import warnings

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -428,22 +427,6 @@ def test_resized_crop_save_load(self, tmpdir):
fn = T.RandomResizedCrop(size=[32], antialias=True)
_test_fn_save_load(fn, tmpdir)

def test_antialias_default_warning(self):

img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)

match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
T.Resize((20, 20))(img)
with pytest.warns(UserWarning, match=match):
T.RandomResizedCrop((20, 20))(img)

# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
T.Resize((20, 20), interpolation=NEAREST)(img)
T.RandomResizedCrop((20, 20), interpolation=NEAREST)(img)


def _test_random_affine_helper(device, **kwargs):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
Expand Down
41 changes: 0 additions & 41 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pathlib
import pickle
import random
import warnings

import numpy as np

Expand Down Expand Up @@ -726,46 +725,6 @@ def test__transform(self, inpt):
assert output.dtype == inpt.dtype


# TODO: remove this test in 0.17 when the default of antialias changes to True
def test_antialias_warning():
pil_img = PIL.Image.new("RGB", size=(10, 10), color=127)
tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8)
tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
transforms.RandomResizedCrop((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.ScaleJitter((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomShortestSize((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
F.resize(tv_tensors.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20))

with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.RandomResizedCrop((20, 20))(pil_img)
transforms.ScaleJitter((20, 20))(pil_img)
transforms.RandomShortestSize((20, 20))(pil_img)
transforms.RandomResize(10, 20)(pil_img)

transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img)
transforms.ScaleJitter((20, 20), antialias=True)(tensor_img)
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
Expand Down
24 changes: 0 additions & 24 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_equal,
assert_no_warnings,
cache,
cpu_and_cuda,
freeze_rng_state,
Expand Down Expand Up @@ -350,12 +349,6 @@ def adapt_fill(value, *, dtype):
]


@contextlib.contextmanager
def assert_warns_antialias_default_value():
with pytest.warns(UserWarning, match="The default value of the antialias parameter of all the resizing transforms"):
yield


def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
format = bounding_boxes.format
canvas_size = new_canvas_size or bounding_boxes.canvas_size
Expand Down Expand Up @@ -684,23 +677,6 @@ def test_max_size_error(self, size, make_input):
with pytest.raises(ValueError, match=match):
F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True)

@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_antialias_warning(self, interpolation, make_input):
with (
assert_warns_antialias_default_value()
if interpolation in {transforms.InterpolationMode.BILINEAR, transforms.InterpolationMode.BICUBIC}
else assert_no_warnings()
):
F.resize(
make_input(self.INPUT_SIZE),
size=self.OUTPUT_SIZES[0],
interpolation=interpolation,
)

@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
@pytest.mark.parametrize(
"make_input",
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,7 @@ def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
antialias: Optional[bool] = True,
) -> Tensor:
_assert_image_tensor(img)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
antialias: Optional[bool] = True,
) -> None:
super().__init__()
self.crop_size = [crop_size]
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
antialias: Optional[bool] = True,
) -> None:
super().__init__()
self.resize_size = [resize_size] if resize_size is not None else None
Expand Down
50 changes: 8 additions & 42 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,19 +393,12 @@
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
antialias: Optional[bool] = True,

Check warning on line 396 in torchvision/transforms/functional.py

View workflow job for this annotation

GitHub Actions / bc

Function resize: antialias changed from Optional[Union[str, bool]] to Optional[bool]
) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.

Args:
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
Expand Down Expand Up @@ -437,7 +430,7 @@
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:

- ``True``: will apply antialiasing for bilinear or bicubic modes.
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
Expand All @@ -446,8 +439,8 @@
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.

The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
The default value changed from ``None`` to ``True`` in
v0.17, for the PIL and Tensor backends to be consistent.

Returns:
PIL Image or Tensor: Resized image.
Expand Down Expand Up @@ -481,8 +474,6 @@
if [image_height, image_width] == output_size:
return img

antialias = _check_antialias(img, antialias, interpolation)

if not isinstance(img, torch.Tensor):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
Expand Down Expand Up @@ -615,7 +606,7 @@
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
antialias: Optional[bool] = True,

Check warning on line 609 in torchvision/transforms/functional.py

View workflow job for this annotation

GitHub Actions / bc

Function resized_crop: antialias changed from Optional[Union[str, bool]] to Optional[bool]
) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected
Expand Down Expand Up @@ -643,7 +634,7 @@
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:

- ``True``: will apply antialiasing for bilinear or bicubic modes.
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
Expand All @@ -652,8 +643,8 @@
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.

The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
The default value changed from ``None`` to ``True`` in
v0.17, for the PIL and Tensor backends to be consistent.
Returns:
PIL Image or Tensor: Cropped image.
"""
Expand Down Expand Up @@ -1590,28 +1581,3 @@
if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode)
return output


# TODO in v0.17: remove this helper and change default of antialias to True everywhere
def _check_antialias(
img: Tensor, antialias: Optional[Union[str, bool]], interpolation: InterpolationMode
) -> Optional[bool]:
if isinstance(antialias, str): # it should be "warn", but we don't bother checking against that
if isinstance(img, Tensor) and (
interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC
):
warnings.warn(
"The default value of the antialias parameter of all the resizing transforms "
"(Resize(), RandomResizedCrop(), etc.) "
"will change from None to True in v0.17, "
"in order to be consistent across the PIL and Tensor backends. "
"To suppress this warning, directly pass "
"antialias=True (recommended, future default), antialias=None (current default, "
"which means False for Tensors and True for PIL), "
"or antialias=False (only works on Tensors - PIL will still use antialiasing). "
"This also applies if you are using the inference transforms from the models weights: "
"update the call to weights.transforms(antialias=True)."
)
antialias = None

return antialias