diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0e015418d42..c4334443f23 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -682,6 +682,10 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 args.enable_dynamic_shape, ) ) + # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK + partitioners.append( + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) + ) modelname = f"vulkan_{modelname}" if args.mps: diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index d168b7efcdc..f8952ad0e53 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -157,7 +157,17 @@ def quantize( # noqa C901 model = gptq_quantizer.quantize(model, inputs) return model elif qmode == "vulkan_4w": - model = VkInt4WeightOnlyQuantizer().quantize(model) + q_group_size = 256 if group_size is None else group_size + model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) + + # Apply additional quantizer for linear layers that aren't lowered to Vulkan + # at the moment + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + model = Int8DynActInt4WeightQuantizer( + precision=torch_dtype, groupsize=q_group_size + ).quantize(model) + return model else: raise Exception(f"Unrecognized quantize mode: {qmode}")