-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make RandomHorizontalFlip torchscriptable (#2282)
* Make RandomHorizontalFlip torchscriptable * Make _is_tensor_a_torch_image more generic * Make RandomVerticalFlip torchscriptable (#2283) * Make RandomVerticalFlip torchscriptable * Fix lint
- Loading branch information
Showing
5 changed files
with
132 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from torchvision import transforms as T | ||
from torchvision.transforms import functional as F | ||
from PIL import Image | ||
|
||
import numpy as np | ||
|
||
import unittest | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
def _create_data(self, height=3, width=3, channels=3): | ||
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) | ||
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) | ||
return tensor, pil_img | ||
|
||
def compareTensorToPIL(self, tensor, pil_image): | ||
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) | ||
self.assertTrue(tensor.equal(pil_tensor)) | ||
|
||
def _test_flip(self, func, method): | ||
tensor, pil_img = self._create_data() | ||
flip_tensor = getattr(F, func)(tensor) | ||
flip_pil_img = getattr(F, func)(pil_img) | ||
self.compareTensorToPIL(flip_tensor, flip_pil_img) | ||
|
||
scripted_fn = torch.jit.script(getattr(F, func)) | ||
flip_tensor_script = scripted_fn(tensor) | ||
self.assertTrue(flip_tensor.equal(flip_tensor_script)) | ||
|
||
# test for class interface | ||
f = getattr(T, method)() | ||
scripted_fn = torch.jit.script(f) | ||
scripted_fn(tensor) | ||
|
||
def test_random_horizontal_flip(self): | ||
self._test_flip('hflip', 'RandomHorizontalFlip') | ||
|
||
def test_random_vertical_flip(self): | ||
self._test_flip('vflip', 'RandomVerticalFlip') | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
try: | ||
import accimage | ||
except ImportError: | ||
accimage = None | ||
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION | ||
|
||
|
||
@torch.jit.unused | ||
def _is_pil_image(img): | ||
if accimage is not None: | ||
return isinstance(img, (Image.Image, accimage.Image)) | ||
else: | ||
return isinstance(img, Image.Image) | ||
|
||
|
||
@torch.jit.unused | ||
def hflip(img): | ||
"""Horizontally flip the given PIL Image. | ||
Args: | ||
img (PIL Image): Image to be flipped. | ||
Returns: | ||
PIL Image: Horizontally flipped image. | ||
""" | ||
if not _is_pil_image(img): | ||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
|
||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
|
||
|
||
@torch.jit.unused | ||
def vflip(img): | ||
"""Vertically flip the given PIL Image. | ||
Args: | ||
img (PIL Image): Image to be flipped. | ||
Returns: | ||
PIL Image: Vertically flipped image. | ||
""" | ||
if not _is_pil_image(img): | ||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
|
||
return img.transpose(Image.FLIP_TOP_BOTTOM) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters