diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index efe3ebd764b..d9e527e7c78 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -237,7 +237,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true" fi if [[ "${QE}" == "ON" ]]; then - EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\"" + EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,768\"" fi if [[ "${MPS}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true" diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 48edc3c0669..af2fa3c74ee 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -23,7 +23,6 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass -from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import save_pte_program @@ -211,9 +210,7 @@ def main() -> None: executorch_program = edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - QuantFusionPass(), - ], + do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 8b76b7650fe..7cb65833f98 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -595,19 +595,16 @@ def __init__( @torch.no_grad() def create_quantized_state_dict(self, packed=False) -> Dict: + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + MappingType, + quantize_, + ) + cur_state_dict = self.mod.state_dict() - if self.bitwidth == 2: - range_min = -2 - range_max = 1 - elif self.bitwidth == 4: - range_min = -8 - range_max = 7 - elif self.bitwidth == 8: - range_min = -128 - range_max = 127 - else: - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") + assert self.bitwidth in [2, 4, 8], f"Unsupported bitwidth {self.bitwidth}" for fqn, mod in self.mod.named_modules(): if isinstance(mod, nn.Embedding): @@ -619,18 +616,22 @@ def create_quantized_state_dict(self, packed=False) -> Dict: print( f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" ) - weight, scales, _ = dynamically_quantize_per_channel( - ( - mod.weight.to(dtype=self.precision) - if self.precision - else mod.weight + tmp_model = nn.Embedding(mod.weight.shape[0], mod.weight.shape[1]) + if self.precision: + tmp_model = tmp_model.to(dtype=self.precision) + tmp_model.weight = nn.Parameter(mod.weight) + config = IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bitwidth}"), + granularity=( + PerAxis(0) + if (self.group_size is None or self.group_size == 0) + else PerGroup(self.group_size) ), - range_min, - range_max, - torch.int8, - self.group_size, - scales_dtype=mod.weight.dtype, + mapping_type=MappingType.SYMMETRIC, ) + quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding)) + weight = tmp_model.weight.qdata # pyre-ignore[16] + scales = tmp_model.weight.scale # pyre-ignore[16] if packed: if self.bitwidth == 2: diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 54ef522047d..e1678b089b8 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -986,25 +986,54 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax): ] -def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901 - List[Tuple[Callable, Callable, List[Callable]]] -): +def _get_embedding_ops_patterns_and_replacements_torchao( # noqa C901 + node_value_dict, +) -> List[Tuple[Callable, Callable, List[Callable]]]: + + def get_embedding_replacement_filter(has_nonzero_zero_point): + def _filter(match, original_graph, pattern_graph): + assert node_value_dict is not None, "node_value_dict cannot be None" + + def get_val(name): + node = [n for n in match.nodes_map if n.name == name][0] + val = match.nodes_map[node] + if isinstance(val, torch.fx.Node) and val.target in node_value_dict: + return node_value_dict[val.target] + return val + + zero_point = get_val("zero_point") + all_zero = (zero_point == 0).all().item() + if has_nonzero_zero_point: + return not all_zero + else: + return all_zero + + return _filter + def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point): dq = torch.ops.torchao.dequantize_affine.default( int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127 ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point): - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_byte.default( - int_data, - scale, - zero_point_dtype_cast, - -128, - 127, - indices, - ) + def get_embedding_byte_replacement(has_nonzero_zero_point): + def embedding_byte_replacement( + indices, int_data, group_size, scale, zero_point + ): + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_byte.default( + int_data, + scale, + zero_point_dtype_cast, + -128, + 127, + indices, + ) + + return embedding_byte_replacement def embedding_byte_dtype_pattern( indices, int_data, group_size, scale, zero_point, output_dtype @@ -1021,19 +1050,25 @@ def embedding_byte_dtype_pattern( ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_byte_dtype_replacement( - indices, int_data, group_size, scale, zero_point, output_dtype - ): - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_byte.dtype( - int_data, - scale, - zero_point_dtype_cast, - -128, - 127, - indices, - dtype=output_dtype, - ) + def get_embedding_byte_dtype_replacement(has_nonzero_zero_point): + def embedding_byte_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_byte.dtype( + int_data, + scale, + zero_point_dtype_cast, + -128, + 127, + indices, + dtype=output_dtype, + ) + + return embedding_byte_dtype_replacement def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point): dq = torch.ops.torchao.dequantize_affine.default( @@ -1041,14 +1076,22 @@ def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point): ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point): - packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( - int_data, 2 - ) - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_2bit.default( - packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices - ) + def get_embedding_2bit_replacement(has_nonzero_zero_point): + def embedding_2bit_replacement( + indices, int_data, group_size, scale, zero_point + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 2 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_2bit.default( + packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices + ) + + return embedding_2bit_replacement def embedding_2bit_dtype_pattern( indices, int_data, group_size, scale, zero_point, output_dtype @@ -1065,22 +1108,28 @@ def embedding_2bit_dtype_pattern( ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_2bit_dtype_replacement( - indices, int_data, group_size, scale, zero_point, output_dtype - ): - packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( - int_data, 2 - ) - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_2bit.dtype( - packed_int_data, - scale, - zero_point_dtype_cast, - -2, - 1, - indices, - dtype=output_dtype, - ) + def get_embedding_2bit_dtype_replacement(has_nonzero_zero_point): + def embedding_2bit_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 2 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_2bit.dtype( + packed_int_data, + scale, + zero_point_dtype_cast, + -2, + 1, + indices, + dtype=output_dtype, + ) + + return embedding_2bit_dtype_replacement def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point): dq = torch.ops.torchao.dequantize_affine.default( @@ -1088,14 +1137,22 @@ def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point): ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point): - packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( - int_data, 4 - ) - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_4bit.default( - packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices - ) + def get_embedding_4bit_replacement(has_nonzero_zero_point): + def embedding_4bit_replacement( + indices, int_data, group_size, scale, zero_point + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 4 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_4bit.default( + packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices + ) + + return embedding_4bit_replacement def embedding_4bit_dtype_pattern( indices, int_data, group_size, scale, zero_point, output_dtype @@ -1112,53 +1169,97 @@ def embedding_4bit_dtype_pattern( ) return torch.ops.aten.embedding.default(dq, indices) - def embedding_4bit_dtype_replacement( - indices, int_data, group_size, scale, zero_point, output_dtype - ): - packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( - int_data, 4 - ) - zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) - return torch.ops.quantized_decomposed.embedding_4bit.dtype( - packed_int_data, - scale, - zero_point_dtype_cast, - -8, - 7, - indices, - dtype=output_dtype, - ) + def get_embedding_4bit_dtype_replacement(has_nonzero_zero_point): + def embedding_4bit_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 4 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + zero_point_dtype_cast = ( + zero_point_dtype_cast if has_nonzero_zero_point else None + ) + return torch.ops.quantized_decomposed.embedding_4bit.dtype( + packed_int_data, + scale, + zero_point_dtype_cast, + -8, + 7, + indices, + dtype=output_dtype, + ) + + return embedding_4bit_dtype_replacement return [ ( _trace_and_lower_to_edge_ops(embedding_byte_pattern), - _trace_and_lower_to_edge_ops(embedding_byte_replacement), - [], + _trace_and_lower_to_edge_ops(get_embedding_byte_replacement(False)), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_byte_pattern), + _trace_and_lower_to_edge_ops(get_embedding_byte_replacement(True)), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ( _trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), - _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), - [], + _trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(False)), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), + _trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(True)), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ( _trace_and_lower_to_edge_ops(embedding_2bit_pattern), - _trace_and_lower_to_edge_ops(embedding_2bit_replacement), - [], + _trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(False)), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_2bit_pattern), + _trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(True)), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ( _trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), - _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), - [], + _trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(False)), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), + _trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(True)), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ( _trace_and_lower_to_edge_ops(embedding_4bit_pattern), - _trace_and_lower_to_edge_ops(embedding_4bit_replacement), - [], + _trace_and_lower_to_edge_ops( + get_embedding_4bit_replacement(has_nonzero_zero_point=False) + ), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_4bit_pattern), + _trace_and_lower_to_edge_ops( + get_embedding_4bit_replacement(has_nonzero_zero_point=True) + ), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ( _trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), - _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), - [], + _trace_and_lower_to_edge_ops( + get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=False) + ), + [get_embedding_replacement_filter(has_nonzero_zero_point=False)], + ), + ( + _trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), + _trace_and_lower_to_edge_ops( + get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=True) + ), + [get_embedding_replacement_filter(has_nonzero_zero_point=True)], ), ] @@ -1445,9 +1546,9 @@ def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax): """ -def get_quant_patterns_and_replacements() -> ( - List[Tuple[Callable, Callable, List[Callable]]] -): +def get_quant_patterns_and_replacements( + node_value_dict, +) -> List[Tuple[Callable, Callable, List[Callable]]]: return copy.copy( [ @@ -1457,6 +1558,6 @@ def get_quant_patterns_and_replacements() -> ( *_get_slice_patterns_and_replacements(), # *_get_fixed_qparams_ops_patterns_and_replacements(), *_get_embedding_ops_patterns_and_replacements(), - *_get_embedding_ops_patterns_and_replacements_torchao(), + *_get_embedding_ops_patterns_and_replacements_torchao(node_value_dict), ] ) diff --git a/exir/passes/quant_fusion_pass.py b/exir/passes/quant_fusion_pass.py index 6941fc65229..b46b34f1d19 100644 --- a/exir/passes/quant_fusion_pass.py +++ b/exir/passes/quant_fusion_pass.py @@ -9,6 +9,8 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.constant_prop_pass import constant_prop_pass from torch.export import ExportedProgram +from torch.export.exported_program import InputKind +from torch.export.graph_signature import TensorArgument from torch.fx import GraphModule, subgraph_rewriter from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree @@ -104,11 +106,27 @@ def _remove_dtype_getattr_nodes(model: GraphModule) -> None: model.recompile() +def _get_node_value_dict(program): + """ + Returns a dict of real tensor values for buffers/parameters in the program + """ + node_value_dict = {} + for input_ in program.graph_signature.input_specs: + if ( + input_.kind in (InputKind.BUFFER, InputKind.PARAMETER) + and isinstance(input_.arg, TensorArgument) + and input_.target in program.state_dict + ): + node_value_dict[input_.arg.name] = program.state_dict[input_.target] + return node_value_dict + + class QuantFusionPass(ExportPass): - def __init__(self, _fix_node_meta_val=False): + def __init__(self, _fix_node_meta_val=False, node_value_dict=None): super().__init__() # TODO This pass violate IR spec because it produces a graph missing node.meta['val'] self._fix_node_meta_val = _fix_node_meta_val + self.node_value_dict = node_value_dict def call(self, graph_module: GraphModule) -> PassResult: """Lower a quantized reference model (with reference quantized operator patterns) @@ -124,7 +142,7 @@ def call(self, graph_module: GraphModule) -> PassResult: pattern, replacement, match_filters, - ) in get_quant_patterns_and_replacements(): + ) in get_quant_patterns_and_replacements(self.node_value_dict): subgraph_rewriter.replace_pattern_with_filters( graph_module, pattern, replacement, match_filters ) @@ -145,7 +163,10 @@ def call(self, graph_module: GraphModule) -> PassResult: def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram: gm = program.graph_module - gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm) + node_value_dict = _get_node_value_dict(program) + gm_res = QuantFusionPass(_fix_node_meta_val=True, node_value_dict=node_value_dict)( + gm + ) gm = gm_res.graph_module # Do const prop pass to remove packing/dtype conversion ops diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index e3073197b2b..8622fca0bd8 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -14,6 +14,7 @@ from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.passes.constant_prop_pass import constant_prop_pass from executorch.exir.passes.quant_fusion_pass import ( + _get_node_value_dict, quant_fusion_and_const_prop_pass, QuantFusionPass, ) @@ -36,7 +37,7 @@ from torch.testing import FileCheck from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.quantization.quant_api import IntxWeightOnlyConfig, MappingType, quantize_ from torchao.quantization.utils import compute_error @@ -383,13 +384,22 @@ def forward(self, indices): # ) def test_embedding_torchao(self) -> None: - for bit_width, use_dtype_variant, test_per_group in zip( - [2, 4, 8], [True, False], [True, False] + for bit_width, use_dtype_variant, test_per_group, mapping_type in zip( + [2, 4, 8], + [True, False], + [True, False], + [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], ): - self._test_embedding_torchao(bit_width, use_dtype_variant, test_per_group) + self._test_embedding_torchao( + bit_width, use_dtype_variant, test_per_group, mapping_type + ) def _test_embedding_torchao( - self, bit_width: int, use_dtype_variant: bool, test_per_group: bool + self, + bit_width: int, + use_dtype_variant: bool, + test_per_group: bool, + mapping_type: MappingType, ) -> None: assert bit_width in [2, 4, 8] embedding_suffix = f"{bit_width}bit" if bit_width < 8 else "byte" @@ -411,7 +421,9 @@ def _test_embedding_torchao( quantize_( model, IntxWeightOnlyConfig( - weight_dtype=getattr(torch, f"int{bit_width}"), granularity=granularity + weight_dtype=getattr(torch, f"int{bit_width}"), + granularity=granularity, + mapping_type=mapping_type, ), lambda m, fqn: isinstance(m, torch.nn.Embedding), ) @@ -439,7 +451,10 @@ def _test_embedding_torchao( m.exported_program().graph_module.code ) - m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) + node_value_dict = _get_node_value_dict(m.exported_program()) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True, node_value_dict=node_value_dict)] + ) # After pass, we see packing op and quantized embedding op, but no torchao dequantize op FileCheck().check_count( @@ -458,6 +473,22 @@ def _test_embedding_torchao( constant_prop_pass(m.exported_program()) + found_embedding_node = False + seeking_suffix = embedding_suffix.replace("_", ".") + seeking = f"quantized_decomposed::embedding_{seeking_suffix}" + for node in m.exported_program().graph.nodes: + if node.op == "call_function" and node.target.name() == seeking: + found_embedding_node = True + if mapping_type == MappingType.SYMMETRIC: + assert ( + node.args[2] is None + ), f"Expected zero_point=None for symmetric quantization, but got {node.args[2]}" + else: + assert node.args[2] is not None + assert ( + found_embedding_node + ), f"Did not find embedding node with target {seeking}" + # After constant prop, we see quantized embedding op, but no packing op FileCheck().check_count( f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",