New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a feature "support multi params to call forward method"? #3055
Comments
Thanks for the feature request @invoker-bot ! Usually, for such cases we suggest to write a custom def train_step(engine, batch):
...
output = model(*x)
...
trainer = Engine(train_step) The only problem with your proposal that I see is that for all iterations during the training we'll be executing if/else code even if technically we know which branch should be taken. I wonder if we provide a new arg in def create_supervised_trainer(model, ..., model_fn = lambda model, x: model(x))
def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
...
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model_fn(model, x)
...
return Engine(update) What do you think ? |
Thank you for your reply. I think it's a good idea to add a parameter like model_fn. If we need to write the custom update function, then we need to rewrite another update function when creating an evaluate trainer. I think the great thing about ignite is it simplifies the boring work and the code becomes more concise. |
Hey 👋, I've just created a thread for this issue on PyTorch-Ignite Discord where you can quickly talk to the community on the topic. 🤖 This comment was automatically posted by Discuss on Discord |
@invoker-bot sounds good. If you would like to help with coding this feature, you are very welcome! |
@vfdev-5 I'd like to. I will make a PR soon. |
Hey @invoker-bot will you be able to send a PR with the feature? |
Sorry, it takes some time for me to familiarize with the code and configure the developer environment. I just sent a PR, please take a look. |
…_trainer (#3055) (#3074) * feat: Added model_fn to support non-standard model function in create_trainer (#3055) * Added versionchanged docstring and tests for model_fn (#3074) * fix style formatting * Update test_create_supervised.py * Apply suggestions from code review --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
🚀 Feature
I think in some cases we may need pass multi params to the model. For example,
if the model has two or more inputs or some states (LSTM, etc.).
It is easy to implement this feature. The source code is here:
https://github.com/pytorch/ignite/blob/34a707e53785cf8a524589f33a570a7516fe064e/ignite/engine/__init__.py#L107C26-L107C26, and we only need to replace
with
. Then multi params can be pass by
prepare_batch
, and this change is compatible with single param.The text was updated successfully, but these errors were encountered: