Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Dynamic FP8 Marlin quantization fails on 0.5.4 #7216

Closed
mgoin opened this issue Aug 6, 2024 · 2 comments · Fixed by #7219
Closed

[Bug]: Dynamic FP8 Marlin quantization fails on 0.5.4 #7216

mgoin opened this issue Aug 6, 2024 · 2 comments · Fixed by #7219
Labels
bug Something isn't working

Comments

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Aug 6, 2024

Your current environment

PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           8
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          3 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           128 MiB (64 instances)
L3 cache:                           120 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118,120,122,124,126
NUMA node1 CPU(s):                  1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79,81,83,85,87,89,91,93,95,97,99,101,103,105,107,109,111,113,115,117,119,121,123,125,127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.44.0
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.4
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV12    NV12    NV12    NV12    NV12    NV12    NV12    PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS     0,2,4,6,8,10    0               N/A
GPU1    NV12     X      NV12    NV12    NV12    NV12    NV12    NV12    SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     0,2,4,6,8,10    0               N/A
GPU2    NV12    NV12     X      NV12    NV12    NV12    NV12    NV12    SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     0,2,4,6,8,10    0               N/A
GPU3    NV12    NV12    NV12     X      NV12    NV12    NV12    NV12    SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     0,2,4,6,8,10    0               N/A
GPU4    NV12    NV12    NV12    NV12     X      NV12    NV12    NV12    SYS     SYS     SYS     SYS     PXB     SYS     SYS     SYS     1,3,5,7,9,11    1               N/A
GPU5    NV12    NV12    NV12    NV12    NV12     X      NV12    NV12    SYS     SYS     SYS     SYS     SYS     PXB     SYS     SYS     1,3,5,7,9,11    1               N/A
GPU6    NV12    NV12    NV12    NV12    NV12    NV12     X      NV12    SYS     SYS     SYS     SYS     SYS     SYS     PXB     SYS     1,3,5,7,9,11    1               N/A
GPU7    NV12    NV12    NV12    NV12    NV12    NV12    NV12     X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     PXB     1,3,5,7,9,11    1               N/A
NIC0    PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC1    SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS
NIC2    SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS
NIC3    SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS
NIC5    SYS     SYS     SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS
NIC6    SYS     SYS     SYS     SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS
NIC7    SYS     SYS     SYS     SYS     SYS     SYS     SYS     PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7

🐛 Describe the bug

Dynamic FP8 works fine on H100 but fails on A100. This is an issue with the dynamic FP8 Marlin backend.

vllm serve meta-llama/Meta-Llama-3-8B-Instruct --quantization="fp8" --port 9000 
...
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 172, in marlin_permute_scales
    s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
RuntimeError: shape '[-1, 32]' is invalid for input of size 1

It does work fine with models that are already quantized to FP8 on A100:

vllm serve neuralmagic/Meta-Llama-3-8B-Instruct-FP8 --quantization="fp8" --port 9000
...
INFO:     Uvicorn running on http://0.0.0.0:9000 (Press CTRL+C to quit)

Full command and output/stacktrace:

vllm serve meta-llama/Meta-Llama-3-8B-Instruct --quantization="fp8" --port 9000     
INFO 08-06 19:27:35 api_server.py:339] vLLM API server version 0.5.4
INFO 08-06 19:27:35 api_server.py:340] args: Namespace(model_tag='meta-llama/Meta-Llama-3-8B-Instruct', host=None, port=9000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, model='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization='fp8', rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None, dispatch_function=<function serve at 0x7f5440e617e0>)
WARNING 08-06 19:27:35 config.py:1454] Casting torch.bfloat16 to torch.float16.
INFO 08-06 19:27:35 llm_engine.py:174] Initializing an LLM engine (v0.5.4) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 08-06 19:27:38 model_runner.py:720] Starting to load model meta-llama/Meta-Llama-3-8B-Instruct...
INFO 08-06 19:27:39 weight_utils.py:225] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  7.55it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:00<00:00,  2.52it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  1.81it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.02it/s]

WARNING 08-06 19:27:41 utils.py:578] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
Process Process-1:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/entrypoints/openai/rpc/server.py", line 217, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/entrypoints/openai/rpc/server.py", line 25, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 471, in from_engine_args
    engine = cls(
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 381, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 552, in _init_engine
    return engine_class(*args, **kwargs)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 249, in __init__
    self.model_executor = executor_class(
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 47, in __init__
    self._init_executor()
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 36, in _init_executor
    self.driver_worker.load_model()
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/worker/worker.py", line 139, in load_model
    self.model_runner.load_model()
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 722, in load_model
    self.model = get_model(model_config=self.model_config,
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 21, in get_model
    return loader.load_model(model_config=model_config,
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 344, in load_model
    quant_method.process_weights_after_loading(module)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 212, in process_weights_after_loading
    prepare_fp8_layer_for_marlin(layer)
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py", line 80, in prepare_fp8_layer_for_marlin
    marlin_scales = marlin_permute_scales(s=scales,
  File "/home/mgoin/venvs/vllm-rel/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 172, in marlin_permute_scales
    s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
RuntimeError: shape '[-1, 32]' is invalid for input of size 1
@garycaokai
Copy link

same problem, 0.5.4 works on L20 but fails on A30.

@fozziethebeat
Copy link

Oh! I was just about to file this, I was running with docker image vllm/vllm-openai:v0.5.4 and hit the same problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants