Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility between ORTModule and DeepSpeed #108

Closed
JingyaHuang opened this issue May 4, 2022 · 6 comments
Closed

Compatibility between ORTModule and DeepSpeed #108

JingyaHuang opened this issue May 4, 2022 · 6 comments

Comments

@JingyaHuang
Copy link

Hi folks,

I am recently working on validating distributed training features while using ORTModule, here are some incompatibilities that I found during some tests:

[With DeepSpeed]

  • ZeRO Stage 1 and 2 work well
  • ZeRO Stage 3 ❌

Warnings:

/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_io.py:558: 
UserWarning: This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX. Compute will continue,  but  unexpected results may occur!  
warnings.warn("This model cannot be deep copied (or pickled)  
  • BF16 ❌

Error Message:

RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:752
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, 
const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> 
[ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(bfloat16)' of input parameter
(_original_module.distilbert.embeddings.word_embeddings.weight) of operator (ATen) in node (ATen_17) is invalid

[With Fairscale]

  • Can only shard optimizer state

Environment

  • OS: Ubuntu 20.04
  • CUDA/cuDNN version: 11.3/8
  • onnxruntime-training: 1.11.1+cu113
  • torch: 1.11.0+cu113
  • torch-ort: 1.11.1
  • Python version:3.8
  • GPU: A100

I would like to confirm with you folks if these behaviors are intended? And concerning the compatibility with DeepSpeed stage 3 and BF16, would it be possible to have some insights on if it would be supported in the future?

Thanks a lot!

@baijumeswani
Copy link
Collaborator

baijumeswani commented May 10, 2022

Hi @JingyaHuang .

Warnings:

/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_io.py:558: 
UserWarning: This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX. Compute will continue,  but  unexpected results may occur!  
warnings.warn("This model cannot be deep copied (or pickled) 

ORTModule runs the pytorch model first before exporting the model to onnx. Because of this requirement, it tries to make a deepcopy of the original model and execute that (so as not to disturb the states in the original model while the export happens). But because this model is not deepcopyable, we need to issue the warning indicating that the model being exported might have some state change due to a single model execution before it is exported.
In most cases, this should be a non-issue. If you encounter a problem, please reach out to us.

Error Message:

RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:752
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, 
const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> 
[ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(bfloat16)' of input parameter
(_original_module.distilbert.embeddings.word_embeddings.weight) of operator (ATen) in node (ATen_17) is invalid

Looking at the source code, it seems we have not added support for bfloat16 for executing an Aten op yet. I believe that we have plans on adding that soon. Let me circle back on this and provide more details.

Thanks for opening this issue.

@baijumeswani
Copy link
Collaborator

@iK1D for reference.

@baijumeswani
Copy link
Collaborator

After syncing internally, I can confirm that we will create a work item to add support for bfloat16 for Aten op execution and plan to have it completed in the near future. I'll leave this issue open and will update it as and when we initiate/complete the work.

@JingyaHuang
Copy link
Author

Hi @baijumeswani ,
Thanks for the clarity, we will also keep tracking it from our side!

@baijumeswani
Copy link
Collaborator

This has been addressed in the pull request microsoft/onnxruntime#11546. Please try out the nightly onnxruntime-training to evaluate if the fix works for you.

Closing this issue now. Please re-open or open another one if you need help.

@JingyaHuang
Copy link
Author

Hi @baijumeswani ,

Thanks for adding the BF16 support for the ATen operator! I just tested it with:

  • nightly built onnxruntime-training 1.12.0.dev20220523001+cu113
  • onnx 1.11.0 (opset=15)
  • Deepspeed stage 1/2, BF16 enabled

This time, it seems to be good with ATen, but I came up with another error as follows:

RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:837  
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, 
onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const 
pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 10 : 
INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(bfloat16)' of input parameter (onnx::Pow_131) of 
operator (Pow) in node (Pow_26) is invalid.

If not mistaken, although POW BF16 is supported on ONNX 1.11.0, it and maybe other essential operators in transformers are not registered in ONNX Runtime for BF16, which leads to the failure on training. Is that correctly understood?

Besides, one thing I can not understand well is that the training by ORTModule with BF16 enabled works well, whereas it doesn't work when deepspeed(stage 1 nor stage 2) is enabled. Could you explain a little bit more about it?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants