From 07a8d370e0a6c043a84f7cd96a51d9647def798b Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 2 Oct 2024 11:38:15 -0700 Subject: [PATCH 1/2] add support for loading qat_lora checkpoints This PR adds support to load qat_lora checkpoints. It mainly does the following two things: - Refactor the existing quantization flow for SpinQuant to be separate function, which is used to load QAT checkpoint as well since they share the same format. - For QAT_LoRA checkpoint, we do one more extra step after quantization. It replaces `Int8DynActInt4WeightLinear` layers with `Int8DynActInt4WeightLinearLoRA` which contains LoRA adaptor. Differential Revision: [D63714794](https://our.internmc.facebook.com/intern/diff/D63714794/) [ghstack-poisoned] --- examples/models/llama2/TARGETS | 1 + examples/models/llama2/export_llama_lib.py | 17 +++ examples/models/llama2/model.py | 132 +++++++++------- .../llama2/source_transformation/lora.py | 141 ++++++++++++++++++ 4 files changed, 234 insertions(+), 57 deletions(-) create mode 100644 examples/models/llama2/source_transformation/lora.py diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 40822e574c3..a80c62514df 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -80,6 +80,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/lora.py", "source_transformation/pre_quantization.py", "source_transformation/prune_output.py", "source_transformation/quantize.py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index a39bb048200..cf8d221c8e5 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -390,6 +390,23 @@ def build_args_parser() -> argparse.ArgumentParser: help="Use SpinQuant for better quantization performance. Only support cuda and native.", ) + parser.add_argument( + "-qat", + "--use_qat", + default=False, + action="store_true", + help="Whether the checkpoin is pre-quantized with QAT or not.", + ) + + parser.add_argument( + "-lora", + "--use_lora", + type=int, + default=0, + help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; " + "otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.", + ) + parser.add_argument( "--preq_mode", type=str, diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index a4081d1bd57..9e7cc7f734f 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -13,6 +13,7 @@ import torch from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer +from executorch.extension.llm.export.builder import DType try: from .fairseq2 import convert_to_llama_checkpoint @@ -191,73 +192,31 @@ def __init__(self, **kwargs): ) elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") - assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" - assert self.args.preq_mode in [ - "8da4w", - "8da4w_output_8da8w", - ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." - assert hasattr( - self.args, "preq_group_size" - ), "preq_group_size must be specified" - assert hasattr( - self.args, "dtype_override" - ), "dtype_override must be specified" + self._transform_for_pre_quantization(checkpoint) + from .source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, - transform_linear_for_pre_quantization, - ) - - mapping = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, - } - - # Transform the output layer first if needed. - if self.args.preq_mode == "8da4w_output_8da8w": - from .source_transformation.pre_quantization import ( - transform_output_linear_for_pre_quantization, - ) - - self.model_ = transform_output_linear_for_pre_quantization( - module=self.model_, - checkpoint=checkpoint, - dtype=mapping[self.args.dtype_override], - ) - - self.model_ = transform_linear_for_pre_quantization( - self.model_, - checkpoint, - self.args.preq_group_size, - mapping[self.args.dtype_override], ) - embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "preq_embedding_quantize"): - embedding_bit_width, embedding_group_size = ( - self.args.preq_embedding_quantize.split(",") - ) - from .source_transformation.pre_quantization import ( - transform_embedding_for_pre_quantization, + sanitize_checkpoint_from_pre_quantization(checkpoint) + elif hasattr(self.args, "use_qat") and self.args.use_qat: + print("Using QAT quantization.") + self._transform_for_pre_quantization(checkpoint) + if hasattr(self.args, "use_lora") and self.args.use_lora: + from .source_transformation.lora import ( + transform_linear_for_lora_after_quantization, ) - if ( - embedding_group_size == "none" - or embedding_group_size == "None" - or embedding_group_size == "0" - ): - embedding_group_size = None - else: - embedding_group_size = int(embedding_group_size) - - self.model_ = transform_embedding_for_pre_quantization( + self.model_ = transform_linear_for_lora_after_quantization( self.model_, checkpoint, - mapping[self.args.dtype_override], - int(embedding_bit_width), - embedding_group_size, + self.args.use_lora, ) + from .source_transformation.pre_quantization import ( + sanitize_checkpoint_from_pre_quantization, + ) + sanitize_checkpoint_from_pre_quantization(checkpoint) # assign=True: load params/buffers by assignment instead of performing an in-place copy. @@ -318,3 +277,62 @@ def get_example_inputs_kvcache_sdpa(self): [0], dtype=torch.long ), # start_pos, what token of output are we on. ) + + def _transform_for_pre_quantization(self, checkpoint): + assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" + assert self.args.preq_mode in [ + "8da4w", + "8da4w_output_8da8w", + ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." + assert hasattr( + self.args, "preq_group_size" + ), "preq_group_size must be specified" + assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" + from .source_transformation.pre_quantization import ( + transform_linear_for_pre_quantization, + ) + + # Transform the output layer first if needed. + if self.args.preq_mode == "8da4w_output_8da8w": + from .source_transformation.pre_quantization import ( + transform_output_linear_for_pre_quantization, + ) + + self.model_ = transform_output_linear_for_pre_quantization( + module=self.model_, + checkpoint=checkpoint, + dtype=DType[self.args.dtype_override].to_torch_dtype(), + ) + + self.model_ = transform_linear_for_pre_quantization( + self.model_, + checkpoint, + self.args.preq_group_size, + DType[self.args.dtype_override].to_torch_dtype(), + ) + + embedding_bit_width, embedding_group_size = None, None + if hasattr(self.args, "preq_embedding_quantize"): + embedding_bit_width, embedding_group_size = ( + self.args.preq_embedding_quantize.split(",") + ) + from .source_transformation.pre_quantization import ( + transform_embedding_for_pre_quantization, + ) + + if ( + embedding_group_size == "none" + or embedding_group_size == "None" + or embedding_group_size == "0" + ): + embedding_group_size = None + else: + embedding_group_size = int(embedding_group_size) + + self.model_ = transform_embedding_for_pre_quantization( + self.model_, + checkpoint, + DType[self.args.dtype_override].to_torch_dtype(), + int(embedding_bit_width), + embedding_group_size, + ) diff --git a/examples/models/llama2/source_transformation/lora.py b/examples/models/llama2/source_transformation/lora.py new file mode 100644 index 00000000000..11fcba76c77 --- /dev/null +++ b/examples/models/llama2/source_transformation/lora.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Helper functions for tranforming the model to be able to load checkpoints with +# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA. + +from typing import Any + +import torch +from torch import nn +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +class LoRAAdaptorLinear(nn.Module): + """ + LoRA adaptor for linear layers. + + This class implements Low-Rank Adaptation(LoRA) for linear layers. + See more details about LoRA here https://arxiv.org/abs/2106.09685. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int, + scale: float = 2.0, + dtype=torch.float32, + device=None, + ) -> None: + super().__init__() + self.scale = scale + self.A = nn.Linear(in_features, rank, bias=False, dtype=dtype, device=device) + self.B = nn.Linear(rank, out_features, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.scale * self.B(self.A(x)) # pyre-ignore[7] + + +class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): + """ + Int8DynActInt4WeightLinear with LoRA adaptor. + """ + + def __init__( + self, + in_features: int, + out_features: int, + lora_rank: int, + bias=True, + device=None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + lora_adaptor_precision: torch.dtype = torch.bfloat16, + lora_scale: float = 2.0, + ) -> None: + super().__init__( + in_features, + out_features, + bias=bias, + device=device, + groupsize=groupsize, + precision=precision, + scales_precision=scales_precision, + ) + self.adaptor = LoRAAdaptorLinear( + in_features, + out_features, + lora_rank, + scale=lora_scale, + dtype=lora_adaptor_precision, + device=device, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input) + self.adaptor(input).to(dtype=self.precision) + + +def _replace_linear_8da4w_for_lora( + module: torch.nn.Module, + checkpoint: Any, + lora_rank: int, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace linear layers where the checkpoint contains explicit adaptors + adaptor_A_key = f"{cur_fqn}.adaptor.A.weight" + adaptor_B_key = f"{cur_fqn}.adaptor.B.weight" + if ( + isinstance(child, Int8DynActInt4WeightLinear) + and adaptor_A_key in checkpoint + and adaptor_B_key in checkpoint + ): + assert checkpoint[adaptor_A_key].dtype == torch.bfloat16 + assert checkpoint[adaptor_A_key].shape[0] == lora_rank + assert checkpoint[adaptor_A_key].shape[1] == child.in_features + assert checkpoint[adaptor_B_key].dtype == torch.bfloat16 + assert checkpoint[adaptor_B_key].shape[0] == child.out_features + assert checkpoint[adaptor_B_key].shape[1] == lora_rank + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt4WeightLinearLoRA( + child.in_features, + child.out_features, + lora_rank=lora_rank, + bias=False, + device=child.weight.device, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales.dtype, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_linear_for_lora_after_quantization( + module: torch.nn.Module, + checkpoint: Any, + lora_rank: int, +) -> torch.nn.Module: + """ + Transform the model to be able to load checkpoints with LoRA adaptors. + The model should be already transformed to be able to load pre-quantized + checkpoints. The checkpoint should have been pre-quantized and added with + LoRA adaptors. + """ + _replace_linear_8da4w_for_lora( + module, + checkpoint, + lora_rank, + ) + return module From 96e6198bfa84d879c33117f12e646a1e3d601208 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 2 Oct 2024 12:10:57 -0700 Subject: [PATCH 2/2] Update on "add support for loading qat_lora checkpoints" This PR adds support to load qat_lora checkpoints. It mainly does the following two things: - Refactor the existing quantization flow for SpinQuant to be separate function, which is used to load QAT checkpoint as well since they share the same format. - For QAT_LoRA checkpoint, we do one more extra step after quantization. It replaces `Int8DynActInt4WeightLinear` layers with `Int8DynActInt4WeightLinearLoRA` which contains LoRA adaptor. Differential Revision: [D63714794](https://our.internmc.facebook.com/intern/diff/D63714794/) [ghstack-poisoned] --- examples/models/llama2/model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 9e7cc7f734f..d8d0ff00ffa 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -13,7 +13,6 @@ import torch from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer -from executorch.extension.llm.export.builder import DType try: from .fairseq2 import convert_to_llama_checkpoint @@ -292,6 +291,12 @@ def _transform_for_pre_quantization(self, checkpoint): transform_linear_for_pre_quantization, ) + mapping = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + # Transform the output layer first if needed. if self.args.preq_mode == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( @@ -301,14 +306,14 @@ def _transform_for_pre_quantization(self, checkpoint): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=DType[self.args.dtype_override].to_torch_dtype(), + dtype=mapping[self.args.dtype_override], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, self.args.preq_group_size, - DType[self.args.dtype_override].to_torch_dtype(), + mapping[self.args.dtype_override], ) embedding_bit_width, embedding_group_size = None, None @@ -332,7 +337,7 @@ def _transform_for_pre_quantization(self, checkpoint): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - DType[self.args.dtype_override].to_torch_dtype(), + mapping[self.args.dtype_override], int(embedding_bit_width), embedding_group_size, )