Skip to content

Commit

Permalink
Add xy masking (#1521)
Browse files Browse the repository at this point in the history
* temp commit

* Added an option for fixed mask length

* Added an option for fixed mask length

* Added tests for cutout

* Added updated readme
  • Loading branch information
ternaus committed Feb 17, 2024
1 parent be6a217 commit 004fabb
Show file tree
Hide file tree
Showing 14 changed files with 419 additions and 39 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -236,6 +236,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [SmallestMaxSize](https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/#albumentations.augmentations.geometric.resize.SmallestMaxSize) |||||
| [Transpose](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.Transpose) |||||
| [VerticalFlip](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.VerticalFlip) |||||
| [XYMasking](https://albumentations.ai/docs/api_reference/augmentations/dropout/xy_masking/#albumentations.augmentations.dropout.xy_masking.XYMasking) ||| ||

## A few more examples of augmentations

Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/__init__.py
Expand Up @@ -11,6 +11,7 @@
from .dropout.functional import *
from .dropout.grid_dropout import *
from .dropout.mask_dropout import *
from .dropout.xy_masking import *
from .functional import *
from .geometric.functional import *
from .geometric.resize import *
Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/dropout/__init__.py
Expand Up @@ -2,3 +2,4 @@
from .coarse_dropout import *
from .grid_dropout import *
from .mask_dropout import *
from .xy_masking import *
23 changes: 9 additions & 14 deletions albumentations/augmentations/dropout/coarse_dropout.py
Expand Up @@ -5,7 +5,7 @@

from ...core.transforms_interface import DualTransform
from ...core.types import KeypointType, ScalarType
from .functional import cutout
from .functional import cutout, keypoint_in_hole

__all__ = ["CoarseDropout"]

Expand Down Expand Up @@ -79,7 +79,8 @@ def __init__(
if not 0 < self.min_width <= self.max_width:
raise ValueError(f"Invalid combination of min_width and max_width. Got: {[min_width, max_width]}")

def check_range(self, dimension: ScalarType) -> None:
@staticmethod
def check_range(dimension: ScalarType) -> None:
if isinstance(dimension, float) and not 0 <= dimension < 1.0:
raise ValueError(f"Invalid value {dimension}. If using floats, the value should be in the range [0.0, 1.0)")

Expand Down Expand Up @@ -108,7 +109,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
height, width = img.shape[:2]

holes = []
for _n in range(random.randint(self.min_holes, self.max_holes)):
for _ in range(random.randint(self.min_holes, self.max_holes)):
if all(
[
isinstance(self.min_height, int),
Expand Down Expand Up @@ -156,20 +157,14 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
def targets_as_params(self) -> List[str]:
return ["image"]

def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool:
x1, y1, x2, y2 = hole
x, y = keypoint[:2]
return x1 <= x < x2 and y1 <= y < y2

def apply_to_keypoints(
self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params: Any
) -> List[KeypointType]:
result = set(keypoints)
for hole in holes:
for kp in keypoints:
if self._keypoint_in_hole(kp, hole):
result.discard(kp)
return list(result)
filtered_keypoints = []
for keypoint in keypoints:
if not any(keypoint_in_hole(keypoint, hole) for hole in holes):
filtered_keypoints.append(keypoint)
return filtered_keypoints

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return (
Expand Down
20 changes: 13 additions & 7 deletions albumentations/augmentations/dropout/functional.py
Expand Up @@ -3,13 +3,12 @@
import numpy as np

from albumentations.augmentations.utils import preserve_shape

__all__ = ["channel_dropout"]
from albumentations.core.types import ColorType, KeypointType


@preserve_shape
def channel_dropout(
img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: Union[int, float] = 0
img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: ColorType = 0
) -> np.ndarray:
if len(img.shape) == 2 or img.shape[2] == 1:
raise NotImplementedError("Only one channel. ChannelDropout is not defined.")
Expand All @@ -19,11 +18,18 @@ def channel_dropout(
return img


def cutout(
img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: Union[int, float] = 0
) -> np.ndarray:
# Make a copy of the input image since we don't want to modify it directly
def cutout(img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: ColorType = 0) -> np.ndarray:
img = img.copy()
# Convert fill_value to a NumPy array for consistent broadcasting
if isinstance(fill_value, (tuple, list)):
fill_value = np.array(fill_value)

for x1, y1, x2, y2 in holes:
img[y1:y2, x1:x2] = fill_value
return img


def keypoint_in_hole(keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool:
x, y = keypoint[:2]
x1, y1, x2, y2 = hole
return x1 <= x < x2 and y1 <= y < y2
211 changes: 211 additions & 0 deletions albumentations/augmentations/dropout/xy_masking.py
@@ -0,0 +1,211 @@
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np

from albumentations.core.types import ColorType, KeypointType, ScaleIntType

from ...core.transforms_interface import DualTransform, to_tuple
from .functional import cutout, keypoint_in_hole

__all__ = ["XYMasking"]


class XYMasking(DualTransform):
"""
Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis),
simulating occlusions. This transform is useful for training models to recognize images
with varied visibility conditions. It's particularly effective for spectrogram images,
allowing spectral and frequency masking to improve model robustness.
At least one of `max_x_length` or `max_y_length` must be specified, dictating the mask's
maximum size along each axis.
Args:
num_masks_x (Union[int, Tuple[int, int]]): Number or range of horizontal regions to mask. Defaults to 0.
num_masks_y (Union[int, Tuple[int, int]]): Number or range of vertical regions to mask. Defaults to 0.
mask_x_length ([Union[int, Tuple[int, int]]): Specifies the length of the masks along
the X (horizontal) axis. If an integer is provided, it sets a fixed mask length.
If a tuple of two integers (min, max) is provided,
the mask length is randomly chosen within this range for each mask.
This allows for variable-length masks in the horizontal direction.
mask_y_length (Union[int, Tuple[int, int]]): Specifies the height of the masks along
the Y (vertical) axis. Similar to `mask_x_length`, an integer sets a fixed mask height,
while a tuple (min, max) allows for variable-height masks, chosen randomly
within the specified range for each mask. This flexibility facilitates creating masks of various
sizes in the vertical direction.
fill_value (Union[int, float, List[int], List[float]]): Value to fill image masks. Defaults to 0.
mask_fill_value (Optional[Union[int, float, List[int], List[float]]]): Value to fill masks in the mask.
If `None`, uses mask is not affected. Default: `None`.
p (float): Probability of applying the transform. Defaults to 0.5.
Targets:
image, mask, keypoints
Image types:
uint8, float32
Note: Either `max_x_length` or `max_y_length` or both must be defined.
"""

def __init__(
self,
num_masks_x: ScaleIntType = 0,
num_masks_y: ScaleIntType = 0,
mask_x_length: ScaleIntType = 0,
mask_y_length: ScaleIntType = 0,
fill_value: ColorType = 0,
mask_fill_value: ColorType = 0,
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)

if (
isinstance(mask_x_length, (int, float))
and mask_x_length <= 0
and isinstance(mask_y_length, (int, float))
and mask_y_length <= 0
):
raise ValueError("At least one of `mask_x_length` or `mask_y_length` Should be a positive number.")

if isinstance(num_masks_x, int) and num_masks_x <= 0 and isinstance(num_masks_y, int) and num_masks_y <= 0:
raise ValueError(
"At least one of `num_masks_x` or `num_masks_y` "
"should be a positive number or tuple of two positive numbers."
)

if isinstance(num_masks_x, (tuple, list)) and min(num_masks_x) <= 0:
raise ValueError("All values in `num_masks_x` should be non negative integers.")

if isinstance(num_masks_y, (tuple, list)) and min(num_masks_y) <= 0:
raise ValueError("All values in `num_masks_y` should be non negative integers.")

self.num_masks_x = num_masks_x
self.num_masks_y = num_masks_y

self.mask_x_length = mask_x_length
self.mask_y_length = mask_y_length
self.fill_value = fill_value
self.mask_fill_value = mask_fill_value

def apply(
self,
img: np.ndarray,
masks_x: List[Tuple[int, int, int, int]],
masks_y: List[Tuple[int, int, int, int]],
**params: Any,
) -> np.ndarray:
return cutout(img, masks_x + masks_y, self.fill_value)

def apply_to_mask(
self,
mask: np.ndarray,
masks_x: List[Tuple[int, int, int, int]],
masks_y: List[Tuple[int, int, int, int]],
**params: Any,
) -> np.ndarray:
if self.mask_fill_value is None:
return mask
return cutout(mask, masks_x + masks_y, self.mask_fill_value)

def validate_mask_length(
self, mask_length: Optional[ScaleIntType], dimension_size: int, dimension_name: str
) -> None:
"""
Validate the mask length against the corresponding image dimension size.
Args:
mask_length (Optional[Union[int, Tuple[int, int]]]): The length of the mask to be validated.
dimension_size (int): The size of the image dimension (width or height)
against which to validate the mask length.
dimension_name (str): The name of the dimension ('width' or 'height') for error messaging.
"""
if mask_length is not None:
if isinstance(mask_length, tuple):
if mask_length[0] < 0 or mask_length[1] > dimension_size:
raise ValueError(
f"{dimension_name} range {mask_length} is out of valid range [0, {dimension_size}]"
)
elif mask_length < 0 or mask_length > dimension_size:
raise ValueError(f"{dimension_name} {mask_length} exceeds image {dimension_name} {dimension_size}")

def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, List[Tuple[int, int, int, int]]]:
img = params["image"]
height, width = img.shape[:2]

# Use the helper method to validate mask lengths against image dimensions
self.validate_mask_length(self.mask_x_length, width, "mask_x_length")
self.validate_mask_length(self.mask_y_length, height, "mask_y_length")

masks_x = self.generate_masks(self.num_masks_x, width, height, self.mask_x_length, axis="x")
masks_y = self.generate_masks(self.num_masks_y, width, height, self.mask_y_length, axis="y")

return {"masks_x": masks_x, "masks_y": masks_y}

@staticmethod
def generate_mask_size(mask_length: Union[ScaleIntType]) -> int:
if isinstance(mask_length, int):
return mask_length # Use fixed size or adjust to dimension size

return random.randint(min(mask_length), max(mask_length))

def generate_masks(
self,
num_masks: ScaleIntType,
width: int,
height: int,
max_length: Optional[ScaleIntType],
axis: str,
) -> List[Tuple[int, int, int, int]]:
if max_length is None or max_length == 0 or isinstance(num_masks, (int, float)) and num_masks == 0:
return []

masks = []

if isinstance(num_masks, int):
num_masks_integer = num_masks
else:
num_masks_integer = random.randint(num_masks[0], num_masks[1])

for _ in range(num_masks_integer):
length = self.generate_mask_size(max_length)

if axis == "x":
x1 = random.randint(0, width - length)
y1 = 0
x2, y2 = x1 + length, height
else: # axis == 'y'
y1 = random.randint(0, height - length)
x1 = 0
x2, y2 = width, y1 + length

masks.append((x1, y1, x2, y2))
return masks

@property
def targets_as_params(self) -> List[str]:
return ["image"]

def apply_to_keypoints(
self,
keypoints: Sequence[KeypointType],
masks_x: List[Tuple[int, int, int, int]],
masks_y: List[Tuple[int, int, int, int]],
**params: Any,
) -> List[KeypointType]:
filtered_keypoints = []
for keypoint in keypoints:
if not any(keypoint_in_hole(keypoint, hole) for hole in masks_x + masks_y):
filtered_keypoints.append(keypoint)
return filtered_keypoints

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return (
"num_masks_x",
"num_masks_y",
"mask_x_length",
"mask_y_length",
"fill_value",
"mask_fill_value",
)
5 changes: 3 additions & 2 deletions albumentations/core/transforms_interface.py
Expand Up @@ -10,6 +10,7 @@
from .types import (
BoxInternalType,
BoxType,
ColorType,
KeypointInternalType,
KeypointType,
ScalarType,
Expand Down Expand Up @@ -74,8 +75,8 @@ def __init__(self, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_
class BasicTransform(Serializable):
call_backup = None
interpolation: Union[int, Interpolation]
fill_value: ScalarType
mask_fill_value: Optional[ScalarType]
fill_value: ColorType
mask_fill_value: Optional[ColorType]

def __init__(self, always_apply: bool = False, p: float = 0.5):
self.p = p
Expand Down
2 changes: 1 addition & 1 deletion albumentations/core/types.py
Expand Up @@ -3,7 +3,7 @@
import numpy as np

ScalarType = Union[int, float]
ColorType = Union[int, float, Tuple[int, int, int], Tuple[float, float, float]]
ColorType = Union[int, float, Sequence[int], Sequence[float]]
SizeType = Sequence[int]

BoxInternalType = Tuple[float, float, float, float]
Expand Down

0 comments on commit 004fabb

Please sign in to comment.