Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a feature "support multi params to call forward method"? #3055

Closed
invoker-bot opened this issue Sep 3, 2023 · 7 comments · Fixed by #3074
Closed

Add a feature "support multi params to call forward method"? #3055

invoker-bot opened this issue Sep 3, 2023 · 7 comments · Fixed by #3074

Comments

@invoker-bot
Copy link
Contributor

🚀 Feature

I think in some cases we may need pass multi params to the model. For example,
if the model has two or more inputs or some states (LSTM, etc.).
It is easy to implement this feature. The source code is here:
https://github.com/pytorch/ignite/blob/34a707e53785cf8a524589f33a570a7516fe064e/ignite/engine/__init__.py#L107C26-L107C26, and we only need to replace

output = model(x)

with

if isinstance(x, tuple):
   output = model(*x)
else:
   output = model(x)

. Then multi params can be pass by prepare_batch, and this change is compatible with single param.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 3, 2023

Thanks for the feature request @invoker-bot !

Usually, for such cases we suggest to write a custom train_step like :

def train_step(engine, batch):
    ...
    output = model(*x)
    ...

trainer = Engine(train_step)

The only problem with your proposal that I see is that for all iterations during the training we'll be executing if/else code even if technically we know which branch should be taken.

I wonder if we provide a new arg in create_supervised_trainer like below would help:

def create_supervised_trainer(model, ..., model_fn = lambda model, x: model(x))
    def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
        ...
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        output = model_fn(model, x)
        ...

    return Engine(update)

What do you think ?

@invoker-bot
Copy link
Contributor Author

Thanks for the feature request @invoker-bot !

Usually, for such cases we suggest to write a custom train_step like :

def train_step(engine, batch):

    ...

    output = model(*x)

    ...



trainer = Engine(train_step)

The only problem with your proposal that I see is that for all iterations during the training we'll be executing if/else code even if technically we know which branch should be taken.

I wonder if we provide a new arg in create_supervised_trainer like below would help:

def create_supervised_trainer(model, ..., model_fn = lambda model, x: model(x))

    def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:

        ...

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)

        output = model_fn(model, x)

        ...



    return Engine(update)

What do you think ?

Thank you for your reply. I think it's a good idea to add a parameter like model_fn. If we need to write the custom update function, then we need to rewrite another update function when creating an evaluate trainer. I think the great thing about ignite is it simplifies the boring work and the code becomes more concise.

@github-actions
Copy link

github-actions bot commented Sep 3, 2023

Hey 👋, I've just created a thread for this issue on PyTorch-Ignite Discord where you can quickly talk to the community on the topic.

🤖 This comment was automatically posted by Discuss on Discord

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 3, 2023

@invoker-bot sounds good. If you would like to help with coding this feature, you are very welcome!

@invoker-bot
Copy link
Contributor Author

@vfdev-5 I'd like to. I will make a PR soon.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 26, 2023

Hey @invoker-bot will you be able to send a PR with the feature?

@invoker-bot
Copy link
Contributor Author

Hey @invoker-bot will you be able to send a PR with the feature?

Sorry, it takes some time for me to familiarize with the code and configure the developer environment. I just sent a PR, please take a look.

vfdev-5 added a commit that referenced this issue Oct 3, 2023
…_trainer (#3055) (#3074)

* feat: Added model_fn to support non-standard model function in create_trainer (#3055)

* Added versionchanged docstring and tests for model_fn (#3074)

* fix style formatting

* Update test_create_supervised.py

* Apply suggestions from code review

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants