-
Notifications
You must be signed in to change notification settings - Fork 559
support amp (auto mixed precision) #2654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@yaochengji Sorry for the delay, will take a look today. |
Thanks, @JackCaoG. Note the torch_xla/amp folder is almost the same as the amp folder in pytorch. |
@yaochengji Thanks for contributing to pytorch/xla! Would you mind adding a bit more context here, e.g. what's your main use case for pytorch/xla + amp, and with this PR have you see meaningful perf gains on some common models? Thanks a lot! |
Hi @ailzhang , I mainly use torch/xla to accelerate pytorch training on GPUs. And for almost all the CV models, we'd better enable amp. The torch/xla amp-training of resnet-50 on V100 with the tensorflow fix could reach 1150 images/s, which could get 50% speedup compared to pytorch-amp, refer here. And only a little slower than tf amp, refer here. |
wow, that performance improvement with xla:gpu is amazing. Copying files from pytorch doesn't sounds like a very good idea. If the amp logic is device agnostic, is there a way to make xla share the same code instead of copying it? |
Hi @JackCaoG, I did some modification in pytorch then I could simplify the code change here. |
Sorry for the delay, I need a bit of time to take a look at this one. |
6305eaa
to
3d3662c
Compare
Sorry for the delay, I will try to take a look soon |
Thanks, @JackCaoG. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaochengji Thanks for contributing! I have not finished reviewing the whole thing, will circle back after vacation.
torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, some nits
torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp
Outdated
Show resolved
Hide resolved
torch_xla/csrc/xla_lower_util.cpp
Outdated
scatter_dnums); | ||
} | ||
|
||
std::vector<xla::XlaOp> BuildAmpForachNonFiniteCheckAndUnscale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mentioned And my XlaOp could only handle the scaler one.
Could you point me to where is that limitation coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ex. https://github.com/pytorch/xla/pull/2654/files#diff-5555c0238ce581790db5f184fb764328a269c974d5019e8d5b66bcedbc545fefR761, the result of xla::AllReduce
is a scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry could you give a bit more details? Why does ReduceAll
return type matters here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of xla::AllReduce
is a scalar, which could not be used together with a XlaOp
with shape (1, )
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BuildAmpForachNonFiniteCheckAndUnscale
is typo for for each
???
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the change, I think we are pretty close!
torch_xla/csrc/xla_lower_util.cpp
Outdated
scatter_dnums); | ||
} | ||
|
||
std::vector<xla::XlaOp> BuildAmpForachNonFiniteCheckAndUnscale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry could you give a bit more details? Why does ReduceAll
return type matters here?
eb9c01d
to
efd803d
Compare
@yaochengji Can you rebase this branch? I think you need #2685 |
@JackCaoG rebased, and CI could not pass until corresponding pytorch PR is merged. |
@yaochengji You can add a |
gpu test failed with message
@yaochengji Could you take a look? Thanks! |
I'm looking into it. It is due to I changed the |
I could fix
|
@JackCaoG , I already fixed the shape (1, ) issue. And I ran BTW, |
@yaochengji thanks for the update, I will try to take a look later today. So the |
Yes, it seems from the |
Thanks @yaochengji ! btw, did you ever try 4 gpu setup? Did you ever encounter the error described in #2758? |
I just reproduced the error on another clean machine on The error of 2 gpus and 4 gpus setup are the same: 1 gpu could pass, that is why |
To summarize: on
|
Thanks for the confirmation. The team is pretty busy with the upcoming 1.8 release and tpuvm public preview work. I will try to take a look before tmr. Since this error is not introduced by this pr, I think we can currently workaround it using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can merge this to master first, I will take a look at the multi GPU test failure tmr.
Thanks @yaochengji for contributing!
Great work @yaochengji , thanks for the contribution! My hand is a bit tight with tpuvm stuff right now but will try to review and import the XRT pr when I am a bit more free. |
Note this PR should be used together with the pytorch PR.