Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[create_supervised_trainer] add automatic mixed precision #1589

Merged
merged 54 commits into from Feb 21, 2021

Conversation

ydcjeff
Copy link
Contributor

@ydcjeff ydcjeff commented Jan 29, 2021

Fixes #1235

Description: Add automatic mixed precision using torch.cuda.amp and apex.

Usage:

# using autocast only
trainer = create_supervised_trainer(amp_mode='amp')

# using autocast and default scaler
trainer = create_supervised_trainer(amp_mode='amp', scaler=True)
# trainer state will have attribute scaler
print(trainer.state.scaler)
<torch.cuda.amp.grad_scaler.GradScaler object at 0x7f8e0dac7b80>

# using autocast and custom scaler
# trainer state will not have attribute scaler if scaler instance is passed
scaler = GradScaler(2**10)
trainer = create_supervised_trainer(amp_mode='amp', scaler=scaler)

# using apex
trainer = create_supervised_trainer(amp_mode='apex')

# scaler will be ignored, warning will show up
trainer = create_supervised_trainer(amp_mode='apex', scaler=True)

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

https://deploy-preview-1589--pytorch-ignite-preview.netlify.app/engine.html#

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @ydcjeff !

docs/source/conf.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 29, 2021

Currently, launched it manually : https://app.circleci.com/pipelines/github/pytorch/ignite/1195/workflows/27d5b840-72bb-41e1-8d1c-84640f1f623c but I think either next commit or new PR will run auto on GPUs

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Jan 29, 2021

Currently, launched it manually : https://app.circleci.com/pipelines/github/pytorch/ignite/1195/workflows/27d5b840-72bb-41e1-8d1c-84640f1f623c but I think either next commit or new PR will run auto on GPUs

Thank you!

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Jan 30, 2021

I need help with the tests specifically with apex and GradScaler.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 30, 2021

I need help with the tests specifically with apex and GradScaler.

I'll try to implement something from my side and we'll see.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 10, 2021

We discussed this PR and related issue with the team and we think that we should explore a bit different approach. Helper method create_supervised_trainer is roughly made of 2 things : update function definition and Engine setup.

Probably, it would more helpful to provide public methods like:

  • supervised_training_step
  • supervised_training_step_tpu
  • supervised_training_step_apex
  • supervised_training_step_amp

and inside create_supervised_trainer we could setup the trainer according to provided options without lots of if/else. Maybe, we can skip for instance grad norm.

Basically, the idea is something like that :

def get_training_step_1(a):
    def training_step(e, b):
        print(a, e, b)
    return training_step
    
def get_training_step_2(a):
    def training_step(e, b):
        print(a, e, b, "with amp")
    return training_step

def create_supervised_trainer(a, opt):
    training_step = None
    if opt == 1:
        training_step = get_training_step_1(a)
    elif opt == 2:
        training_step = get_training_step_2(a)
        
    e = Engine(training_step)
    return e  

cc @sdesrozis any other ideas or thoughts ?

@sdesrozis
Copy link
Contributor

That would be great for users to have these functions, helpful to check under the hood.

My thoughts on this topic is about the update function. The dream would be to pass a generic function and have an automatic (or close) tools to adapt to features like amp, tpu, etc.

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 13, 2021

Shall we also accept scaler argument or create only internally?

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

ignite/engine/__init__.py Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 15, 2021

@sdesrozis can you review the PR please

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 20, 2021

Thank you @sdesrozis for your review.
if gpu tests pass, we are ready to merge.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as well ! Thanks a lot @ydcjeff !
I left few nit comments about removing TPU mentions where it is inappropriate.
The comment about the warning and adding usage examples of these features could be done in a follow-up PR...

ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
ignite/engine/__init__.py Outdated Show resolved Hide resolved
@sdesrozis
Copy link
Contributor

Could we add tests to decrease codecov warnings ?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

Could we add tests to decrease codecov warnings ?

Let's do that all in a follow-up PR :)

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 21, 2021

Could we add tests to decrease codecov warnings ?

I think those warnings are from one_gpu_tests failing to upload coverage.

The comment about the warning and adding usage examples of these features could be done in a follow-up PR...

will do that.

@ydcjeff ydcjeff changed the title [create_supervised_trainer] add amp and grad_norm [create_supervised_trainer] add automatic mixed precision Feb 21, 2021
@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 21, 2021

Found out that amp module is available in torch 1.5 which doesn't have autocast yet.
Changed to ImportError to handle all torch version.

@sdesrozis
Copy link
Contributor

I think those warnings are from one_gpu_tests failing to upload coverage.

Do we know why it does not work ?

@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 21, 2021

I think those warnings are from one_gpu_tests failing to upload coverage.

Do we know why it does not work ?

I don't know very well, but it can be codecov failed to upload to its server.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 21, 2021

I think those warnings are from one_gpu_tests failing to upload coverage.

Do we know why it does not work ?

Asked here: codecov/codecov-bash#411

Anyway, if there is no way to fix the uploading we can remove - Z option and silently ignore uploading issue.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 21, 2021

@sdesrozis can you please merge this PR once the ci is OK.

@sdesrozis sdesrozis merged commit f379b18 into pytorch:master Feb 21, 2021
@ydcjeff ydcjeff deleted the engine/create_supervised_trainer branch February 21, 2021 10:31
@ydcjeff
Copy link
Contributor Author

ydcjeff commented Feb 21, 2021

Thank you for your help and reviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add more options to create_supervised_trainer
3 participants