diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 4d39d131d1d..3c4e3f13e6f 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -116,6 +116,9 @@ class ModelArgs: bos_count: int = -1 # i.e., a single EOS is used as BOS eos_count: int = 2 + quantization_args: Optional[dict] = None + lora_args: Optional[dict] = None + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ad997de64cd..c31af23f5b6 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -165,7 +165,7 @@ def __init__(self, **kwargs): ) elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") - self._transform_for_pre_quantization(checkpoint) + self._transform_for_pre_quantization(checkpoint, model_args) from .source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, @@ -174,8 +174,9 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) elif hasattr(self.args, "use_qat") and self.args.use_qat: print("Using QAT quantization.") - self._transform_for_pre_quantization(checkpoint) + self._transform_for_pre_quantization(checkpoint, model_args) if hasattr(self.args, "use_lora") and self.args.use_lora: + assert model_args.lora_args["rank"] == self.args.use_lora from .source_transformation.lora import ( transform_linear_for_lora_after_quantization, ) @@ -251,7 +252,7 @@ def get_example_inputs_kvcache_sdpa(self): ), # start_pos, what token of output are we on. ) - def _transform_for_pre_quantization(self, checkpoint): + def _transform_for_pre_quantization(self, checkpoint, model_args): assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" assert self.args.preq_mode in [ "8da4w", @@ -264,6 +265,7 @@ def _transform_for_pre_quantization(self, checkpoint): from .source_transformation.pre_quantization import ( transform_linear_for_pre_quantization, ) + assert self.args.preq_group_size == model_args.quantization_args["group_size"] mapping = { "fp32": torch.float32,