Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA graph compilation #154

Merged
merged 66 commits into from
Jan 4, 2024
Merged

CUDA graph compilation #154

merged 66 commits into from
Jan 4, 2024

Conversation

tgaddair
Copy link
Contributor

@tgaddair tgaddair commented Jan 2, 2024

This PR adds support for compiling the model into a static CUDA graph. See Accelerating PyTorch with CUDA Graphs for more details on CUDA graphs and how they can reduce latency.

To enable this (experimental) feature:

lorax-launcher ... --compile

There is a tradeoff to be aware of when using CUDA graphs, namely that it increases memory overhead by 3-10GB depending on model size. However, the observed decrease in latency can be as much as 50%, so if you don't need to run with very large batch sizes and are more latency constrained than throughput, this is a very compelling feature to enable.

In practice, CUDA graphs are most useful in cases where there are excess GPU flops available, such as during decoding. As such, we do not use the compiled version of the model during prefill, only during the decoding steps. Which means in practice that the benefits of enabling compilation will be most pronounced when generating longer sequences (for which more time is spent during decoding).

Current limitations:

  • Batch size < 256
  • LoRA rank >= 8 and <= 32
  • Only one LoRA rank in the batch

If any of these conditions are not met, then LoRAX will fallback to using eager execution for the batch.

Thanks to folks on the Punica team for updating kernels to support graph tracing. Additionally, we modified kernels to support padding with -1 (necessary for CUDA graph's requirement that input shapes be constant across batches).

Comparison:

gpt2-medium, time to generate 100 tokens:

no adapter

  • baseline: 1.044 s
  • cuda graph: 0.422 s

1 adapter (rank 16)

  • baseline: 1.503 s
  • cuda graph: 0.583 s

@tgaddair tgaddair marked this pull request as ready for review January 2, 2024 21:50
@tgaddair tgaddair merged commit f20789d into main Jan 4, 2024
1 check passed
@tgaddair tgaddair deleted the cuda-graph branch January 4, 2024 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant