You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm looking into fine-tuning LLaMA-7b with EasyLM on a TPU v3-8. From my initial runs, I've found that I can get around 975 token/sec. I've tested all the flag combinations I can think of, but am unable to increase the batch size or gradient accumulation steps beyond 1 without OOMing.
I saw that you achieved a high throughput of 2,200 tokens/sec/TPU-v4 chip on OpenLLaMA-7b, and mesh-transformer-jax gets 5k/T/sec on a v3-8 for GPT-J, so I was curious if there was an issue in my config.
Here's how I'm running it:
# Removed "jax_enable_async_all_gather", as it causes a crash on a v3-8. Without these flags, the throughput is 590 tokens/sec.export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
python -m EasyLM.models.llama.llama_train \
--dtype='fp32'\ # bf16 causes errors - is it only intended for serving?
--mesh_dim='1,-1,1' \
--load_llama_config='7b' \
--optimizer.type='adamw' \
--train_dataset.json_dataset.seq_length=2048 \
--train_dataset.json_dataset.batch_size=1 \
# ... omitting other flags which shouldn't affect throughput
Do you have any tips? Or is higher throughput only expected on larger TPU pods?
Thanks!
The text was updated successfully, but these errors were encountered:
With full FSDP, you should be able to set batch size at least 8 for TPU v3-8 (for FSDP, the batch size should always be set to at least the size of FSDP axis). The bottleneck here for your case is most likely memory, since storing the weights and optimizer states takes over 84GB memory, which leaves very little memory for activations. You might need a larger pod to realize the high throughput.
Hello!
I'm looking into fine-tuning LLaMA-7b with EasyLM on a TPU v3-8. From my initial runs, I've found that I can get around 975 token/sec. I've tested all the flag combinations I can think of, but am unable to increase the batch size or gradient accumulation steps beyond 1 without OOMing.
I saw that you achieved a high throughput of 2,200 tokens/sec/TPU-v4 chip on OpenLLaMA-7b, and mesh-transformer-jax gets 5k/T/sec on a v3-8 for GPT-J, so I was curious if there was an issue in my config.
Here's how I'm running it:
Do you have any tips? Or is higher throughput only expected on larger TPU pods?
Thanks!
The text was updated successfully, but these errors were encountered: