Skip to content

Commit

Permalink
Add JPEG augmentation (#8316)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people committed Mar 18, 2024
1 parent 2ba586d commit 924b162
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ Miscellaneous
v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes
v2.UniformTemporalSubsample
v2.JPEG

Functionals

Expand All @@ -419,6 +420,7 @@ Functionals
v2.functional.sanitize_bounding_boxes
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample
v2.functional.jpeg

.. _conversion_transforms:

Expand Down
11 changes: 11 additions & 0 deletions gallery/transforms/plot_transforms_illustrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot([orig_img] + equalized_imgs)

# %%
# JPEG
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.v2.JPEG` transform
# (see also :func:`~torchvision.transforms.v2.functional.jpeg`)
# applies JPEG compression to the given image with random
# degree of compression.
jpeg = v2.JPEG((5, 50))
jpeg_imgs = [jpeg(orig_img) for _ in range(4)]
plot([orig_img] + jpeg_imgs)

# %%
# Augmentation Transforms
# -----------------------
Expand Down
83 changes: 83 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5932,3 +5932,86 @@ def test_errors_functional(self):

with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
F.sanitize_bounding_boxes(good_bbox.tolist())


class TestJPEG:
@pytest.mark.parametrize("quality", [5, 75])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_kernel_image(self, quality, color_space):
check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality)

def test_kernel_video(self):
check_kernel(F.jpeg_video, make_video(), quality=5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_functional(self, make_input):
check_functional(F.jpeg, make_input(), quality=5)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.jpeg_image, torch.Tensor),
(F._jpeg_image_pil, PIL.Image.Image),
(F.jpeg_image, tv_tensors.Image),
(F.jpeg_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_transform(self, make_input, quality, color_space):
check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space))

@pytest.mark.parametrize("quality", [5])
def test_functional_image_correctness(self, quality):
image = make_image()

actual = F.jpeg(image, quality=quality)
expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality))

# NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
torch.testing.assert_close(actual, expected, rtol=0, atol=1)

@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, quality, color_space, seed):
image = make_image(color_space=color_space)

transform = transforms.JPEG(quality=quality)

with freeze_rng_state():
torch.manual_seed(seed)
actual = transform(image)

torch.manual_seed(seed)
expected = F.to_image(transform(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected, rtol=0, atol=1)

@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("seed", list(range(10)))
def test_transform_get_params_bounds(self, quality, seed):
transform = transforms.JPEG(quality=quality)

with freeze_rng_state():
torch.manual_seed(seed)
params = transform._get_params([])

if isinstance(quality, int):
assert params["quality"] == quality
else:
assert quality[0] <= params["quality"] <= quality[1]

@pytest.mark.parametrize("quality", [[0], [0, 0, 0]])
def test_transform_sequence_len_error(self, quality):
with pytest.raises(ValueError, match="quality should be a sequence of length 2"):
transforms.JPEG(quality=quality)

@pytest.mark.parametrize("quality", [-1, 0, 150])
def test_transform_invalid_quality_error(self, quality):
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
transforms.JPEG(quality=quality)
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import CutMix, MixUp, RandomErasing
from ._augment import CutMix, JPEG, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
40 changes: 38 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

import PIL.Image
import torch
Expand All @@ -11,7 +11,7 @@
from torchvision.transforms.v2 import functional as F

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -317,3 +317,39 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output
else:
return inpt


class JPEG(Transform):
"""Apply JPEG compression and decompression to the given images.
If the input is a :class:`torch.Tensor`, it is expected
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions.
Args:
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
randomly select from (inclusive of both ends).
Returns:
image with JPEG compression.
"""

def __init__(self, quality: Union[int, Sequence[int]]):
super().__init__()
if isinstance(quality, int):
quality = [quality, quality]
else:
_check_sequence_input(quality, "quality", req_sizes=(2,))

if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")

self.quality = quality

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
return dict(quality=quality)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_size,
) # usort: skip

from ._augment import _erase_image_pil, erase, erase_image, erase_video
from ._augment import _erase_image_pil, _jpeg_image_pil, erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video
from ._color import (
_adjust_brightness_image_pil,
_adjust_contrast_image_pil,
Expand Down
43 changes: 43 additions & 0 deletions torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import io

import PIL.Image

import torch
from torchvision import tv_tensors
from torchvision.io import decode_jpeg, encode_jpeg
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

Expand Down Expand Up @@ -53,3 +56,43 @@ def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)


def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.JPEG` for details."""
if torch.jit.is_scripting():
return jpeg_image(image, quality=quality)

_log_api_usage_once(jpeg)

kernel = _get_kernel(jpeg, type(image))
return kernel(image, quality=quality)


@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, tv_tensors.Image)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
original_shape = image.shape
image = image.view((-1,) + image.shape[-3:])

if image.shape[0] == 0: # degenerate
return image.reshape(original_shape).clone()

image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
image = torch.stack(image, dim=0).view(original_shape)
return image


@_register_kernel_internal(jpeg, tv_tensors.Video)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
return jpeg_image(video, quality=quality)


@_register_kernel_internal(jpeg, PIL.Image.Image)
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
raw_jpeg = io.BytesIO()
image.save(raw_jpeg, format="JPEG", quality=quality)

# we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
# which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
return PIL.Image.open(raw_jpeg).copy()

0 comments on commit 924b162

Please sign in to comment.