Skip to content

tv_tensors.wrap doesn't work with subclasses of BoundingBoxes or KeyPoints #9328

@jd-anam

Description

@jd-anam

🐛 Describe the bug

When using tv_tensors.wrap on an instance of a subclass of either BoundingBoxes or KeyPoints it will instead return the base class since they are hard-coded in. For example:

from torchvision import tv_tensors

class MyKeyPoints(tv_tensors.KeyPoints):
    pass

kp = MyKeyPoints(torch.tensor([[0, 1], [2, 3]]), canvas_size=(4, 8))
>>> kp
MyKeyPoints([[0, 1],
             [2, 3]], canvas_size=(4, 8))
>>> tv_tensors.wrap(kp, like=kp)
KeyPoints([[0, 1],
           [2, 3]], canvas_size=(4, 8))

🚀 Proposal

If instead of

KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))

it was changed to

type(like)._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))

then the subclass can be propagated. I see a possible downside to this where subclasses maybe don't support the exact same kwargs (e.g. canvas_size for KeyPoints) however it seems preferable to explicitly fail in that case and force a different approach vs. the potentially silent errors caused currently when the return type does not match the type of like

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions