Skip to content

Commit

Permalink
Random erase bypass boxes and masks
Browse files Browse the repository at this point in the history
Go back with if-return/elif-return/else-return
  • Loading branch information
vfdev-5 committed Jul 8, 2022
1 parent 014b8c7 commit a16e61d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 93 deletions.
29 changes: 11 additions & 18 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
from typing import Any, Dict, Tuple, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -94,11 +94,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
elif isinstance(inpt, PIL.Image.Image):
# TODO: We should implement a fallback to tensor, like gaussian_blur etc
raise RuntimeError("Not implemented")
elif isinstance(inpt, torch.Tensor):
return F.erase_image_tensor(inpt, **params)
raise TypeError(
"RandomErasing transformation does not support bounding boxes, segmentation masks and plain labels"
)
else:
return inpt


class _BaseMixupCutmix(Transform):
Expand All @@ -125,21 +122,19 @@ class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[features.Image, features.OneHotLabel], params: Dict[str, Any]
) -> Union[features.Image, features.OneHotLabel]:
lam = params["lam"]
if isinstance(inpt, features.Image):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return features.Image.new_like(inpt, output)
if isinstance(inpt, features.OneHotLabel):
else: # inpt is features.OneHotLabel
return self._mixup_onehotlabel(inpt, lam)

raise TypeError(
"RandomMixup transformation does not support bounding boxes, segmentation masks and plain labels"
)


class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
Expand All @@ -165,7 +160,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[features.Image, features.OneHotLabel], params: Dict[str, Any]
) -> Union[features.Image, features.OneHotLabel]:
if isinstance(inpt, features.Image):
box = params["box"]
if inpt.ndim < 4:
Expand All @@ -175,10 +172,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return features.Image.new_like(inpt, output)
if isinstance(inpt, features.OneHotLabel):
else: # inpt is features.OneHotLabel
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)

raise TypeError(
"RandomCutmix transformation does not support bounding boxes, segmentation masks and plain labels"
)
55 changes: 33 additions & 22 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
else:
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)


adjust_saturation_image_tensor = _FT.adjust_saturation
Expand All @@ -28,9 +29,10 @@ def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
else:
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)


adjust_contrast_image_tensor = _FT.adjust_contrast
Expand All @@ -40,9 +42,10 @@ def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
else:
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)


adjust_sharpness_image_tensor = _FT.adjust_sharpness
Expand All @@ -52,9 +55,10 @@ def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
else:
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)


adjust_hue_image_tensor = _FT.adjust_hue
Expand All @@ -64,9 +68,10 @@ def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
else:
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)


adjust_gamma_image_tensor = _FT.adjust_gamma
Expand All @@ -76,9 +81,10 @@ def adjust_hue(inpt: DType, hue_factor: float) -> DType:
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
else:
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)


posterize_image_tensor = _FT.posterize
Expand All @@ -88,9 +94,10 @@ def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
def posterize(inpt: DType, bits: int) -> DType:
if isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
return posterize_image_tensor(inpt, bits=bits)
else:
return posterize_image_tensor(inpt, bits=bits)


solarize_image_tensor = _FT.solarize
Expand All @@ -100,9 +107,10 @@ def posterize(inpt: DType, bits: int) -> DType:
def solarize(inpt: DType, threshold: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
return solarize_image_tensor(inpt, threshold=threshold)
else:
return solarize_image_tensor(inpt, threshold=threshold)


autocontrast_image_tensor = _FT.autocontrast
Expand All @@ -112,9 +120,10 @@ def solarize(inpt: DType, threshold: float) -> DType:
def autocontrast(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.autocontrast()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
return autocontrast_image_tensor(inpt)
else:
return autocontrast_image_tensor(inpt)


equalize_image_tensor = _FT.equalize
Expand All @@ -124,9 +133,10 @@ def autocontrast(inpt: DType) -> DType:
def equalize(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.equalize()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
return equalize_image_tensor(inpt)
else:
return equalize_image_tensor(inpt)


invert_image_tensor = _FT.invert
Expand All @@ -136,6 +146,7 @@ def equalize(inpt: DType) -> DType:
def invert(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.invert()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
return invert_image_tensor(inpt)
else:
return invert_image_tensor(inpt)
Loading

0 comments on commit a16e61d

Please sign in to comment.