Skip to content

Commit

Permalink
Make RandomHorizontalFlip torchscriptable (#2282)
Browse files Browse the repository at this point in the history
* Make RandomHorizontalFlip torchscriptable

* Make _is_tensor_a_torch_image more generic

* Make RandomVerticalFlip torchscriptable (#2283)

* Make RandomVerticalFlip torchscriptable

* Fix lint
  • Loading branch information
fmassa committed Jun 4, 2020
1 parent de52437 commit 11a39aa
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 25 deletions.
44 changes: 44 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image

import numpy as np

import unittest


class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))

def _test_flip(self, func, method):
tensor, pil_img = self._create_data()
flip_tensor = getattr(F, func)(tensor)
flip_pil_img = getattr(F, func)(pil_img)
self.compareTensorToPIL(flip_tensor, flip_pil_img)

scripted_fn = torch.jit.script(getattr(F, func))
flip_tensor_script = scripted_fn(tensor)
self.assertTrue(flip_tensor.equal(flip_tensor_script))

# test for class interface
f = getattr(T, method)()
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def test_random_horizontal_flip(self):
self._test_flip('hflip', 'RandomHorizontalFlip')

def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')


if __name__ == '__main__':
unittest.main()
34 changes: 22 additions & 12 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import Tensor
import math
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
try:
Expand All @@ -11,6 +12,9 @@
from collections.abc import Sequence, Iterable
import warnings

from . import functional_pil as F_pil
from . import functional_tensor as F_t


def _is_pil_image(img):
if accimage is not None:
Expand Down Expand Up @@ -434,19 +438,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
return img


def hflip(img):
"""Horizontally flip the given PIL Image.
def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given PIL Image or torch Tensor.
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image: Horizontally flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)

return img.transpose(Image.FLIP_LEFT_RIGHT)
return F_t.hflip(img)


def _parse_fill(fill, img, min_pil_version):
Expand Down Expand Up @@ -536,19 +543,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)


def vflip(img):
"""Vertically flip the given PIL Image.
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given PIL Image or torch Tensor.
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)

return img.transpose(Image.FLIP_TOP_BOTTOM)
return F_t.vflip(img)


def five_crop(img, size):
Expand Down
46 changes: 46 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
try:
import accimage
except ImportError:
accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION


@torch.jit.unused
def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)


@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Horizontally flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)


@torch.jit.unused
def vflip(img):
"""Vertically flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)
3 changes: 1 addition & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import torchvision.transforms.functional as F
from torch import Tensor
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple


def _is_tensor_a_torch_image(input):
return len(input.shape) == 3
return input.ndim >= 2


def vflip(img):
Expand Down
30 changes: 19 additions & 11 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,51 +500,59 @@ def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)


class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a given probability.
class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""

def __init__(self, p=0.5):
super().__init__()
self.p = p

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
PIL Image or Tensor: Randomly flipped image.
"""
if random.random() < self.p:
if torch.rand(1) < self.p:
return F.hflip(img)
return img

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomVerticalFlip(object):
class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given PIL Image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""

def __init__(self, p=0.5):
super().__init__()
self.p = p

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
PIL Image or Tensor: Randomly flipped image.
"""
if random.random() < self.p:
if torch.rand(1) < self.p:
return F.vflip(img)
return img

Expand Down

0 comments on commit 11a39aa

Please sign in to comment.