From 1295171d9982fbfe25eca2755edbe90f790e9186 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 14 Aug 2025 22:36:55 -0700 Subject: [PATCH 1/2] [ET-VK] Enable IntxWeightOnlyConfig ## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Very similar to the below PR, Int4 weight only quantization is currently enabled in Vulkan via a custom source transform quantizer that replaces linear layers with a custom linear layer that calls a custom weight only quantized linear op. This diff aims to make it so that no Vulkan specific source transforms need to be applied by adding a fusion pattern for weight only quantized linear. ## Changes * Introduce a fusable graph pattern for weight only quantized linear * Add fusion logic for weight only quantized linear in the fuse patterns pass * Add `4w` qmode to the export llama script Differential Revision: [D80293302](https://our.internmc.facebook.com/intern/diff/D80293302/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 2 + backends/vulkan/_passes/fuse_patterns.py | 187 ++++++++++++++++++ .../_passes/int4_weight_only_quantizer.py | 3 + .../vulkan/partitioner/vulkan_partitioner.py | 1 + backends/vulkan/patterns/TARGETS | 1 + backends/vulkan/patterns/__init__.py | 4 + backends/vulkan/patterns/quantized_linear.py | 117 +++++++++++ backends/vulkan/test/test_vulkan_delegate.py | 88 ++++++++- backends/vulkan/vulkan_preprocess.py | 3 + examples/models/llama/export_llama_lib.py | 2 +- .../llama/source_transformation/quantize.py | 15 ++ extension/llm/export/config/llm_config.py | 8 +- 12 files changed, 428 insertions(+), 3 deletions(-) create mode 100644 backends/vulkan/patterns/quantized_linear.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 3263d273b72..2ccd48eb5d1 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -126,7 +126,9 @@ runtime.python_library( ], deps = [ "//caffe2:torch", + "//executorch/backends/transforms:utils", "//executorch/backends/vulkan/patterns:vulkan_patterns", + "//executorch/backends/vulkan:utils_lib", "//executorch/exir:lib", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py index b320dc973a0..3a5694385ed 100644 --- a/backends/vulkan/_passes/fuse_patterns.py +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -8,8 +8,12 @@ from typing import Callable, List, Optional import executorch.backends.vulkan.patterns as vk_patterns +import executorch.backends.vulkan.utils as utils import torch +import torch.nn.functional as F + +from executorch.backends.transforms.utils import get_param_tensor, is_param_node from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -108,6 +112,182 @@ def create_rotary_emb_custom_op( xk_out.replace_all_uses_with(getitem_1) +## +## Quantized Linear +## + + +def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: + """ + Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed + weight tensor by packing 2 4-bit values in one unsigned 8-bit value. + + An input weight tensor of shape (M, K) will produce a packed weight tensor of shape + (M, K / 2). + """ + + # Assert we got a properly quantized tensor. + min, max = inp.min().item(), inp.max().item() + assert ( + max <= 7 and min >= -8 + ), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]" + + # Assuming we have a 2d tensor + if inp.ndim != 2: + inp = inp.squeeze() + assert ( + inp.ndim == 2 + ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}" + + # pad ic + if inp.shape[-1] % 2 != 0: + inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) + + # Shape after padding + oc, ic = inp.shape + assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" + + # Adjust inp tensor for zp + inp = inp.to(dtype=torch.uint8) + 8 + # Pack each 4-bit value into a single 8-bit value + return inp[::, ::2] << 4 | inp[::, 1::2] + + +def make_combined_scales_and_zeros_tensor( + scales: torch.Tensor, zeros: torch.Tensor +) -> torch.Tensor: + """ + Given a scales and zeros tensor, create a combined tensor by packing the values + into a single tensor. + + An input scales tensor of shape (M,) and zeros tensor of shape (K,) will produce a + combined tensor of shape (M, K). + """ + scales_reshaped = scales.transpose(0, 1).unsqueeze(2) + zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2) + + zeros_scaled = zeros_reshaped * scales_reshaped * -1 + return torch.cat((scales_reshaped, zeros_scaled), dim=2) + + +def identify_wo_quantized_linear_io_nodes( # noqa: C901 + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +) -> Optional[List[torch.fx.Node]]: + dequant_node = None + # First, find the dequant node + for node in match.nodes_map.values(): + if utils.is_dequant_node(node): + dequant_node = node + break + + if dequant_node is None: + return None + + quantized_weight = dequant_node.args[0] + quant_scales = dequant_node.args[2] + quant_zeros = dequant_node.args[3] + + if not isinstance(quantized_weight, torch.fx.Node) or not is_param_node( + ep, quantized_weight + ): + return None + if not isinstance(quant_scales, torch.fx.Node) or not is_param_node( + ep, quant_scales + ): + return None + if not isinstance(quant_zeros, torch.fx.Node) or not is_param_node(ep, quant_zeros): + return None + + input_nodes = match.placeholder_nodes + if len(input_nodes) != 4: + return None + + in_tensor_node = None + for node in input_nodes: + if node not in dequant_node.args: + in_tensor_node = node + break + + if in_tensor_node is None: + return None + + output_nodes = match.returning_nodes + + if len(output_nodes) != 1: + return None + + out_tensor_node = output_nodes[0] + if not isinstance(out_tensor_node, torch.fx.Node): + return None + + return [ + in_tensor_node, + quantized_weight, + quant_scales, + quant_zeros, + out_tensor_node, + ] + + +# wo = "weight only" +def create_wo_quantized_linear_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +): + io_nodes = identify_wo_quantized_linear_io_nodes(ep, graph_module, match) + if io_nodes is None: + return + + assert len(io_nodes) == 5 + in_tensor, quantized_weight, quant_scales, quant_zeros, out_tensor = io_nodes + + quantized_weight_tensor = get_param_tensor(ep, quantized_weight) + if not isinstance(quantized_weight_tensor, torch.Tensor): + return + packed_quantized_weight_tensor = pack_4bit_weight_tensor(quantized_weight_tensor) + utils.update_program_state_dict( + ep, quantized_weight.name, packed_quantized_weight_tensor + ) + quantized_weight.meta["val"] = quantized_weight.meta["val"][:, ::2].to(torch.uint8) + + quant_scales_tensor = get_param_tensor(ep, quant_scales) + quant_zeros_tensor = get_param_tensor(ep, quant_zeros) + + assert quantized_weight_tensor is not None + assert quant_scales_tensor is not None + assert quant_zeros_tensor is not None + + group_size = quantized_weight_tensor.shape[1] // quant_scales_tensor.shape[1] + + combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor( + quant_scales_tensor, quant_zeros_tensor + ) + + combined_scales_zeros_name = f"{quantized_weight.name}_scales_zeros" + graph_module.register_parameter( + combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor) + ) + + with graph_module.graph.inserting_before(out_tensor): + combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name) + wo_qlinear = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_weight_int4.default, + args=(in_tensor, quantized_weight, group_size, combined_scales_zeros, 1), + ) + + if hasattr(out_tensor, "meta") and "val" in out_tensor.meta: + wo_qlinear.meta["val"] = out_tensor.meta["val"] + + out_tensor.replace_all_uses_with(wo_qlinear) + + # Clean up dead code + graph_module.graph.eliminate_dead_code() + + class FusePatternsPass(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() @@ -123,6 +303,13 @@ def call(self, graph_module: torch.fx.GraphModule): create_rotary_emb_custom_op, ) + total_replaced += fuse_pattern( + self.program, + graph_module, + vk_patterns.get_torchao_wo_quantized_linear_graphs(), + create_wo_quantized_linear_custom_op, + ) + if total_replaced > 0: graph_module.recompile() # Re-trace the graph diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 34ff5937822..22e7d2e40b7 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -246,6 +246,9 @@ def _create_quantized_state_dict( self.groupsize, self.precision, # dtype for scales_and_zeros ) + + print(w_int4x8.shape) + print(scales_and_zeros.shape) # If the packing of 2 4-bit values into a single 8-bit value was not # performed in the previous function call, then do it manually now. if w_int4x8.shape == weight.shape: diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index fa0cd107a3b..10235baf3ba 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -309,6 +309,7 @@ def get_fusable_subgraphs(graph_module: torch.fx.GraphModule) -> List[InternalMa fuse_patterns = [] fuse_patterns.extend(vk_patterns.get_rope_graphs()) + fuse_patterns.extend(vk_patterns.get_torchao_wo_quantized_linear_graphs()) for pattern in fuse_patterns: sm = SubgraphMatcher(pattern.graph, ignore_literals=True) diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 7068799d02e..a213a1aef36 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -8,6 +8,7 @@ runtime.python_library( srcs = [ "__init__.py", "rope.py", + "quantized_linear.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 189f01d67a6..d992bc6c06d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from executorch.backends.vulkan.patterns.quantized_linear import ( + get_torchao_wo_quantized_linear_graphs, +) from executorch.backends.vulkan.patterns.rope import ( get_rope_graphs, RotaryEmbeddingPattern, @@ -11,6 +14,7 @@ __all__ = [ + "get_torchao_wo_quantized_linear_graphs", "get_rope_graphs", "RotaryEmbeddingPattern", ] diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py new file mode 100644 index 00000000000..f4673c98bb7 --- /dev/null +++ b/backends/vulkan/patterns/quantized_linear.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +from typing import Callable, List, Optional + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from torch.export import export + +# Import torchao modules conditionally to avoid import errors during pattern matching +from torchao.quantization.granularity import PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.utils import unwrap_tensor_subclass + + +class TorchAOWeightOnlyQuantizedLinearPattern(torch.nn.Module): + """ + Quantized linear pattern produced when quantizing linear layers using + `torchao.quantization.quant_api.quantize_()` with IntxWeightOnlyConfig. + """ + + def __init__( + self, + in_features: int = 512, + out_features: int = 256, + bias: bool = False, + group_size: int = 64, + weight_bits: int = 4, + granularity_class: Optional[Callable] = None, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + self.group_size = group_size + self.weight_bits = weight_bits + + if self.weight_bits == 4: + # pyre-ignore[16] + self.weight_dtype = torch.int4 + else: + self.weight_dtype = torch.int8 + + if granularity_class is not None: + self.quant_granularity = granularity_class(self.group_size) + else: + self.quant_granularity = PerGroup(self.group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def apply_quantization(self): + q_config = IntxWeightOnlyConfig( + weight_dtype=self.weight_dtype, + granularity=self.quant_granularity, + ) + quantize_(self, q_config) + unwrap_tensor_subclass(self) + return self + + +@lru_cache(maxsize=None) +def get_torchao_wo_quantized_linear_graphs() -> List[torch.fx.GraphModule]: + graphs = [] + + # Different configurations to test + configs = [ + # gemv pattern + (1, 1, 128, 128, False, 64, 4, PerGroup), + # gemm pattern + (1, 8, 128, 128, False, 64, 4, PerGroup), + ] + + for ( + batch_size, + seq_len, + in_features, + out_features, + bias, + group_size, + weight_bits, + granularity_class, + ) in configs: + for dtype in [torch.float32]: + xs = [] + xs.append(torch.randn(batch_size, seq_len, in_features, dtype=dtype)) + if batch_size == 1: + xs.append(torch.randn(seq_len, in_features, dtype=dtype)) + + for x in xs: + # Create and quantize the pattern + pattern = TorchAOWeightOnlyQuantizedLinearPattern( + in_features=in_features, + out_features=out_features, + bias=bias, + group_size=group_size, + weight_bits=weight_bits, + granularity_class=granularity_class, + ) + + # Apply quantization + pattern = pattern.apply_quantization() + + # Export the quantized pattern + edge = to_edge( + export( + pattern, + (x,), + ), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge.exported_program().graph_module + graphs.append(gm) + + return graphs diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 6bf6a68090a..33536acb662 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -24,10 +24,13 @@ ExecutorchProgramManager, ) from torch.export import Dim, export, export_for_training, ExportedProgram +from torchao.quantization.granularity import PerGroup from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.utils import unwrap_tensor_subclass ctypes.CDLL("libvulkan.so.1") @@ -84,7 +87,7 @@ def quantize_and_lower_module( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() - program = prepare_pt2e(program, quantizer) # pyre-ignore + program = prepare_pt2e(program, quantizer) # Calibrate program(*sample_inputs) @@ -2294,3 +2297,86 @@ def forward(self, x1, x2, x3, x4, x5, x6): dynamic_shapes=dynamic_shapes, test_inputs=test_inputs, ) + + def test_vulkan_backend_torchao_wo_quantized_linear(self): + in_features = 1024 + out_features = 512 + bias = False + group_size = 64 + weight_bits = 4 + + class TorchAOQuantizedLinearModule(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + group_size: int = 64, + weight_bits: int = 4, + ): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + self.group_size = group_size + self.weight_bits = weight_bits + + if self.weight_bits == 4: + self.weight_dtype = torch.int4 + else: + self.weight_dtype = torch.int8 + + self.quant_granularity = PerGroup(self.group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def apply_quantization(self): + """Apply TorchAO weight-only quantization to the linear layer.""" + q_config = IntxWeightOnlyConfig( + weight_dtype=self.weight_dtype, + granularity=self.quant_granularity, + ) + quantize_(self, q_config) + unwrap_tensor_subclass(self) + return self + + # Test with GEMV pattern (batch_size=1, seq_len=1) + quantized_linear_module = TorchAOQuantizedLinearModule( + in_features=in_features, + out_features=out_features, + bias=bias, + group_size=group_size, + weight_bits=weight_bits, + ) + + # Apply quantization + quantized_linear_module = quantized_linear_module.apply_quantization() + + # Test with 2D input (GEMV pattern) + sample_inputs = (torch.randn(size=(1, in_features), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + quantized_linear_module, sample_inputs, atol=1e-2, rtol=1e-2 + ) + + # Test with GEMM pattern (batch_size > 1) + quantized_linear_module_gemm = TorchAOQuantizedLinearModule( + in_features=in_features, + out_features=out_features, + bias=bias, + group_size=group_size, + weight_bits=weight_bits, + ) + + # Apply quantization + quantized_linear_module_gemm = quantized_linear_module_gemm.apply_quantization() + + # Test with 3D input (GEMM pattern) + sample_inputs_gemm = ( + torch.randn(size=(1, 248, in_features), dtype=torch.float32), + ) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 + ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 8c1165a89df..83c96790bab 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -149,6 +149,9 @@ def preprocess( # noqa: C901 program = unsafe_remove_auto_functionalized_pass(program) + print("\n\nVulkanBackend preprocess") + print(program.graph_module.graph) + print("VulkanBackend preprocess\n\n") # First, apply passes that fuse/remove operators to consolidate the graph # structure but still preserve an "ATen-compliant" graph structure (i.e. all # arguments to ATen operators must match the ATen function schema). diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 18700acade2..bced97beef0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -789,7 +789,7 @@ def get_quantizer_and_quant_params(llm_config): def _qmode_type(value): - choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w", "4w"] patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] if value in choices: diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index fed36c39081..0278bc6e912 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -165,6 +165,21 @@ def quantize( # noqa C901 q_group_size = 256 if group_size is None else group_size model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) + return model + elif qmode == "4w": + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + from torchao.utils import unwrap_tensor_subclass + + q_group_size = 256 if group_size is None else group_size + q_config = IntxWeightOnlyConfig( + # pyre-ignore[16] + weight_dtype=torch.int4, + granularity=PerGroup(q_group_size), + ) + quantize_(model, q_config) + model = unwrap_tensor_subclass(model) + return model else: raise Exception(f"Unrecognized quantize mode: {qmode}") diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index de5564cae4f..8f8646e88cc 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -315,7 +315,13 @@ class QuantizationConfig: """ # Constants. - QMODE_OPTIONS: ClassVar[List[str]] = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + QMODE_OPTIONS: ClassVar[List[str]] = [ + "int8", + "8da4w", + "8da4w-gptq", + "vulkan_4w", + "4w", + ] AO_QUANT_PATTERNS: ClassVar[List[str]] = [ r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w", From c98ea5c2da6d6059f54190c13111d9e86edaadca Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 14 Aug 2025 22:46:55 -0700 Subject: [PATCH 2/2] Update on "[ET-VK] Enable IntxWeightOnlyConfig" ## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Very similar to the below PR, Int4 weight only quantization is currently enabled in Vulkan via a custom source transform quantizer that replaces linear layers with a custom linear layer that calls a custom weight only quantized linear op. This diff aims to make it so that no Vulkan specific source transforms need to be applied by adding a fusion pattern for weight only quantized linear. ## Changes * Introduce a fusable graph pattern for weight only quantized linear * Add fusion logic for weight only quantized linear in the fuse patterns pass * Add `4w` qmode to the export llama script Differential Revision: [D80293302](https://our.internmc.facebook.com/intern/diff/D80293302/) [ghstack-poisoned] --- backends/vulkan/_passes/fuse_patterns.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py index 3a5694385ed..3e1770ffaa7 100644 --- a/backends/vulkan/_passes/fuse_patterns.py +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -284,9 +284,6 @@ def create_wo_quantized_linear_custom_op( out_tensor.replace_all_uses_with(wo_qlinear) - # Clean up dead code - graph_module.graph.eliminate_dead_code() - class FusePatternsPass(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: