Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to register custom transform kernel #8223

Open
sungchul2 opened this issue Jan 19, 2024 · 0 comments
Open

Enable to register custom transform kernel #8223

sungchul2 opened this issue Jan 19, 2024 · 0 comments

Comments

@sungchul2
Copy link

馃殌 The feature

It would be good to be able to register custom transform kernels in v2.function.

Motivation, pitch

If I want to register the transform's kernel, which is incompatible with built-in torchvision transforms and the functional API, and which uses built-in tv_tensor classes, it will be blocked by checking if it is from built-in function.

elif not (
callable(functional)
and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
f"but got {functional}."
)

if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")

def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
    ...

@F.register_kernel(custom_transform_kernel, tv_tensors.TVTensor)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
    output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
    return tv_tensors.wrap(output, like=inpt)

class CustomTransform(tvt_v2.Transform):
    def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
        return self._call_kernel(custom_transform_kernel, inpt)

It would be more flexible if registering incompatible custom transform kernel is possible.

Alternatives

I tried to use @F._utils._register_kernel_internal instead, and it works.
But I think it could not be a safe way.

def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
    ...

@F._utils._register_kernel_internal(custom_transform_kernel, tv_tensors.TVTensor, tv_tensor_wrapper=False)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
    output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
    return tv_tensors.wrap(output, like=inpt)

class CustomTransform(tvt_v2.Transform):
    def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
        return self._call_kernel(custom_transform_kernel, inpt)

Additional context

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant