Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added model_fn to support non-standard model function in create_trainer (#3055) #3074

Merged
merged 6 commits into from Oct 3, 2023

Conversation

invoker-bot
Copy link
Contributor

@invoker-bot invoker-bot commented Sep 27, 2023

Fixes #3055

Description:

Now we can define our custom model_fn in create_supervised_trainer and create_supervised_evaluator.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: engine Engine module label Sep 27, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 27, 2023

@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py

@invoker-bot
Copy link
Contributor Author

@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py

I have made these changes, please check it.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @invoker-bot
Few improvements to add and it can be good to be merged


loss[0] = mse_loss(_y_pred, _y).item()

# loss[0] = mse_loss(model(_x), _y).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove commented code


@trainer.on(Events.ITERATION_COMPLETED(every=gradient_accumulation_steps))
def _():
theta[0] -= accumulation[0] / gradient_accumulation_steps
assert pytest.approx(model.fc.weight.data[0, 0].item(), abs=1.0e-5) == theta[0]
assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0]
print("loss:", loss[0], "theta:", theta[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this print

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 29, 2023

@invoker-bot please run code style formatting script to fix CI issues:

bash ./tests/run_code_style.sh install
bash ./tests/run_code_style.sh fmt

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 3, 2023

@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release?

@invoker-bot
Copy link
Contributor Author

@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release?

I have fixed this issue now, please check it.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 3, 2023

ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @invoker-bot

@vfdev-5 vfdev-5 merged commit b8751f2 into pytorch:master Oct 3, 2023
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: engine Engine module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a feature "support multi params to call forward method"?
2 participants