-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 The feature, motivation and pitch
Previous improvements on collate_fn
has allow custom types in collate_fn
In this Enhancement Proposal, we would like to add Registry to take one step further to make extension even smoother.
Currently, to extend the collate_fn
, we need to have the following code.
from torch.utils import data
def collate_nested_tensor_fn(batch):
NestedTensor(batch)
data._utils.collate.default_collate_fn_map.update({PNTensor: collate_nested_tensor_fn})
However, default_collate_fn_map
is not exported in torch.utils.data
and is under a protected sub-package.
The process would be much smoother if we have a registry for collate_fn
s, so that the process will become:
from torch.utils import data
@data.collate_fns.register(PNTensor)
def collate_nested_tensor_fn(batch):
NestedTensor(batch)
Alternatives
Alternatively, we could export default_collate_fn_map
in torch.utils.data
to avoid access of protected sub-package.
Additional context
The following code demonstrates how we use registry in extending PyTorch functions:
from functools import wraps
from typing import Callable
from ..registry import Registry
class TorchFuncRegistry(Registry):
"""
`TorchFuncRegistry` for extending PyTorch Tensor.
"""
def implement(self, torch_function: Callable) -> Callable:
r"""
Implement an implementation for a torch function.
Args:
function: The torch function to implement.
Returns:
function: The registered function.
Raises:
ValueError: If the function with the same name already registered and `TorchFuncRegistry.override=False`.
Examples:
```python
>>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean) # pylint: disable=E1101
... def mean(input):
... raise input.mean()
>>> registry # doctest: +ELLIPSIS
TorchFuncRegistry(
(<built-in method mean of type object at ...>): <function mean at ...>
)
```
"""
if torch_function in self and not self.override:
raise ValueError(f"Torch function {torch_function.__name__} already registered.")
@wraps(self.register)
def register(function):
self.set(torch_function, function)
return function
return register
NestedTensorFunc = TorchFuncRegistry()
@NestedTensorFunc.implement(torch.mean) # pylint: disable=E1101
def mean(
input, # pylint: disable=W0622
dim: Optional[int] = None,
keepdim: bool = False,
*,
dtype: Optional[torch.dtype] = None,
):
return input.mean(dim=dim, keepdim=keepdim, dtype=dtype)
@NestedTensorFunc.implement(torch.cat) # pylint: disable=E1101
def cat(tensors, dim: int = 0):
if dim != 0:
raise NotImplementedError(f"NestedTensor only supports cat when dim=0, but got {dim}")
return NestedTensor([t for tensor in tensors for t in tensor.storage])
Registry is basically a dict with some additional methods. The following code demonstrates how we define our registry:
Note that NestedDict
is a subclass of dict
, and can be considered as identical to dict
.
class Registry(NestedDict):
"""
`Registry` for components.
Notes:
`Registry` inherits from [`NestedDict`](https://chanfig.danling.org/nested_dict/).
Therefore, `Registry` comes in a nested structure by nature.
You could create a sub-registry by simply calling `registry.sub_registry = Registry`,
and access through `registry.sub_registry.register()`.
Examples:
```python
>>> registry = Registry("test")
>>> @registry.register
... @registry.register("Module1")
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> module = registry.register(Module, "Module2")
>>> registry
Registry(
('Module1'): <class 'danling.registry.Module'>
('Module'): <class 'danling.registry.Module'>
('Module2'): <class 'danling.registry.Module'>
)
>>> registry.lookup("Module")
<class 'danling.registry.Module'>
>>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
>>> # registry.register(Module)
>>> module = registry.build(config["module"])
>>> type(module)
<class 'danling.registry.Module'>
>>> module.a
1
>>> module.b
2
```
"""
override: bool = False
def __init__(self, override: bool = False):
super().__init__()
self.setattr("override", override)
def register(self, component: Optional[Callable] = None, name: Optional[str] = None) -> Callable:
r"""
Register a new component.
Args:
component: The component to register.
name: The name of the component.
Returns:
component: The registered component.
Raises:
ValueError: If the component with the same name already registered and `Registry.override=False`.
Examples:
```python
>>> registry = Registry("test")
>>> @registry.register
... @registry.register("Module1")
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> module = registry.register(Module, "Module2")
>>> registry
Registry(
('Module1'): <class 'danling.registry.Module'>
('Module'): <class 'danling.registry.Module'>
('Module2'): <class 'danling.registry.Module'>
)
```
"""
if name in self and not self.override:
raise ValueError(f"Component with name {name} already registered.")
# Registry.register()
if name is not None:
self.set(name, component)
# @Registry.register()
@wraps(self.register)
def register(component, name=None):
if name is None:
name = component.__name__
self.set(name, component)
return component
# @Registry.register
if callable(component) and name is None:
return register(component)
return lambda x: register(x, component)
def lookup(self, name: str) -> Any:
r"""
Lookup for a component.
Args:
name:
Returns:
(Any): The component.
Raises:
KeyError: If the component is not registered.
Examples:
```python
>>> registry = Registry("test")
>>> @registry.register
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> registry.lookup("Module")
<class 'danling.registry.Module'>
```
"""
return self.get(name)
def build(self, name: Union[str, Mapping], *args, **kwargs) -> Any:
r"""
Build a component.
Args:
name (str | Mapping):
If its a `Mapping`, it must contain `"name"` as a member, the rest will be treated as `**kwargs`.
Note that values in `kwargs` will override values in `name` if its a `Mapping`.
*args: The arguments to pass to the component.
**kwargs: The keyword arguments to pass to the component.
Returns:
(Any):
Raises:
KeyError: If the component is not registered.
Examples:
```python
>>> registry = Registry("test")
>>> @registry.register
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
>>> # registry.register(Module)
>>> module = registry.build(**config["module"])
>>> type(module)
<class 'danling.registry.Module'>
>>> module.a
1
>>> module.b
2
>>> module = registry.build(config["module"], a=2)
>>> module.a
2
```
"""
if isinstance(name, Mapping):
name = deepcopy(name)
name, kwargs = name.pop("name"), dict(name, **kwargs) # type: ignore
return self.get(name)(*args, **kwargs) # type: ignore
def __wrapped__(self, *args, **kwargs):
pass