New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added model_fn to support non-standard model function in create_trainer (#3055) #3074
Conversation
@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py |
I have made these changes, please check it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update @invoker-bot
Few improvements to add and it can be good to be merged
|
||
loss[0] = mse_loss(_y_pred, _y).item() | ||
|
||
# loss[0] = mse_loss(model(_x), _y).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove commented code
|
||
@trainer.on(Events.ITERATION_COMPLETED(every=gradient_accumulation_steps)) | ||
def _(): | ||
theta[0] -= accumulation[0] / gradient_accumulation_steps | ||
assert pytest.approx(model.fc.weight.data[0, 0].item(), abs=1.0e-5) == theta[0] | ||
assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0] | ||
print("loss:", loss[0], "theta:", theta[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove this print
@invoker-bot please run code style formatting script to fix CI issues: bash ./tests/run_code_style.sh install
bash ./tests/run_code_style.sh fmt |
@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release? |
I have fixed this issue now, please check it. |
@invoker-bot thanks, please also check above comments:
-> let me fix them myself to accelerate the review process |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @invoker-bot
Fixes #3055
Description:
Now we can define our custom
model_fn
increate_supervised_trainer
andcreate_supervised_evaluator
.Check list: