You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Do we have the training script to pre-train a llama-7B model using GPU such as A100? Current examples are based on TPU. Don't know if there are some difference. thanks.
The text was updated successfully, but these errors were encountered:
I believe the configuration would be very similar, although you might need to tune the mesh dimensions according to your cluster configuration and network topology to get the best performance. Specifically, you'll want to add these options when training on GPUs in a multihost environment:
python -m EasyLM.models.llama.llama_train \
--jax_distributed.initialize_jax_distributed=True \
--jax_distributed.coordinator_address=<your coordinator (process 0) address and port> \
--jax_distributed.num_processes=<total number of processes (hosts)> \
--jax_distributed.process_id=<current process id>
Do we have the training script to pre-train a llama-7B model using GPU such as A100? Current examples are based on TPU. Don't know if there are some difference. thanks.
The text was updated successfully, but these errors were encountered: