-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Closed
Closed
Copy link
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of python collect_env.py
==============================
System Info
==============================
OS : Red Hat Enterprise Linux release 8.10 (Ootpa) (x86_64)
GCC version : (GCC) 8.5.0 20210514 (Red Hat 8.5.0-26)
Clang version : Could not collect
CMake version : Could not collect
Libc version : glibc-2.28
==============================
PyTorch Info
==============================
PyTorch version : 2.9.0+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform : Linux-4.18.0-553.50.1.el8_10.x86_64-x86_64-with-glibc2.28
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : Could not collect
CUDA_MODULE_LOADING set to :
GPU models and configuration :
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
Nvidia driver version : 575.51.03
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
CPU MHz: 2250.000
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4491.72
Virtualization: AMD-V
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-15
NUMA node1 CPU(s): 16-31
NUMA node2 CPU(s): 32-47
NUMA node3 CPU(s): 48-63
NUMA node4 CPU(s): 64-79
NUMA node5 CPU(s): 80-95
NUMA node6 CPU(s): 96-111
NUMA node7 CPU(s): 112-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.5.2
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.16.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cutlass-dsl==4.3.0
[pip3] nvidia-ml-py==13.580.82
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.3.20
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.9.0
[pip3] torchaudio==2.9.0
[pip3] torchvision==0.24.0
[pip3] transformers==4.57.1
[pip3] triton==3.5.0
[conda] flashinfer-python 0.5.2 pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cudnn-frontend 1.16.0 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-cufile-cu12 1.13.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-cutlass-dsl 4.3.0 pypi_0 pypi
[conda] nvidia-ml-py 13.580.82 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvshmem-cu12 3.3.20 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pyzmq 27.1.0 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] torchaudio 2.9.0 pypi_0 pypi
[conda] torchvision 0.24.0 pypi_0 pypi
[conda] transformers 4.57.1 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi
==============================
vLLM Info
==============================
ROCM Version : Could not collect
vLLM Version : 0.11.2.dev154+g57430fc95 (git sha: 57430fc95)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
GPU0 GPU1 NIC0 NIC1 NIC2 NIC3 NIC4 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV12 PXB PXB SYS SYS SYS 16-31 1 N/A
GPU1 NV12 X PXB PXB SYS SYS SYS 16-31 1 N/A
NIC0 PXB PXB X PXB SYS SYS SYS
NIC1 PXB PXB PXB X SYS SYS SYS
NIC2 SYS SYS SYS SYS X SYS SYS
NIC3 SYS SYS SYS SYS SYS X PIX
NIC4 SYS SYS SYS SYS SYS PIX X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
==============================
Environment Variables
==============================
LD_LIBRARY_PATH=/opt/software/slurm/24.11.5-1/lib
TORCHINDUCTOR_CACHE_DIR=/scratch/hpc-prf-haqc/vllm-compile-cache
CUDA_VISIBLE_DEVICES=0,1
CUDA_VISIBLE_DEVICES=0,1
TORCH_COMPILE_CACHE=/scratch/hpc-prf-haqc/haikai/vllm-cache/torch_compile_cache
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
🐛 Describe the bug
def execute_batch_pope_with_pruned_embeddings(
modelname: str,
fields: List[Dict[str, Any]],
query: str,
keep_ratio: float,
typed_fields: List[Tuple[str, str]],
reordered_columns: List[str],
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
guided_choice: List[str] = None,
base_url: str = None, # Not used in offline mode, but kept for signature compatibility
) -> Tuple[List[str], float]:
# --- PHASE 1: Pruning / Embedding Generation (PyTorch) ---
print("Phase 1: Loading Vision Model for Pruning...")
vision_tower, model, processor = load_vision_models_llava_next(device='cuda')
user_prompts = []
all_pruned_embeddings = []
batch_pruning_time = 0.0
try:
for field_dict in fields:
user_prompt = ""
pruned_embeddings_for_this_prompt = []
for field_name in reordered_columns:
# ... (Existing logic to find field type) ...
field_type = next((ft for fn, ft in typed_fields if fn == field_name), None)
if field_type == "text":
value = field_dict.get(field_name, "")
user_prompt += f"{field_name}: {value}\n"
elif field_type == "image":
# IMPORTANT: vLLM needs the <image> token to know where to inject embeddings
user_prompt += f"{field_name}: <image>\n"
image_data = field_dict.get(field_name)
if image_data is not None:
image_binary = extract_image_binary_from_pope_data(image_data)
messages = [
{
"role": "user",
"content": [{"type": "text", "text": ""}, {"type": "image", "image": Image.open(io.BytesIO(image_binary))}],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process Image
inputs = processor(text=prompt, images=Image.open(io.BytesIO(image_binary)), return_tensors="pt")
# --- PRUNING LOGIC ---
# (Your existing timing and pruning logic)
#reduced_tokens = model.prumerge(inputs)
reduced_tokens = get_inputs_embeds(model, inputs)
# --- CRITICAL FIX FOR VLLM ---
# vLLM expects [Batch, Sequence, Hidden].
# prumerge returns [Sequence, Hidden].
# We must add the batch dimension: [1, 2642, 4096]
# Move to CPU to save GPU memory for vLLM later
pruned_embeddings_for_this_prompt.append(reduced_tokens.detach().cpu())
user_prompts.append(user_prompt.strip())
# Store only the first image embedding found (assuming 1 image per prompt based on your logic)
all_pruned_embeddings.append(
pruned_embeddings_for_this_prompt if pruned_embeddings_for_this_prompt else None
)
finally:
# --- MEMORY MANAGEMENT ---
# We MUST destroy the PyTorch model to free up VRAM for vLLM
print("Phase 1 Complete. Unloading PyTorch model to free VRAM...")
del vision_tower
del model
del processor
torch.cuda.empty_cache()
# --- PHASE 2: Inference (vLLM) ---
print("Phase 2: Loading vLLM Engine...")
scratch_path = "/scratch/hpc-prf-haqc"
user_path = f"{scratch_path}/haikai"
# Set the variables
os.environ["HF_HOME"] = f"{user_path}/hf-cache"
os.environ["VLLM_CACHE_DIR"] = f"{user_path}/vllm-cache"
os.environ["TORCH_COMPILE_CACHE"] = f"{user_path}/vllm-cache/torch_compile_cache"
os.environ["TRANSFORMERS_CACHE"] = f"{user_path}/hf-cache"
# Triton/Inductor caches (Critical for vllm performance/compilation)
os.environ["TRITON_CACHE_DIR"] = f"{scratch_path}/vllm-compile-cache"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{scratch_path}/vllm-compile-cache"
# -----------------------------------------------------------------------
# 2. NOW IMPORT AND INITIALIZE VLLM
# -----------------------------------------------------------------------
# It is safe to import vllm now that the paths are set
from vllm import LLM as VllmEngine
from vllm import SamplingParams
# Initialize
llm = VllmEngine(
model=modelname,
limit_mm_per_prompt={"image": 1},
trust_remote_code=True,
gpu_memory_utilization=0.5,
enforce_eager=True,
enable_mm_embeds=True,
download_dir=os.environ["HF_HOME"] # Explicitly enforce it here too just in case
)
# Configure Sampling Params
sampling_args = {
"temperature": 0.0,
"max_tokens": 512
}
# Add Guided Decoding if choices are provided
sampling_params = SamplingParams(**sampling_args)
# Extract Text
final_results = []
for i, user_prompt in enumerate(user_prompts):
# Generate full prompt (System + User)
full_prompt = _generate_prompt(user_prompt=user_prompt, system_prompt=system_prompt)
input_item = {
"prompt": full_prompt,
}
# Inject Embeddings if they exist
if all_pruned_embeddings[i] is not None:
embed_tensor = all_pruned_embeddings[i]
input_item["multi_modal_data"] = {
"image": embed_tensor # shape: [1, seq_len, hidden]
}
outputs = llm.generate(input_item, sampling_params=sampling_params)
for output in outputs:
if output.outputs:
final_results.append(output.outputs[0].text)
else:
final_results.append("")
return final_results, batch_pruning_time
I try to pass the image embedding to the using the image_embeds features of the vLLM with llava v1.6
But when the input embedding size is a 2D tensor, it says
[ File "/scratch/hpc-prf-haqc/haikai/LLM-Multimodal/trim_pope_llavanext.py", line 391, in llm_udf_embedding_batch
outputs, batch_pruning_time = execute_batch_pope_with_pruned_embeddings(
File "/scratch/hpc-prf-haqc/haikai/LLM-Multimodal/trim_pope_llavanext.py", line 229, in execute_batch_pope_with_pruned_embeddings
outputs = llm.generate(input_item, sampling_params=sampling_params)
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/entrypoints/llm.py", line 439, in generate
self._validate_and_add_requests(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/entrypoints/llm.py", line 1612, in _validate_and_add_requests
raise e
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/entrypoints/llm.py", line 1612, in _validate_and_add_requests
raise e
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/entrypoints/llm.py", line 1699, in _add_request
engine_request, tokenization_kwargs = self._process_inputs(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/entrypoints/llm.py", line 1679, in _process_inputs
engine_request = self.processor.process_inputs(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/v1/engine/processor.py", line 442, in process_inputs
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 699, in preprocess
res = self._preprocess(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 685, in _preprocess
return self._process_decoder_only_prompt(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 654, in _process_decoder_only_prompt
prompt_comps = self._prompt_to_llm_inputs(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 425, in _prompt_to_llm_inputs
return self._process_text(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 378, in _process_text
inputs = self._process_multimodal(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/inputs/preprocess.py", line 270, in _process_multimodal
mm_input = mm_processor.apply(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 2079, in apply
) = self._cached_apply_hf_processor(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1858, in _cached_apply_hf_processor
) = self._apply_hf_processor_main(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1595, in _apply_hf_processor_main
mm_processed_data = self._apply_hf_processor_mm_only(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1553, in _apply_hf_processor_mm_only
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1480, in _apply_hf_processor_text_mm
processed_data = self._call_hf_processor(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1440, in _call_hf_processor
return self.info.ctx.call_hf_processor(
File "/scratch/hpc-prf-haqc/haikai/vllm/vllm/multimodal/processing.py", line 1093, in call_hf_processor
raise ValueError(msg) from exc
ValueError: Failed to apply LlavaNextProcessor on data={'text': '<image>', 'images': [tensor([[[-4.5471e-03, 3.5286e-04, -5.1575e-03, ..., 4.0770e-05,
-1.0300e-03, -1.2779e-04],
[-4.4861e-03, -5.1975e-05, -1.8234e-03, ..., 2.4128e-04,
4.0588e-03, 4.3678e-04],
[ 1.5068e-04, 3.4142e-04, -2.4261e-03, ..., -2.5787e-03,
3.2501e-03, -2.8687e-03],
...,
[-8.0490e-04, 1.2512e-03, -6.4850e-04, ..., 3.5667e-04,
-4.1809e-03, 1.8692e-04],
[ 1.5068e-04, 3.4142e-04, -2.4261e-03, ..., -2.5787e-03,
3.2501e-03, -2.8687e-03],
[-4.0283e-03, -1.8082e-03, 3.7384e-03, ..., -2.1839e-04,
-1.0376e-03, 3.2959e-03]]], dtype=torch.float16)]} with kwargs={'truncation': False}
But when if I put reduce_tokens = reduce_tokens.unsqueeze(0) to transfo
rm it back to a 3D tensor, it says
(EngineCore_DP0 pid=547579) AssertionError: Expected multimodal embeddings to be a sequence of 2D tensors,
but got tensors with shapes [torch.Size([1, 2155, 4096])] instead. This is most likely due to incorrect implementation of the model's `embed_multimodal` method.
Is there any complete code demo that shows how to use image embedding input using Llava 1.6 or Qwen-VL-2.5?
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working