Skip to content

Function Registry for extending collate_fn #97498

@ZhiyuanChen

Description

@ZhiyuanChen

🚀 The feature, motivation and pitch

Previous improvements on collate_fn has allow custom types in collate_fn

In this Enhancement Proposal, we would like to add Registry to take one step further to make extension even smoother.

Currently, to extend the collate_fn, we need to have the following code.

from torch.utils import data


def collate_nested_tensor_fn(batch):
    NestedTensor(batch)


data._utils.collate.default_collate_fn_map.update({PNTensor: collate_nested_tensor_fn})

However, default_collate_fn_map is not exported in torch.utils.data and is under a protected sub-package.

The process would be much smoother if we have a registry for collate_fns, so that the process will become:

from torch.utils import data


@data.collate_fns.register(PNTensor)
def collate_nested_tensor_fn(batch):
    NestedTensor(batch)

Alternatives

Alternatively, we could export default_collate_fn_map in torch.utils.data to avoid access of protected sub-package.

Additional context

The following code demonstrates how we use registry in extending PyTorch functions:

from functools import wraps
from typing import Callable

from ..registry import Registry


class TorchFuncRegistry(Registry):
    """
    `TorchFuncRegistry` for extending PyTorch Tensor.
    """

    def implement(self, torch_function: Callable) -> Callable:
        r"""
        Implement an implementation for a torch function.

        Args:
            function: The torch function to implement.

        Returns:
            function: The registered function.

        Raises:
            ValueError: If the function with the same name already registered and `TorchFuncRegistry.override=False`.

        Examples:
        ```python
        >>> import torch
        >>> registry = TorchFuncRegistry("test")
        >>> @registry.implement(torch.mean)  # pylint: disable=E1101
        ... def mean(input):
        ...     raise input.mean()
        >>> registry  # doctest: +ELLIPSIS
        TorchFuncRegistry(
          (<built-in method mean of type object at ...>): <function mean at ...>
        )

        ```
        """

        if torch_function in self and not self.override:
            raise ValueError(f"Torch function {torch_function.__name__} already registered.")

        @wraps(self.register)
        def register(function):
            self.set(torch_function, function)
            return function

        return register


NestedTensorFunc = TorchFuncRegistry()


@NestedTensorFunc.implement(torch.mean)  # pylint: disable=E1101
def mean(
    input,  # pylint: disable=W0622
    dim: Optional[int] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
):
    return input.mean(dim=dim, keepdim=keepdim, dtype=dtype)


@NestedTensorFunc.implement(torch.cat)  # pylint: disable=E1101
def cat(tensors, dim: int = 0):
    if dim != 0:
        raise NotImplementedError(f"NestedTensor only supports cat when dim=0, but got {dim}")
    return NestedTensor([t for tensor in tensors for t in tensor.storage])

Registry is basically a dict with some additional methods. The following code demonstrates how we define our registry:

Note that NestedDict is a subclass of dict, and can be considered as identical to dict.

class Registry(NestedDict):
    """
    `Registry` for components.

    Notes:
        `Registry` inherits from [`NestedDict`](https://chanfig.danling.org/nested_dict/).

        Therefore, `Registry` comes in a nested structure by nature.
        You could create a sub-registry by simply calling `registry.sub_registry = Registry`,
        and access through `registry.sub_registry.register()`.

    Examples:
    ```python
    >>> registry = Registry("test")
    >>> @registry.register
    ... @registry.register("Module1")
    ... class Module:
    ...     def __init__(self, a, b):
    ...         self.a = a
    ...         self.b = b
    >>> module = registry.register(Module, "Module2")
    >>> registry
    Registry(
      ('Module1'): <class 'danling.registry.Module'>
      ('Module'): <class 'danling.registry.Module'>
      ('Module2'): <class 'danling.registry.Module'>
    )
    >>> registry.lookup("Module")
    <class 'danling.registry.Module'>
    >>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
    >>> # registry.register(Module)
    >>> module = registry.build(config["module"])
    >>> type(module)
    <class 'danling.registry.Module'>
    >>> module.a
    1
    >>> module.b
    2

    ```
    """

    override: bool = False

    def __init__(self, override: bool = False):
        super().__init__()
        self.setattr("override", override)

    def register(self, component: Optional[Callable] = None, name: Optional[str] = None) -> Callable:
        r"""
        Register a new component.

        Args:
            component: The component to register.
            name: The name of the component.

        Returns:
            component: The registered component.

        Raises:
            ValueError: If the component with the same name already registered and `Registry.override=False`.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... @registry.register("Module1")
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> module = registry.register(Module, "Module2")
        >>> registry
        Registry(
          ('Module1'): <class 'danling.registry.Module'>
          ('Module'): <class 'danling.registry.Module'>
          ('Module2'): <class 'danling.registry.Module'>
        )

        ```
        """

        if name in self and not self.override:
            raise ValueError(f"Component with name {name} already registered.")

        # Registry.register()
        if name is not None:
            self.set(name, component)

        # @Registry.register()
        @wraps(self.register)
        def register(component, name=None):
            if name is None:
                name = component.__name__
            self.set(name, component)
            return component

        # @Registry.register
        if callable(component) and name is None:
            return register(component)

        return lambda x: register(x, component)

    def lookup(self, name: str) -> Any:
        r"""
        Lookup for a component.

        Args:
            name:

        Returns:
            (Any): The component.

        Raises:
            KeyError: If the component is not registered.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> registry.lookup("Module")
        <class 'danling.registry.Module'>

        ```
        """

        return self.get(name)

    def build(self, name: Union[str, Mapping], *args, **kwargs) -> Any:
        r"""
        Build a component.

        Args:
            name (str | Mapping):
                If its a `Mapping`, it must contain `"name"` as a member, the rest will be treated as `**kwargs`.
                Note that values in `kwargs` will override values in `name` if its a `Mapping`.
            *args: The arguments to pass to the component.
            **kwargs: The keyword arguments to pass to the component.

        Returns:
            (Any):

        Raises:
            KeyError: If the component is not registered.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
        >>> # registry.register(Module)
        >>> module = registry.build(**config["module"])
        >>> type(module)
        <class 'danling.registry.Module'>
        >>> module.a
        1
        >>> module.b
        2
        >>> module = registry.build(config["module"], a=2)
        >>> module.a
        2

        ```
        """

        if isinstance(name, Mapping):
            name = deepcopy(name)
            name, kwargs = name.pop("name"), dict(name, **kwargs)  # type: ignore
        return self.get(name)(*args, **kwargs)  # type: ignore

    def __wrapped__(self, *args, **kwargs):
        pass

cc @ssnl @VitalyFedyunin @ejguan @NivekT @dzhulgakov

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: dataloaderRelated to torch.utils.data.DataLoader and SamplertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions