New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Function Registry for extending collate_fn #97498
Comments
The proposal looks reasonable, please feel free to submit a PR? I don't think you need |
Sure thing, I'll submit one this week.
Can I ask your opinion on other utilities in PyTorch? I think bringing Currently, the @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return HANDLED_FUNCTIONS[func](*args, **kwargs) It would be easier if it uses a similar Registry mechanism. |
IMO, it would be better to have a non-global Collater class that lets modify its behavior locally by operating its intance methods Calling this register function will require modifying the user code anyway, so it's a good moment to propose a more local and evolvable pattern (of Collater class) Having a proper Collater class is better, even if some global singleton instance is supported too |
Could you please elaborate more on the Collater class? I'm not sure if I understand it correctly. |
Why not introduce a new magic function, maybe named as |
Old related discussions: #33181. For nested tensor: #27617 #1512 Collater meaning sth like this: class BatchCollater:
def __init__(self, ...): # with args modifying behavior
# the registry adn other configuration would be contained in local fields and not globally
# self.registry = ...
pass
# methods modifying behavior
def __call__(self, *args):
# do actual collate using instance fields for configuration read-up
pass then this could be set up like: collate_fn = BatchCollater()
# some calls collate_fn.set_up_this_or_that_field_collation()
data_loader = DataLoader(..., collate_fn = collate_fn) and existing |
I'll track this approach under a new PR (#98575) |
@vadimkantorov Could you please take a look at #98575 to see if it's what you desired? |
Kind of, yes! IMO having these things as instance methods is better because different datasets may need different collater object. Maybe the global instance could directly be called Going further, I"m not the best person to give feedbacks, because since those old times (2017-2019) I changed mind on whether collation should be "automatic", because different kinds of input padding may be more convenient for different situations, and having to "configure" something rather than having it explicitly as code. So I would say, the most important thing for this would be getting the API right. E.g. if it's working with dicts/tuples as input, then maybe one should better configure collation per input key rather than per data type. So IMO getting this kind of design right is important, but I myself don't have very concrete ideas. |
@ssnl @VitalyFedyunin @ejguan @NivekT @dzhulgakov |
Thanks for opening this. I commented in #98194 |
馃殌 The feature, motivation and pitch
Previous improvements on
collate_fn
has allow custom types incollate_fn
In this Enhancement Proposal, we would like to add Registry to take one step further to make extension even smoother.
Currently, to extend the
collate_fn
, we need to have the following code.However,
default_collate_fn_map
is not exported intorch.utils.data
and is under a protected sub-package.The process would be much smoother if we have a registry for
collate_fn
s, so that the process will become:Alternatives
Alternatively, we could export
default_collate_fn_map
intorch.utils.data
to avoid access of protected sub-package.Additional context
The following code demonstrates how we use registry in extending PyTorch functions:
Registry is basically a dict with some additional methods. The following code demonstrates how we define our registry:
Note that
NestedDict
is a subclass ofdict
, and can be considered as identical todict
.cc @ssnl @VitalyFedyunin @ejguan @NivekT @dzhulgakov
The text was updated successfully, but these errors were encountered: