Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advice/expectations on throughput #61

Closed
float-trip opened this issue Jun 17, 2023 · 2 comments
Closed

Advice/expectations on throughput #61

float-trip opened this issue Jun 17, 2023 · 2 comments

Comments

@float-trip
Copy link

Hello!

I'm looking into fine-tuning LLaMA-7b with EasyLM on a TPU v3-8. From my initial runs, I've found that I can get around 975 token/sec. I've tested all the flag combinations I can think of, but am unable to increase the batch size or gradient accumulation steps beyond 1 without OOMing.

I saw that you achieved a high throughput of 2,200 tokens/sec/TPU-v4 chip on OpenLLaMA-7b, and mesh-transformer-jax gets 5k/T/sec on a v3-8 for GPT-J, so I was curious if there was an issue in my config.

Here's how I'm running it:

# Removed "jax_enable_async_all_gather", as it causes a crash on a v3-8. Without these flags, the throughput is 590 tokens/sec.
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'

python -m EasyLM.models.llama.llama_train \
    --dtype='fp32' \ # bf16 causes errors - is it only intended for serving?
    --mesh_dim='1,-1,1' \
    --load_llama_config='7b' \
    --optimizer.type='adamw' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=1 \
    # ... omitting other flags which shouldn't affect throughput

Do you have any tips? Or is higher throughput only expected on larger TPU pods?

Thanks!

@young-geng
Copy link
Owner

With full FSDP, you should be able to set batch size at least 8 for TPU v3-8 (for FSDP, the batch size should always be set to at least the size of FSDP axis). The bottleneck here for your case is most likely memory, since storing the weights and optimizer states takes over 84GB memory, which leaves very little memory for activations. You might need a larger pod to realize the high throughput.

@float-trip
Copy link
Author

That makes sense, I appreciate the explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants