Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Module.forward signature with **kwargs #23732

Open
prokotg opened this issue Aug 2, 2019 · 5 comments
Open

nn.Module.forward signature with **kwargs #23732

prokotg opened this issue Aug 2, 2019 · 5 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: checkpoint Related to torch.utils.checkpoint module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@prokotg
Copy link
Contributor

prokotg commented Aug 2, 2019

Hi there!

The other day I was working on improving utils.checkpoint module and I found out there is no way right now to pass **kwargs to checkpointed function. This is probably in order to by compliant with the nn.Module.forward method.

Now, I could not think about any argument against having **kwargs in forward method and many people overload this method with **kwargs because it is nice and lit :) I was about to change it and create pull request but before I do it I decided to go smart and ask you guys if there is some obvious reason why we should not do it. Otherwise, I am more than happy to start working on it.

Thanks :- )

@fmassa
Copy link
Member

fmassa commented Aug 2, 2019

Adding a **kwargs argument to the base forward method won't change anything in the functionality I believe, given that __call__ already has thr kwargs and passes it to forward, see

def __call__(self, *input, **kwargs):

This being said, I don't think there is any problem exposing the kwargs in the base forward

@vincentqb vincentqb added module: checkpoint Related to torch.utils.checkpoint enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn triage review labels Aug 2, 2019
@ssnl
Copy link
Collaborator

ssnl commented Aug 2, 2019

There is nothing that prevents it but you gain nothing from this change as well because forward is always overriden.

@prokotg
Copy link
Contributor Author

prokotg commented Aug 2, 2019

True, no gain in functionality, just a change to be consistent I guess. I am going to try to change checkpoint module to take into account **kwargs though.

Thank you :)

@vincentqb vincentqb added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Aug 2, 2019
@saareliad
Copy link

saareliad commented Aug 12, 2019

Currently it could be solved if programmers write forward functions in a way which checkpoint will be able to pass *args to it (in order) instead using kwargs.

OR

Pass a dict function_kwargs to checkpoint, which will end up calling:

CheckpointFunction.apply(function, preserve, *args, **function_kwargs)
(and change CheckpointFunction accordingly)
instead of
CheckpointFunction.apply(function, preserve, *args)

For example, function_kwargs could be added as a kwarg to the checkpoint function.

IMO should be easy to implement in a backward compatible way.

@prokotg
Copy link
Contributor Author

prokotg commented Aug 12, 2019

@saareliad right, that's what the documentation says. The forward function should know what to do with **args so we have to pass it in some order and indeed you have to know what to do with input tuple. However, your forward method might require, for example, some parameters and you can have multiple of them and passing them without keyword could get messy.

Besides, current solution forces developers to rewrite their current, key-worded model to be keyword-args-less and I think since pytorch is aiming to be python-first, this should not take place.

I followed the data-flow path and basically started to change some code to incorporate kwargs
here

{(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},

and in the function itself to unpack the **kwargs dict

PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)

and I stumbled upon #16940

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: checkpoint Related to torch.utils.checkpoint module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants