-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
Description
There is a FIXME comment at line 103 of torchvision/transforms/autoaugment.py:
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
Four classes (AutoAugment, RandAugment, TrivialAugmentWide, AugMix) all independently repeat the same fill-standardization logic in their forward() methods:
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]They also independently inherit from torch.nn.Module and define interpolation and fill in their __init__.
Proposal
Create an _AutoAugmentBase class (like v2 already did in torchvision/transforms/v2/_auto_augment.py) that:
- Inherits from
torch.nn.Module - Holds common
__init__params:interpolationandfill - Provides a
_get_fill()helper to eliminate the duplicated fill-standardization logic
This is a pure internal refactor — no public API changes, no behavior changes. All existing tests should pass as-is.
Note: _augmentation_space() is intentionally not unified since each class uses different signatures and contents.
cc @pmeier