Skip to content

Commit

Permalink
Revert 1692 (#1702)
Browse files Browse the repository at this point in the history
* Revert "Available keys at BasicTransform and BaseCompose (#1692)"

This reverts commit ee3c634.

* Add tests for common pipelines check

* Update test
  • Loading branch information
Dipet committed May 4, 2024
1 parent cd50269 commit 16a55ae
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 122 deletions.
2 changes: 1 addition & 1 deletion albumentations/__init__.py
@@ -1,4 +1,4 @@
__version__ = "1.4.5"
__version__ = "1.4.6"

from .augmentations import *
from .core.composition import *
Expand Down
36 changes: 5 additions & 31 deletions albumentations/core/composition.py
@@ -1,7 +1,7 @@
import random
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Union, cast
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast

import cv2
import numpy as np
Expand Down Expand Up @@ -65,9 +65,7 @@ def __init__(self, transforms: TransformsSeqType, p: float):
self.replay_mode = False
self.applied_in_replay = False
self._additional_targets: Dict[str, str] = {}
self._available_keys: Set[str] = set()
self.processors: Dict[str, Union[BboxProcessor, KeypointsProcessor]] = {}
self._set_keys()

def __iter__(self) -> Iterator[TransformType]:
return iter(self.transforms)
Expand All @@ -88,10 +86,6 @@ def __repr__(self) -> str:
def additional_targets(self) -> Dict[str, str]:
return self._additional_targets

@property
def available_keys(self) -> Set[str]:
return self._available_keys

def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")}
repr_string = self.__class__.__name__ + "(["
Expand Down Expand Up @@ -133,22 +127,11 @@ def add_targets(self, additional_targets: Optional[Dict[str, str]]) -> None:
f"Trying to overwrite existed additional targets. "
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
)
self._additional_targets.update(additional_targets)
self._additional_targets.update(additional_targets)
for t in self.transforms:
t.add_targets(additional_targets)
for proc in self.processors.values():
proc.add_targets(additional_targets)
self._set_keys()

def _set_keys(self) -> None:
"""Set _available_keys"""
for t in self.transforms:
self._available_keys.update(t.available_keys)
if self.processors:
self._available_keys.update(["labels"])
for proc in self.processors.values():
if proc.params.label_fields:
self._available_keys.update(proc.params.label_fields)

def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
for t in self.transforms:
Expand Down Expand Up @@ -209,7 +192,6 @@ def __init__(
self._disable_check_args_for_transforms(self.transforms)

self.is_check_shapes = is_check_shapes
self._always_apply = get_always_apply(self.transforms) # transforms list that always apply
self._check_each_transform = tuple( # processors that checks after each transform
proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False)
)
Expand All @@ -229,22 +211,18 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[s
if args:
msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
raise KeyError(msg)
if self.is_check_args:
self._check_args(**data)

if not isinstance(force_apply, (bool, int)):
msg = "force_apply must have bool or int type"
raise TypeError(msg)

need_to_run = force_apply or random.random() < self.p
if not need_to_run and not self._always_apply:
return data

transforms = self.transforms if need_to_run else self._always_apply

if self.is_check_args:
self._check_args(**data)

for p in self.processors.values():
p.ensure_data_valid(data)
transforms = self.transforms if need_to_run else get_always_apply(self.transforms)

for p in self.processors.values():
p.preprocess(data)
Expand Down Expand Up @@ -308,9 +286,6 @@ def _check_args(self, **kwargs: Any) -> None:
check_keypoints_param = ["keypoints"]
shapes = []
for data_name, data in kwargs.items():
if data_name not in self._available_keys and data_name not in ["mask", "masks"]:
msg = f"Key {data_name} is not in available keys."
raise ValueError(msg)
internal_data_name = self._additional_targets.get(data_name, data_name)
if internal_data_name in checked_single:
if not isinstance(data, np.ndarray):
Expand Down Expand Up @@ -518,7 +493,6 @@ def __init__(
super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes)
self.set_deterministic(True, save_key=save_key)
self.save_key = save_key
self._available_keys.add(save_key)

def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Dict[str, Any]:
kwargs[self.save_key] = defaultdict(dict)
Expand Down
61 changes: 21 additions & 40 deletions albumentations/core/transforms_interface.py
@@ -1,6 +1,6 @@
import random
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from warnings import warn

import cv2
Expand Down Expand Up @@ -43,12 +43,8 @@ class CombinedMeta(SerializableMeta, ValidatedTransformMeta):


class BasicTransform(Serializable, metaclass=CombinedMeta):
_targets: Union[Tuple[Targets, ...], Targets] # targets that this transform can work on
_available_keys: Set[str] # targets that this transform, as string, lower-cased
_key2func: Dict[
str,
Callable[..., Any],
] # mapping for targets (plus additional targets) and methods for which they depend
# `_targets` defines the types of targets (e.g., image, mask) that the transform can be applied to.
_targets: Union[Tuple[Targets, ...], Targets]
call_backup = None
interpolation: int
fill_value: ColorType
Expand All @@ -68,8 +64,6 @@ def __init__(self, always_apply: bool = False, p: float = 0.5):
self._additional_targets: Dict[str, str] = {}
# replay mode params
self.params: Dict[Any, Any] = {}
self._key2func = {}
self._set_keys()

def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any:
if args:
Expand Down Expand Up @@ -103,11 +97,12 @@ def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -
params = self.update_params(params, **kwargs)
res = {}
for key, arg in kwargs.items():
if key in self._key2func and arg is not None:
target_function = self._key2func[key]
res[key] = target_function(arg, **params)
if arg is not None:
target_function = self._get_target_function(key)
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
res[key] = target_function(arg, **dict(params, **target_dependencies))
else:
res[key] = arg
res[key] = None
return res

def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform":
Expand All @@ -130,6 +125,14 @@ def __repr__(self) -> str:
state.update(self.get_transform_init_args())
return f"{self.__class__.__name__}({format_args(state)})"

def _get_target_function(self, key: str) -> Callable[..., Any]:
"""Returns function to process target"""
transform_key = key
if key in self._additional_targets:
transform_key = self._additional_targets.get(key, key)

return self.targets.get(transform_key, lambda x, **p: x)

def apply(self, img: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
"""Apply transform on image."""
raise NotImplementedError
Expand All @@ -146,23 +149,6 @@ def targets(self) -> Dict[str, Callable[..., Any]]:
# >> {"masks": self.apply_to_masks}
raise NotImplementedError

def _set_keys(self) -> None:
"""Set _available_keys"""
if not hasattr(self, "_targets"):
self._available_keys = set()
else:
self._available_keys = {
target.value.lower()
for target in (self._targets if isinstance(self._targets, tuple) else [self._targets])
}
self._available_keys.update(self.targets.keys())
self._key2func = {key: self.targets[key] for key in self._available_keys if key in self.targets}

@property
def available_keys(self) -> Set[str]:
"""Returns set of available keys"""
return self._available_keys

def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
"""Update parameters with transform specific params"""
if hasattr(self, "interpolation"):
Expand All @@ -174,6 +160,10 @@ def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]
params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]})
return params

@property
def target_dependence(self) -> Dict[str, Any]:
return {}

def add_targets(self, additional_targets: Dict[str, str]) -> None:
"""Add targets to transform them the same way as one of existing targets
ex: {'target_image': 'image'}
Expand All @@ -184,16 +174,7 @@ def add_targets(self, additional_targets: Dict[str, str]) -> None:
additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
"""
for k, v in additional_targets.items():
if k in self._additional_targets and v != self._additional_targets[k]:
raise ValueError(
f"Trying to overwrite existed additional targets. "
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
)
if v in self._available_keys:
self._additional_targets[k] = v
self._key2func[k] = self.targets[v]
self._available_keys.add(k)
self._additional_targets = {**self._additional_targets, **additional_targets}

@property
def targets_as_params(self) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions albumentations/pytorch/transforms.py
Expand Up @@ -4,7 +4,6 @@
import torch

from albumentations.core.transforms_interface import BasicTransform
from albumentations.core.types import Targets

__all__ = ["ToTensorV2"]

Expand All @@ -24,8 +23,6 @@ class ToTensorV2(BasicTransform):
"""

_targets = (Targets.IMAGE, Targets.MASK)

def __init__(self, transpose_mask: bool = False, always_apply: bool = True, p: float = 1.0):
super().__init__(always_apply=always_apply, p=p)
self.transpose_mask = transpose_mask
Expand Down Expand Up @@ -54,3 +51,6 @@ def apply_to_masks(self, masks: List[np.ndarray], **params: Any) -> List[torch.T

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return ("transpose_mask",)

def get_params_dependent_on_targets(self, params: Any) -> Dict[str, Any]:
return {}

0 comments on commit 16a55ae

Please sign in to comment.