Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def test_custom_compile_config(
run_model(compilation_config, model, model_kwargs)


@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
)
def test_fp8_kv_scale_compile(optimization_level: int):
model = "Qwen/Qwen2-0.5B"
model_kwargs = {
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True,
"max_model_len": 512,
}
run_model(optimization_level, model, model_kwargs)


def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
Expand Down
43 changes: 40 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,8 @@ def forward(
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)

output_dtype = query.dtype
if self.query_quant is not None:
Expand Down Expand Up @@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector(
attn_metadata[layer_name])


def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata

if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]

if attn_metadata is None or not getattr(
attn_metadata, 'enable_kv_scales_calculation', False):
return

self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)


def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,15 @@ def execute_model(
self.cudagraph_dispatcher.dispatch(batch_descriptor,
use_cascade_attn)

# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (attn_metadata.values() if isinstance(
attn_metadata, dict) else [attn_metadata])
if any(
getattr(m, 'enable_kv_scales_calculation', False)
for m in metadata_list):
cudagraph_runtime_mode = CUDAGraphMode.NONE

# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
if ubatch_slices is not None:
Expand Down