New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient checkpointing failed in xla_device #5766
Comments
we used https://github.com/pytorch/xla/blob/master/torch_xla/utils/checkpoint.py but it is pretty much copied from upstream and add the |
it works, convert |
@JackCaoG, here the work-around was to switch from torch.utils.checkpoint.checkpoint to torch_xla.utils.checkpoint.checkpoint. However, it would be better to restore the ability to use torch.utils.checkpoint.checkpoint which was working in 1.13. What are your thoughts? |
IMO |
❓ Questions and Help
I try to fine-tune a large language model on xla_device, these models come from huggingface. The error is reported
The function torch.utils.checkpoint.checkpoint decorated with _disable_dynamo.
https://github.com/pytorch/pytorch/blob/0d95378341b4eb19849295c7ccab08cc9be328a7/torch/utils/checkpoint.py#L341
Does this mean that if the model's device is set to xla, then torch.utils.checkpoint.checkpoint cannot be used?
If so, are there any alternative approaches to avoid using gradient checkpoint in LLM?
Any help on this would be greatly appreciated!
fine-tune.py
fine-tune.py can be run by executing the command
dataset: https://github.com/baichuan-inc/Baichuan2/tree/main/fine-tune/data
model: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/tree/main
torch version: Version: 2.1.0
torch-xla:2.1.0
The text was updated successfully, but these errors were encountered: