-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Closed
Description
🐛 Describe the bug
When using tv_tensors.wrap on an instance of a subclass of either BoundingBoxes or KeyPoints it will instead return the base class since they are hard-coded in. For example:
from torchvision import tv_tensors
class MyKeyPoints(tv_tensors.KeyPoints):
pass
kp = MyKeyPoints(torch.tensor([[0, 1], [2, 3]]), canvas_size=(4, 8))>>> kp
MyKeyPoints([[0, 1],
[2, 3]], canvas_size=(4, 8))
>>> tv_tensors.wrap(kp, like=kp)
KeyPoints([[0, 1],
[2, 3]], canvas_size=(4, 8))🚀 Proposal
If instead of
KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))it was changed to
type(like)._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))then the subclass can be propagated. I see a possible downside to this where subclasses maybe don't support the exact same kwargs (e.g. canvas_size for KeyPoints) however it seems preferable to explicitly fail in that case and force a different approach vs. the potentially silent errors caused currently when the return type does not match the type of like
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels