Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function Registry for extending collate_fn #97498

Open
ZhiyuanChen opened this issue Mar 24, 2023 · 11 comments
Open

Function Registry for extending collate_fn #97498

ZhiyuanChen opened this issue Mar 24, 2023 · 11 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ZhiyuanChen
Copy link
Contributor

ZhiyuanChen commented Mar 24, 2023

馃殌 The feature, motivation and pitch

Previous improvements on collate_fn has allow custom types in collate_fn

In this Enhancement Proposal, we would like to add Registry to take one step further to make extension even smoother.

Currently, to extend the collate_fn, we need to have the following code.

from torch.utils import data


def collate_nested_tensor_fn(batch):
    NestedTensor(batch)


data._utils.collate.default_collate_fn_map.update({PNTensor: collate_nested_tensor_fn})

However, default_collate_fn_map is not exported in torch.utils.data and is under a protected sub-package.

The process would be much smoother if we have a registry for collate_fns, so that the process will become:

from torch.utils import data


@data.collate_fns.register(PNTensor)
def collate_nested_tensor_fn(batch):
    NestedTensor(batch)

Alternatives

Alternatively, we could export default_collate_fn_map in torch.utils.data to avoid access of protected sub-package.

Additional context

The following code demonstrates how we use registry in extending PyTorch functions:

from functools import wraps
from typing import Callable

from ..registry import Registry


class TorchFuncRegistry(Registry):
    """
    `TorchFuncRegistry` for extending PyTorch Tensor.
    """

    def implement(self, torch_function: Callable) -> Callable:
        r"""
        Implement an implementation for a torch function.

        Args:
            function: The torch function to implement.

        Returns:
            function: The registered function.

        Raises:
            ValueError: If the function with the same name already registered and `TorchFuncRegistry.override=False`.

        Examples:
        ```python
        >>> import torch
        >>> registry = TorchFuncRegistry("test")
        >>> @registry.implement(torch.mean)  # pylint: disable=E1101
        ... def mean(input):
        ...     raise input.mean()
        >>> registry  # doctest: +ELLIPSIS
        TorchFuncRegistry(
          (<built-in method mean of type object at ...>): <function mean at ...>
        )

        ```
        """

        if torch_function in self and not self.override:
            raise ValueError(f"Torch function {torch_function.__name__} already registered.")

        @wraps(self.register)
        def register(function):
            self.set(torch_function, function)
            return function

        return register


NestedTensorFunc = TorchFuncRegistry()


@NestedTensorFunc.implement(torch.mean)  # pylint: disable=E1101
def mean(
    input,  # pylint: disable=W0622
    dim: Optional[int] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
):
    return input.mean(dim=dim, keepdim=keepdim, dtype=dtype)


@NestedTensorFunc.implement(torch.cat)  # pylint: disable=E1101
def cat(tensors, dim: int = 0):
    if dim != 0:
        raise NotImplementedError(f"NestedTensor only supports cat when dim=0, but got {dim}")
    return NestedTensor([t for tensor in tensors for t in tensor.storage])

Registry is basically a dict with some additional methods. The following code demonstrates how we define our registry:

Note that NestedDict is a subclass of dict, and can be considered as identical to dict.

class Registry(NestedDict):
    """
    `Registry` for components.

    Notes:
        `Registry` inherits from [`NestedDict`](https://chanfig.danling.org/nested_dict/).

        Therefore, `Registry` comes in a nested structure by nature.
        You could create a sub-registry by simply calling `registry.sub_registry = Registry`,
        and access through `registry.sub_registry.register()`.

    Examples:
    ```python
    >>> registry = Registry("test")
    >>> @registry.register
    ... @registry.register("Module1")
    ... class Module:
    ...     def __init__(self, a, b):
    ...         self.a = a
    ...         self.b = b
    >>> module = registry.register(Module, "Module2")
    >>> registry
    Registry(
      ('Module1'): <class 'danling.registry.Module'>
      ('Module'): <class 'danling.registry.Module'>
      ('Module2'): <class 'danling.registry.Module'>
    )
    >>> registry.lookup("Module")
    <class 'danling.registry.Module'>
    >>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
    >>> # registry.register(Module)
    >>> module = registry.build(config["module"])
    >>> type(module)
    <class 'danling.registry.Module'>
    >>> module.a
    1
    >>> module.b
    2

    ```
    """

    override: bool = False

    def __init__(self, override: bool = False):
        super().__init__()
        self.setattr("override", override)

    def register(self, component: Optional[Callable] = None, name: Optional[str] = None) -> Callable:
        r"""
        Register a new component.

        Args:
            component: The component to register.
            name: The name of the component.

        Returns:
            component: The registered component.

        Raises:
            ValueError: If the component with the same name already registered and `Registry.override=False`.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... @registry.register("Module1")
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> module = registry.register(Module, "Module2")
        >>> registry
        Registry(
          ('Module1'): <class 'danling.registry.Module'>
          ('Module'): <class 'danling.registry.Module'>
          ('Module2'): <class 'danling.registry.Module'>
        )

        ```
        """

        if name in self and not self.override:
            raise ValueError(f"Component with name {name} already registered.")

        # Registry.register()
        if name is not None:
            self.set(name, component)

        # @Registry.register()
        @wraps(self.register)
        def register(component, name=None):
            if name is None:
                name = component.__name__
            self.set(name, component)
            return component

        # @Registry.register
        if callable(component) and name is None:
            return register(component)

        return lambda x: register(x, component)

    def lookup(self, name: str) -> Any:
        r"""
        Lookup for a component.

        Args:
            name:

        Returns:
            (Any): The component.

        Raises:
            KeyError: If the component is not registered.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> registry.lookup("Module")
        <class 'danling.registry.Module'>

        ```
        """

        return self.get(name)

    def build(self, name: Union[str, Mapping], *args, **kwargs) -> Any:
        r"""
        Build a component.

        Args:
            name (str | Mapping):
                If its a `Mapping`, it must contain `"name"` as a member, the rest will be treated as `**kwargs`.
                Note that values in `kwargs` will override values in `name` if its a `Mapping`.
            *args: The arguments to pass to the component.
            **kwargs: The keyword arguments to pass to the component.

        Returns:
            (Any):

        Raises:
            KeyError: If the component is not registered.

        Examples:
        ```python
        >>> registry = Registry("test")
        >>> @registry.register
        ... class Module:
        ...     def __init__(self, a, b):
        ...         self.a = a
        ...         self.b = b
        >>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
        >>> # registry.register(Module)
        >>> module = registry.build(**config["module"])
        >>> type(module)
        <class 'danling.registry.Module'>
        >>> module.a
        1
        >>> module.b
        2
        >>> module = registry.build(config["module"], a=2)
        >>> module.a
        2

        ```
        """

        if isinstance(name, Mapping):
            name = deepcopy(name)
            name, kwargs = name.pop("name"), dict(name, **kwargs)  # type: ignore
        return self.get(name)(*args, **kwargs)  # type: ignore

    def __wrapped__(self, *args, **kwargs):
        pass

cc @ssnl @VitalyFedyunin @ejguan @NivekT @dzhulgakov

@ngimel ngimel added module: dataloader Related to torch.utils.data.DataLoader and Sampler enhancement Not as big of a feature, but technically not a bug. Should be easy to fix triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 24, 2023
@dzhulgakov
Copy link
Collaborator

The proposal looks reasonable, please feel free to submit a PR?

I don't think you need class Registry complexity here as you're not instantiating long-lived components. And bringing such utility class would require considering other use cases within PyTorch. So maybe just do something like @data.register_default_collate_for and use a regular dict underneath

@ZhiyuanChen
Copy link
Contributor Author

ZhiyuanChen commented Mar 27, 2023

The proposal looks reasonable, please feel free to submit a PR?

Sure thing, I'll submit one this week.

I don't think you need class Registry complexity here as you're not instantiating long-lived components. And bringing such utility class would require considering other use cases within PyTorch. So maybe just do something like @data.register_default_collate_for and use a regular dict underneath

Can I ask your opinion on other utilities in PyTorch? I think bringing Registry could be beneficial for extending PyTorch Tensor too.

Currently, the __torch_function__ is defined as follows in document

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

It would be easier if it uses a similar Registry mechanism.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Apr 3, 2023

IMO, it would be better to have a non-global Collater class that lets modify its behavior locally by operating its intance methods

Calling this register function will require modifying the user code anyway, so it's a good moment to propose a more local and evolvable pattern (of Collater class)

Having a proper Collater class is better, even if some global singleton instance is supported too

@ZhiyuanChen
Copy link
Contributor Author

IMO, it would be better to have a non-global Collater class that lets modify its behavior locally by operating its intance methods

Could you please elaborate more on the Collater class? I'm not sure if I understand it correctly.

@ZhiyuanChen
Copy link
Contributor Author

ZhiyuanChen commented Apr 4, 2023

IMO, it would be better to have a non-global Collater class that lets modify its behavior locally by operating its intance methods

Calling this register function will require modifying the user code anyway, so it's a good moment to propose a more local and evolvable pattern (of Collater class)

Having a proper Collater class is better, even if some global singleton instance is supported too

Why not introduce a new magic function, maybe named as __torch_collate__. So that the users doesn't have to do anything about torch.utils.data?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Apr 4, 2023

Could you please elaborate more on the Collater class? I'm not sure if I understand it correctly.

Old related discussions: #33181. For nested tensor: #27617 #1512

Collater meaning sth like this:

class BatchCollater:
  def __init__(self, ...): # with args modifying behavior
    # the registry adn other configuration would be contained in local fields and not globally
    # self.registry = ...
    pass
  # methods modifying behavior
  def __call__(self, *args):
    # do actual collate using instance fields for configuration read-up
    pass

then this could be set up like:

collate_fn = BatchCollater()
# some calls collate_fn.set_up_this_or_that_field_collation()
data_loader = DataLoader(..., collate_fn = collate_fn)

and existing default_collate would also be an instance of this class

@ZhiyuanChen
Copy link
Contributor Author

Could you please elaborate more on the Collater class? I'm not sure if I understand it correctly.

Old related discussions: #33181. For nested tensor: #27617 #1512

Collater meaning sth like this:

class BatchCollater:
  def __init__(self, ...): # with args modifying behavior
    # the registry adn other configuration would be contained in local fields and not globally
    # self.registry = ...
    pass
  # methods modifying behavior
  def __call__(self, *args):
    # do actual collate using instance fields for configuration read-up
    pass

then this could be set up like:

collate_fn = BatchCollater()
# some calls collate_fn.set_up_this_or_that_field_collation()
data_loader = DataLoader(..., collate_fn = collate_fn)

and existing default_collate would also be an instance of this class

I'll track this approach under a new PR (#98575)

@ZhiyuanChen
Copy link
Contributor Author

@vadimkantorov Could you please take a look at #98575 to see if it's what you desired?

@vadimkantorov
Copy link
Contributor

Kind of, yes! IMO having these things as instance methods is better because different datasets may need different collater object. Maybe the global instance could directly be called default_collate.

Going further, I"m not the best person to give feedbacks, because since those old times (2017-2019) I changed mind on whether collation should be "automatic", because different kinds of input padding may be more convenient for different situations, and having to "configure" something rather than having it explicitly as code.

So I would say, the most important thing for this would be getting the API right. E.g. if it's working with dicts/tuples as input, then maybe one should better configure collation per input key rather than per data type. So IMO getting this kind of design right is important, but I myself don't have very concrete ideas.

@ZhiyuanChen
Copy link
Contributor Author

@ssnl @VitalyFedyunin @ejguan @NivekT @dzhulgakov
Can I ask your opinions on the approaches?

@NivekT
Copy link
Contributor

NivekT commented Apr 7, 2023

Thanks for opening this. I commented in #98194

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants