-
Notifications
You must be signed in to change notification settings - Fork 294
Open
Labels
bugSomething isn't workingSomething isn't working
Description
⚙️ Your current environment
The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-6.14.0-1019-gcp-x86_64-with-glibc2.39`
Python Version: `3.11.13 (main, Jun 5 2025, 13:12:00) [GCC 11.2.0]`
llm-compressor Version: `0.8.1`
compressed-tensors Version: `0.12.2`
transformers Version: `4.57.1`
torch Version: `2.8.0`
CUDA Devices: `['NVIDIA L4', 'NVIDIA L4']`
AMD Devices: `None`
🐛 Describe the bug
i can't run the AWQ Compression on the GPU, it runs fine one the CPU only but when i am using GPU it Fails.
🛠️ Steps to reproduce
Code
import base64
import argparse
from io import BytesIO
import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling import replace_modules_for_calibration
# Load model.
model_id = "Qwen3-VL-8B-Instruct"
model = Qwen3VLForConditionalGeneration.from_pretrained(model_id, device_map=None,torch_dtype="auto",local_files_only=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# for name, module in model.named_modules():
# print(name, ":", module)
# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = "test[:12]"
NUM_CALIBRATION_SAMPLES = 12
MAX_SEQUENCE_LENGTH = 2048
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)
# Apply chat template and tokenize inputs.
def preprocess_and_tokenize(example):
# preprocess
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_qwen},
{"type": "text", "text": "What does the image show?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
# tokenize
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)
ds = ds.map(preprocess_and_tokenize, remove_columns=ds.column_names)
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
# Recipe
recipe = AWQModifier(
config_groups={
"group_0": {
"targets": ["Linear"],
"weights": {
"num_bits": 8,
"type": "int",
"strategy": "group",
"observer": "mse",
"group_size": 32, # set desired group size here
"symmetric": True # False
},
"input_activations": None,
"output_activations": None,
}
},
ignore=["re:.*lm_head", "re:.*visual.*"], # , "re:.*mlp.gate$"
duo_scaling=False,
)
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
sequential_targets=["Qwen3VLTextDecoderLayer"],
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "http://images.cocodataset.org/train2017/000000231895.jpg",
},
{"type": "text", "text": "Please describe the animal in this image\n"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W8A8_ViT-W4A8-LLM"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)Output
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████| 4/4 [00:00<00:00, 16.06it/s]
2025-11-25T11:04:25.259691+0000 | reset | INFO - Compression lifecycle reset
2025-11-25T11:04:25.277264+0000 | _create_default_logger | INFO - Logging all LLM Compressor modifier-level logs to sparse_logs/25-11-2025_11.04.25.log
2025-11-25T11:04:25.278124+0000 | from_modifiers | INFO - Creating recipe from modifiers
2025-11-25T11:04:25.433593+0000 | on_initialize | INFO - No AWQModifier.mappings provided, inferring from model...
2025-11-25T11:04:25.433998+0000 | get_layer_mappings_from_architecture | INFO - Architecture Qwen3VLForConditionalGeneration not found in mappings. Using default mappings: [AWQMapping(smooth_layer='re:.*input_layernorm$', balance_layers=['re:.*q_proj$', 're:.*k_proj$', 're:.*v_proj$']), AWQMapping(smooth_layer='re:.*v_proj$', balance_layers=['re:.*o_proj$']), AWQMapping(smooth_layer='re:.*post_attention_layernorm$', balance_layers=['re:.*gate_proj$', 're:.*up_proj$']), AWQMapping(smooth_layer='re:.*up_proj$', balance_layers=['re:.*down_proj$'])]
Resolving mapping 1/4 (0 skipped): : 36it [00:00, 1209.13it/s]
Resolving mapping 2/4 (35 skipped): : 36it [00:00, 2422.90it/s]
Resolving mapping 3/4 (0 skipped): : 36it [00:00, 1521.94it/s]
Resolving mapping 4/4 (0 skipped): : 36it [00:00, 2556.42it/s]
2025-11-25T11:04:25.518866+0000 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
2025-11-25T11:04:25.519197+0000 | IndependentPipeline | INFO - Inferred `SequentialPipeline` for `AWQModifier`
Preparing cache: 100%|███████████████████████████████████████████████| 12/12 [00:02<00:00, 4.77it/s]
(1/37): Calibrating: 0%| | 0/12 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 73, in forward
outputs = forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 19, in forward
File "Qwen3VLModel_8572268571825_autowrapped", line 33, in wrapped_2
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 1061, in get_image_features
image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 716, in forward
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 685, in fast_pos_embed_interpolate
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/hooks.py", line 170, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/compressed_tensors/utils/offload.py", line 574, in keep_onload_pre_forward
ret = original_pre_forward(self, module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/hooks.py", line 369, in pre_forward
return send_to_device(args, self.execution_device), send_to_device(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/utils/operations.py", line 169, in send_to_device
return honor_type(
^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/utils/operations.py", line 170, in <genexpr>
tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/accelerate/utils/operations.py", line 153, in send_to_device
return tensor.to(device, non_blocking=non_blocking)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/qwen3_vl/main_table_data_gpu.py", line 98, in <module>
oneshot(
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/entrypoints/oneshot.py", line 330, in oneshot
one_shot()
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/entrypoints/oneshot.py", line 158, in __call__
self.apply_recipe_modifiers(
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/entrypoints/oneshot.py", line 201, in apply_recipe_modifiers
pipeline(
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/pipelines/independent/pipeline.py", line 45, in __call__
pipeline(model, dataloader, dataset_args)
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/pipelines/sequential/pipeline.py", line 104, in __call__
subgraph.forward(model, **inputs)
File "/root/anaconda3/envs/vllm/lib/python3.11/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 75, in forward
raise RuntimeError(
RuntimeError: Raised an exception during execution of the following code:
1
2 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_0")
3 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_1")
4 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_2")
5 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_3")
6 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_5")
7 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_4")
8 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_6")
9 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_7")
10 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_8")
11 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_9")
12 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_10")
13 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_11")
14
15 def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor, pixel_values : torch.Tensor, image_grid_thw : torch.Tensor):
16 wrapped_0 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_0(input_ids, None); wrapped_0 = None
17 wrapped_1 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_1(input_ids, None)
18 getitem = wrapped_1[0]; wrapped_1 = None
19 wrapped_2 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_2(image_grid_thw, None, input_ids, getitem, pixel_values); getitem = pixel_values = None
20 getitem_1 = wrapped_2[0]
21 getitem_2 = wrapped_2[1]
22 getitem_3 = wrapped_2[2]; getitem_3 = None
23 getitem_4 = wrapped_2[3]
24 getitem_5 = wrapped_2[4]; wrapped_2 = None
25 wrapped_3 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_3(getitem_1, input_ids, getitem_5, None, None, None); getitem_1 = getitem_5 = None
26 getitem_6 = wrapped_3[0]
27 getitem_7 = wrapped_3[1]
28 getitem_8 = wrapped_3[2]
29 getitem_9 = wrapped_3[3]; getitem_9 = None
30 getitem_10 = wrapped_3[4]; wrapped_3 = None
31 wrapped_5 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_5(getitem_6, attention_mask, None, image_grid_thw, input_ids, getitem_8, None, None, None); getitem_6 = image_grid_thw = input_ids = None
32 wrapped_6 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_4(None, getitem_8); wrapped_6 = None
33 wrapped_7 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_6(None, getitem_8); getitem_8 = None
34 wrapped_4 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_7(getitem_2, getitem_7, None, getitem_4, getitem_10, None); getitem_2 = getitem_7 = getitem_4 = getitem_10 = None
35 getitem_20 = wrapped_5[0]; getitem_20 = None
36 getitem_21 = wrapped_5[1]; getitem_21 = None
37 getitem_22 = wrapped_5[2]; getitem_22 = None
38 getitem_23 = wrapped_5[3]; getitem_23 = None
39 getitem_24 = wrapped_5[4]
40 getitem_25 = wrapped_5[5]; getitem_25 = None
41 getitem_26 = wrapped_5[6]; getitem_26 = None
42 getitem_27 = wrapped_5[7]; getitem_27 = None
43 getitem_28 = wrapped_5[8]; wrapped_5 = getitem_28 = None
44 getitem_29 = wrapped_7[0]; wrapped_7 = None
45 getitem_11 = wrapped_4[0]
46 getitem_12 = wrapped_4[1]; getitem_12 = None
47 getitem_13 = wrapped_4[2]; getitem_13 = None
48 getitem_14 = wrapped_4[3]; getitem_14 = None
49 getitem_15 = wrapped_4[4]; getitem_15 = None
50 getitem_16 = wrapped_4[5]; getitem_16 = None
51 getitem_17 = wrapped_4[6]; getitem_17 = None
52 getitem_18 = wrapped_4[7]; getitem_18 = None
53 getitem_19 = wrapped_4[8]; wrapped_4 = None
54 wrapped_8 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_8(None, getitem_29, None)
55 getitem_30 = wrapped_8[0]
56 getitem_31 = wrapped_8[1]; wrapped_8 = getitem_31 = None
57 wrapped_9 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_9(getitem_30, getitem_29, getitem_24); getitem_24 = None
58 getitem_32 = wrapped_9[0]; wrapped_9 = None
59 wrapped_10 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_10(getitem_32); getitem_32 = None
60 getitem_33 = wrapped_10[0]
61 getitem_34 = wrapped_10[1]; wrapped_10 = None
62 model_language_model_rotary_emb = self.model.language_model.rotary_emb(getitem_29, getitem_33); getitem_33 = None
63 wrapped_11 = transformers_models_qwen3_vl_modeling_qwen3_vl_wrapped_11(attention_mask, getitem_30, getitem_29, None, getitem_34); attention_mask = None
64 model_language_model_layers_0 = getattr(self.model.language_model.layers, "0")(getitem_29, attention_mask = wrapped_11, position_ids = getitem_34, past_key_values = None, cache_position = getitem_30, position_embeddings = model_language_model_rotary_emb); getitem_29 = None
65 return {'getitem_11': getitem_11, 'getitem_19': getitem_19, 'getitem_30': getitem_30, 'getitem_34': getitem_34, 'model_language_model_rotary_emb': model_language_model_rotary_emb, 'wrapped_11': wrapped_11, 'model_language_model_layers_0': model_language_model_layers_0}
66
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working