Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How did you train the large-sized models without out-of-memory? #27

Closed
jiang719 opened this issue Aug 6, 2022 · 3 comments
Closed

How did you train the large-sized models without out-of-memory? #27

jiang719 opened this issue Aug 6, 2022 · 3 comments

Comments

@jiang719
Copy link

jiang719 commented Aug 6, 2022

I would like to fine-tune the 2B model, but I got the out-of-memory issue even with the batch size setting to 1 (on a single GPU with 24G memory).

I wonder what devices you used to pre-train the 2B and 16B models? How did you address the memory issue? Did you parallel the model by layers on different GPUs? Thank you.

Nan

@enijkamp
Copy link
Contributor

enijkamp commented Aug 7, 2022

The models were pre-trained in JAX and TPU-v4 hardware and then later converted to PyTorch for sampling.

The training code in JAX will be released soon.

You may try to fine-tune the models in PyTorch using DeepSpeed:

https://news.ycombinator.com/item?id=32331764

@xanderdunn
Copy link

Training code in JAX has been released: #16 (comment)

@enijkamp
Copy link
Contributor

enijkamp commented Oct 4, 2022

@jiang719 Here is DeepSpeed fine-tuning code with CPU parameter offloading, so that you should be able to avoid OOM:

https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py

@enijkamp enijkamp closed this as completed Oct 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants