-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🚀 The feature
Supporting arbitrary input structures in custom transforms is very important in the case of transform compositions:
tr = Compose([RandomCrop((128,128), CustomTransform])
This can be done by inheriting from torchvision.transforms.v2.Transform
and implementing the private ._transform
method, which avoids having to unravel the data structure on your own (since this is done anyway in the .forward
method).
class CustomTransform(Transform):
def __init__(self, *kwargs):
pass
def _transform(self, inpt, params):
if isinstance(inpt, Image):
pass
elif isinstance(inpt, BoundingBoxes):
pass
else:
pass
return transformed_inpt
The method has also been described in this blog post How to Create Custom Torchvision V2 Transforms, but the official torchvision docs do not yet describe it and instead suggest hard-coding the input structure.
Having to implement a private method for this (even though the class Transform
is public) feels very wrong this means that things could break on our side any time. I would appreciate if the ._transform
method was made public -> .transform
and the Transform
class would receive proper documentation on how this method should be implemented for custom transforms.
Motivation, pitch
The torchvision.transforms.v2
API has now been around for quite some time already and it would be nice to give developers the chance to develop transforms of the same quality and flexibility as the originally implemented ones!
Alternatives
No response
Additional context
No response