Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.compile Performance Tracking #1008

Open
merrymercy opened this issue Aug 9, 2024 · 0 comments
Open

Torch.compile Performance Tracking #1008

merrymercy opened this issue Aug 9, 2024 · 0 comments

Comments

@merrymercy
Copy link
Contributor

merrymercy commented Aug 9, 2024

torch.compile can accelerate small batch sizes for llama-3 8B. However, it is sometimes slower for large batch size or tensor parallelism. We use this issue to track the performance and potential fixes.

Instructions and results

# Benchmark llama-3-8B (TP=1, bs=1) with cuda graph
# Decode.  median latency: 0.00737 s, median throughput:    135.64 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8

# Benchmark llama-3-8B (TP=1, bs=1) with torch.compile
# Decode.  median latency: 0.00642 s, median throughput:    155.67 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile


# Benchmark llama-3-8B (TP=1, bs=128) with cuda graph
# Decode.  median latency: 0.01184 s, median throughput:  10815.07 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 128 --input 128 --output 8

# Benchmark llama-3-8B (TP=1, bs=128) with torch.compile
# Decode.  median latency: 0.01231 s, median throughput:  10401.75 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 128 --input 128 --output 8 --enable-torch-compile


# Benchmark llama-3-8B (TP=8, bs=1) with cuda graph
# Decode.  median latency: 0.00335 s, median throughput:    298.53 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --tp 8

# Benchmark llama-3-8B (TP=8, bs=1) with torch.compile
# Decode.  median latency: 0.00351 s, median throughput:    284.51 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --tp 8 --enable-torch-compile


# Benchmark llama-3-70B (TP=8, bs=1) with cuda graph
# Decode.  median latency: 0.01220 s, median throughput:     82.00 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-70B --batch-size 1 --input 128 --output 8 --tp 8

# Benchmark llama-3-70B (TP=8, bs=1) with torch.compile
# Decode.  median latency: 0.01211 s, median throughput:     82.57 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-70B --batch-size 1 --input 128 --output 8 --tp 8 --enable-torch-compile

Environment

python3 -m sglang.check_env

GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0

NVCC: Cuda compilation tools, release 12.3, V12.3.107
CUDA Driver Version: 545.23.08

PyTorch: 2.4.0+cu121
flashinfer: 0.1.6+cu121torch2.4
triton: 3.0.0
vllm: 0.5.5
NVIDIA Topology: mostly NV18

commit: 79ece2c51f47ee6b792c6282a6f76987892c5f8d (Fri Aug 30)
@merrymercy merrymercy changed the title Torch.compile Performance Track Torch.compile Performance Tracking Aug 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant