From ef5fa4601bb0d77aa021bc5422c36c739872ce45 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 25 Nov 2024 03:46:10 +0530 Subject: [PATCH 1/4] Enabled repeat KV heads for AWQ/GPTQ models Signed-off-by: Onkar Chougule --- QEfficient/transformers/quantizers/auto.py | 28 +++ .../transformers/quantizers/quantizer_gptq.py | 4 +- QEfficient/utils/__init__.py | 4 + .../perplexity_computation/README.md | 0 .../perplexity_computation/__init__.py | 0 .../calculate_perplexity.py | 0 .../replicate_kv_head/README.md | 0 .../replicate_kv_head/__init__.py | 0 .../replicate_kv_head/replicate_kv_heads.py | 172 ++++++++++++++++++ .../replicate_kv_head/replicate_kv_heads.py | 83 --------- .../models/test_causal_lm_models.py | 11 +- 11 files changed, 207 insertions(+), 95 deletions(-) rename {scripts => examples}/perplexity_computation/README.md (100%) rename {scripts => examples}/perplexity_computation/__init__.py (100%) rename {scripts => examples}/perplexity_computation/calculate_perplexity.py (100%) rename {scripts => examples}/replicate_kv_head/README.md (100%) rename {scripts => examples}/replicate_kv_head/__init__.py (100%) create mode 100644 examples/replicate_kv_head/replicate_kv_heads.py delete mode 100644 scripts/replicate_kv_head/replicate_kv_heads.py diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py index aa84f9084..f4cec3b54 100644 --- a/QEfficient/transformers/quantizers/auto.py +++ b/QEfficient/transformers/quantizers/auto.py @@ -6,12 +6,17 @@ # ---------------------------------------------------------------------------- from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING +from transformers.quantizers.quantizer_awq import AwqQuantizer +from transformers.quantizers.quantizer_gptq import GptqHfQuantizer +from transformers.utils.quantization_config import AwqConfig, GPTQConfig from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer} QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig, "gptq": QEffGPTQConfig} +DUPLICATE_AUTO_QUANTIZER_MAPPING = {"awq": AwqQuantizer, "gptq": GptqHfQuantizer} +DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": AwqConfig, "gptq": GPTQConfig} def with_replaced_quantizers(func): @@ -39,3 +44,26 @@ def wrapper(*args, **kwargs): return out return wrapper + + +def replace_transformers_quantizers(): + """ + This method lets you import AWQ/GPTQ models on CPU without bypassing the + rule of transformers of need to GPU. + Just call this method before using + `transformer.AutoModelForCausalLM.from_pretrained` and any AWQ/GPTQ model + that can be supported by QEfficient will be loaded using CPU. + """ + AUTO_QUANTIZER_MAPPING.update(QEFF_AUTO_QUANTIZER_MAPPING) + AUTO_QUANTIZATION_CONFIG_MAPPING.update(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING) + + +# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers. +# When a user imports QEfficient, these are always available. +def undo_transformers_quantizers(): + """ + This method is used to undo the effects on method `replace_transformers_quantizers`. + After this is called, the transformers library will be used for loading AWQ/GPTQ models. + """ + AUTO_QUANTIZER_MAPPING.update(DUPLICATE_AUTO_QUANTIZER_MAPPING) + AUTO_QUANTIZATION_CONFIG_MAPPING.update(DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING) diff --git a/QEfficient/transformers/quantizers/quantizer_gptq.py b/QEfficient/transformers/quantizers/quantizer_gptq.py index 76dfe3718..675e6258d 100644 --- a/QEfficient/transformers/quantizers/quantizer_gptq.py +++ b/QEfficient/transformers/quantizers/quantizer_gptq.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import torch -from transformers.quantizers.quantizer_gptq import HfQuantizer +from transformers.quantizers.quantizer_gptq import GptqHfQuantizer from transformers.utils.quantization_config import GPTQConfig from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ @@ -38,7 +38,7 @@ def post_init(self): raise ValueError("damp_percent must be between 0 and 1.") -class QEffGPTQQuantizer(HfQuantizer): +class QEffGPTQQuantizer(GptqHfQuantizer): """ Quantizer class for QEffGPTQ, extending HfQuantizer. This class handles the initialization, environment validation, dtype updating, and model processing for quantization. diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index b9efbf720..2506b9233 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -5,6 +5,10 @@ # # ----------------------------------------------------------------------------- +from QEfficient.transformers.quantizers.auto import ( # noqa: F401 + replace_transformers_quantizers, + undo_transformers_quantizers, +) from QEfficient.utils._utils import ( # noqa: F401 check_and_assign_cache_dir, get_num_layers_from_config, diff --git a/scripts/perplexity_computation/README.md b/examples/perplexity_computation/README.md similarity index 100% rename from scripts/perplexity_computation/README.md rename to examples/perplexity_computation/README.md diff --git a/scripts/perplexity_computation/__init__.py b/examples/perplexity_computation/__init__.py similarity index 100% rename from scripts/perplexity_computation/__init__.py rename to examples/perplexity_computation/__init__.py diff --git a/scripts/perplexity_computation/calculate_perplexity.py b/examples/perplexity_computation/calculate_perplexity.py similarity index 100% rename from scripts/perplexity_computation/calculate_perplexity.py rename to examples/perplexity_computation/calculate_perplexity.py diff --git a/scripts/replicate_kv_head/README.md b/examples/replicate_kv_head/README.md similarity index 100% rename from scripts/replicate_kv_head/README.md rename to examples/replicate_kv_head/README.md diff --git a/scripts/replicate_kv_head/__init__.py b/examples/replicate_kv_head/__init__.py similarity index 100% rename from scripts/replicate_kv_head/__init__.py rename to examples/replicate_kv_head/__init__.py diff --git a/examples/replicate_kv_head/replicate_kv_heads.py b/examples/replicate_kv_head/replicate_kv_heads.py new file mode 100644 index 000000000..417fcde7a --- /dev/null +++ b/examples/replicate_kv_head/replicate_kv_heads.py @@ -0,0 +1,172 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM, export +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ + + +def main(args): + # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU. + replace_transformers_quantizers() + # Load the model and tokenizer + model_name = args.model_name + model_base_name = model_name.split("/")[-1] + model = AutoModelForCausalLM.from_pretrained( + model_name, # num_hidden_layers=1, + attn_implementation="eager", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + inputs = tokenizer(args.prompt, return_tensors="pt") + + # Generate original outputs and tokens + with torch.inference_mode(): + _ = model(**inputs) # original output + orig_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) + + # Modify the number of key-value heads + repeat = args.repeat + orig_kv_heads = model.config.num_key_value_heads + new_kv_heads = repeat * orig_kv_heads + model.config.num_key_value_heads = new_kv_heads + + print("Original KV heads:", orig_kv_heads) + print("Modified KV heads:", new_kv_heads) + + # Update the model's attention layers with new key-value heads + for block in model.model.layers: + attn = block.self_attn + attn.num_key_value_heads = new_kv_heads + attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads + k_proj = attn.k_proj + v_proj = attn.v_proj + if isinstance(attn.k_proj, (WQLinear_GEMM, QuantLinearGPTQ)): + if attn.head_dim % 8 != 0: + raise ValueError(f"the value attn.head_dim={attn.head_dim} is not divisible by 8 which is \ + according to the assumption that model is 4-bit quantized.") + if attn.hidden_size % k_proj.group_size != 0 or attn.hidden_size % v_proj.group_size: + raise ValueError(f"The value of attn.hidden_size={attn.hidden_size} is not divisible by \ + K_proj.group_size={k_proj.group_size}") + + # Key projection duplication + # Duplication of quantized weights + k_proj.qweight.data = ( + torch.repeat_interleave( + k_proj.qweight.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size), repeat, 0 + ) + .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size) + .T + ) + # Duplication of quantized zero points + k_proj.qzeros.data = ( + torch.repeat_interleave( + k_proj.qzeros.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size // k_proj.group_size), + repeat, + 0, + ) + .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size // k_proj.group_size) + .T + ) + # Duplication of quantization scales + k_proj.scales.data = ( + torch.repeat_interleave( + k_proj.scales.data.T.view(orig_kv_heads, attn.head_dim, attn.hidden_size // k_proj.group_size), + repeat, + 0, + ) + .view(new_kv_heads * attn.head_dim, attn.hidden_size // k_proj.group_size) + .T + ) + k_proj.out_features = k_proj.out_features * repeat + else: + attn.k_proj.weight.data = torch.repeat_interleave( + attn.k_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 + ).view(new_kv_heads * attn.head_dim, attn.hidden_size) + + if isinstance(v_proj, (WQLinear_GEMM, QuantLinearGPTQ)): + if attn.head_dim % 8 != 0: + raise ValueError(f"the value attn.head_dim={attn.head_dim} is not divisible by 8 which is \ + according to the assumption that model is 4-bit quantized.") + if attn.hidden_size % v_proj.group_size: + raise ValueError(f"The value of attn.hidden_size={attn.hidden_size} is not divisible by \ + v_proj.group_size = {v_proj.group_size}") + + # Value projection duplication + # Duplication of quantized weights + v_proj.qweight.data = ( + torch.repeat_interleave( + v_proj.qweight.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size), repeat, 0 + ) + .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size) + .T + ) + # Duplication of quantized zero points + v_proj.qzeros.data = ( + torch.repeat_interleave( + v_proj.qzeros.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size // v_proj.group_size), + repeat, + 0, + ) + .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size // v_proj.group_size) + .T + ) + # Duplication of quantization scales + v_proj.scales.data = ( + torch.repeat_interleave( + v_proj.scales.data.T.view(orig_kv_heads, attn.head_dim, attn.hidden_size // v_proj.group_size), + repeat, + 0, + ) + .view(new_kv_heads * attn.head_dim, attn.hidden_size // v_proj.group_size) + .T + ) + v_proj.out_features = v_proj.out_features * repeat + else: + attn.v_proj.weight.data = torch.repeat_interleave( + attn.v_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 + ).view(new_kv_heads * attn.head_dim, attn.hidden_size) + + # Generate modified outputs and tokens + with torch.inference_mode(): + _ = model(**inputs) # Modified output + mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) + + # Print the original and modified token outputs + print("Original:", tokenizer.batch_decode(orig_tokens)) + print("Modified:", tokenizer.batch_decode(mod_tokens)) + + # Export the modified model + q_model = QEFFAutoModelForCausalLM(model, model_name) + export( + model_name, + q_model, + tokenizer=tokenizer, + onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", + ) + + # Undo the effect of replace_transformers_quantizers + undo_transformers_quantizers() + + +if __name__ == "__main__": + # Set up argument parser + parser = argparse.ArgumentParser(description="Modify and export a causal language model.") + parser.add_argument( + "--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="Name of the model to use." + ) + parser.add_argument("--prompt", type=str, default="My name is", help="Prompt to use for the model.") + parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.") + + args = parser.parse_args() + main(args) diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py deleted file mode 100644 index 844b8957d..000000000 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ /dev/null @@ -1,83 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import argparse - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM, export - - -def main(args): - # Load the model and tokenizer - model_name = args.model_name - model_base_name = model_name.split("/")[-1] - model = AutoModelForCausalLM.from_pretrained( - model_name, # num_hidden_layers=2, - attn_implementation="eager", - ) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - inputs = tokenizer(args.prompt, return_tensors="pt") - - # Generate original outputs and tokens - with torch.inference_mode(): - _ = model(**inputs) # original output - orig_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) - - # Modify the number of key-value heads - repeat = args.repeat - orig_kv_heads = model.config.num_key_value_heads - new_kv_heads = repeat * orig_kv_heads - model.config.num_key_value_heads = new_kv_heads - - print("Original KV heads:", orig_kv_heads) - print("Modified KV heads:", new_kv_heads) - - # Update the model's attention layers with new key-value heads - for block in model.model.layers: - attn = block.self_attn - attn.num_key_value_heads = new_kv_heads - attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads - attn.k_proj.weight.data = torch.repeat_interleave( - attn.k_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 - ).view(new_kv_heads * attn.head_dim, attn.hidden_size) - attn.v_proj.weight.data = torch.repeat_interleave( - attn.v_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 - ).view(new_kv_heads * attn.head_dim, attn.hidden_size) - - # Generate modified outputs and tokens - with torch.inference_mode(): - _ = model(**inputs) # Modified output - mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) - - # Print the original and modified token outputs - print("Original:", tokenizer.batch_decode(orig_tokens)) - print("Modified:", tokenizer.batch_decode(mod_tokens)) - - # Export the modified model - q_model = QEFFAutoModelForCausalLM(model, model_name) - export( - model_name, - q_model, - tokenizer=tokenizer, - onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", - ) - - -if __name__ == "__main__": - # Set up argument parser - parser = argparse.ArgumentParser(description="Modify and export a causal language model.") - parser.add_argument( - "--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="Name of the model to use." - ) - parser.add_argument("--prompt", type=str, default="My name is", help="Prompt to use for the model.") - parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.") - - args = parser.parse_args() - main(args) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index e5cc6325c..2f3401223 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -8,12 +8,10 @@ import numpy as np import pytest from transformers import AutoModelForCausalLM -from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM -from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer -from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers from QEfficient.utils import hf_download from QEfficient.utils._utils import load_hf_tokenizer from QEfficient.utils.constants import Constants @@ -41,13 +39,6 @@ ] -# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers. -# When a user imports QEfficient, these are always available. -def replace_transformers_quantizers(): - AUTO_QUANTIZER_MAPPING.update({"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer}) - AUTO_QUANTIZATION_CONFIG_MAPPING.update({"awq": QEffAwqConfig, "gptq": QEffGPTQConfig}) - - def load_causal_lm_model(model_config): """ Function to load model from huggingface and transform to KV model From aea230d03bccef09bd2a356613483629921c50ca Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 25 Nov 2024 04:06:43 +0530 Subject: [PATCH 2/4] undo location Signed-off-by: Onkar Chougule --- {examples => scripts}/perplexity_computation/README.md | 0 {examples => scripts}/perplexity_computation/__init__.py | 0 .../perplexity_computation/calculate_perplexity.py | 0 {examples => scripts}/replicate_kv_head/README.md | 0 {examples => scripts}/replicate_kv_head/__init__.py | 0 {examples => scripts}/replicate_kv_head/replicate_kv_heads.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename {examples => scripts}/perplexity_computation/README.md (100%) rename {examples => scripts}/perplexity_computation/__init__.py (100%) rename {examples => scripts}/perplexity_computation/calculate_perplexity.py (100%) rename {examples => scripts}/replicate_kv_head/README.md (100%) rename {examples => scripts}/replicate_kv_head/__init__.py (100%) rename {examples => scripts}/replicate_kv_head/replicate_kv_heads.py (100%) diff --git a/examples/perplexity_computation/README.md b/scripts/perplexity_computation/README.md similarity index 100% rename from examples/perplexity_computation/README.md rename to scripts/perplexity_computation/README.md diff --git a/examples/perplexity_computation/__init__.py b/scripts/perplexity_computation/__init__.py similarity index 100% rename from examples/perplexity_computation/__init__.py rename to scripts/perplexity_computation/__init__.py diff --git a/examples/perplexity_computation/calculate_perplexity.py b/scripts/perplexity_computation/calculate_perplexity.py similarity index 100% rename from examples/perplexity_computation/calculate_perplexity.py rename to scripts/perplexity_computation/calculate_perplexity.py diff --git a/examples/replicate_kv_head/README.md b/scripts/replicate_kv_head/README.md similarity index 100% rename from examples/replicate_kv_head/README.md rename to scripts/replicate_kv_head/README.md diff --git a/examples/replicate_kv_head/__init__.py b/scripts/replicate_kv_head/__init__.py similarity index 100% rename from examples/replicate_kv_head/__init__.py rename to scripts/replicate_kv_head/__init__.py diff --git a/examples/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py similarity index 100% rename from examples/replicate_kv_head/replicate_kv_heads.py rename to scripts/replicate_kv_head/replicate_kv_heads.py From 8a30aae716fff1cb2ef5c7a829f9e78c2fc190fc Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 25 Nov 2024 04:07:54 +0530 Subject: [PATCH 3/4] bugfix Signed-off-by: Onkar Chougule --- scripts/replicate_kv_head/replicate_kv_heads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index 417fcde7a..dcb6ff5e4 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -55,7 +55,7 @@ def main(args): if attn.head_dim % 8 != 0: raise ValueError(f"the value attn.head_dim={attn.head_dim} is not divisible by 8 which is \ according to the assumption that model is 4-bit quantized.") - if attn.hidden_size % k_proj.group_size != 0 or attn.hidden_size % v_proj.group_size: + if attn.hidden_size % k_proj.group_size != 0: raise ValueError(f"The value of attn.hidden_size={attn.hidden_size} is not divisible by \ K_proj.group_size={k_proj.group_size}") From 0393aff175bd263d72e6f4858fa01e518604332c Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 25 Nov 2024 13:23:03 +0530 Subject: [PATCH 4/4] fixed CI bug, simplified replication script Signed-off-by: Onkar Chougule --- .../transformers/quantizers/quantizer_gptq.py | 4 +- .../replicate_kv_head/replicate_kv_heads.py | 142 ++++++------------ 2 files changed, 50 insertions(+), 96 deletions(-) diff --git a/QEfficient/transformers/quantizers/quantizer_gptq.py b/QEfficient/transformers/quantizers/quantizer_gptq.py index 675e6258d..76f6efa79 100644 --- a/QEfficient/transformers/quantizers/quantizer_gptq.py +++ b/QEfficient/transformers/quantizers/quantizer_gptq.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import torch -from transformers.quantizers.quantizer_gptq import GptqHfQuantizer +from transformers.quantizers import HfQuantizer from transformers.utils.quantization_config import GPTQConfig from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ @@ -38,7 +38,7 @@ def post_init(self): raise ValueError("damp_percent must be between 0 and 1.") -class QEffGPTQQuantizer(GptqHfQuantizer): +class QEffGPTQQuantizer(HfQuantizer): """ Quantizer class for QEffGPTQ, extending HfQuantizer. This class handles the initialization, environment validation, dtype updating, and model processing for quantization. diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index dcb6ff5e4..2fdbaf883 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -16,17 +16,54 @@ from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +def duplicate_weights_for_linear_layer( + layer: torch.nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int +): + new_kv_heads = repeat * orig_kv_heads + if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)): + if head_dim % 8 != 0: + raise ValueError(f"the value head_dim={head_dim} is not divisible by 8 which is \ + according to the assumption that model is 4-bit quantized.") + if hidden_size % layer.group_size != 0: + raise ValueError(f"The value of hidden_size={hidden_size} is not divisible by \ + K_proj.group_size={layer.group_size}") + + # Duplication of quantized weights + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 + ).view(hidden_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantized zero points + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), + repeat, + 1, + ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantization scales + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), + repeat, + 1, + ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) + layer.out_features = layer.out_features * repeat + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + + def main(args): - # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU. - replace_transformers_quantizers() # Load the model and tokenizer model_name = args.model_name model_base_name = model_name.split("/")[-1] + # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU. + replace_transformers_quantizers() model = AutoModelForCausalLM.from_pretrained( - model_name, # num_hidden_layers=1, + model_name, + num_hidden_layers=1, attn_implementation="eager", ) - + # Undo the effect of replace_transformers_quantizers + undo_transformers_quantizers() tokenizer = AutoTokenizer.from_pretrained(model_name) inputs = tokenizer(args.prompt, return_tensors="pt") @@ -49,93 +86,8 @@ def main(args): attn = block.self_attn attn.num_key_value_heads = new_kv_heads attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads - k_proj = attn.k_proj - v_proj = attn.v_proj - if isinstance(attn.k_proj, (WQLinear_GEMM, QuantLinearGPTQ)): - if attn.head_dim % 8 != 0: - raise ValueError(f"the value attn.head_dim={attn.head_dim} is not divisible by 8 which is \ - according to the assumption that model is 4-bit quantized.") - if attn.hidden_size % k_proj.group_size != 0: - raise ValueError(f"The value of attn.hidden_size={attn.hidden_size} is not divisible by \ - K_proj.group_size={k_proj.group_size}") - - # Key projection duplication - # Duplication of quantized weights - k_proj.qweight.data = ( - torch.repeat_interleave( - k_proj.qweight.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size), repeat, 0 - ) - .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size) - .T - ) - # Duplication of quantized zero points - k_proj.qzeros.data = ( - torch.repeat_interleave( - k_proj.qzeros.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size // k_proj.group_size), - repeat, - 0, - ) - .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size // k_proj.group_size) - .T - ) - # Duplication of quantization scales - k_proj.scales.data = ( - torch.repeat_interleave( - k_proj.scales.data.T.view(orig_kv_heads, attn.head_dim, attn.hidden_size // k_proj.group_size), - repeat, - 0, - ) - .view(new_kv_heads * attn.head_dim, attn.hidden_size // k_proj.group_size) - .T - ) - k_proj.out_features = k_proj.out_features * repeat - else: - attn.k_proj.weight.data = torch.repeat_interleave( - attn.k_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 - ).view(new_kv_heads * attn.head_dim, attn.hidden_size) - - if isinstance(v_proj, (WQLinear_GEMM, QuantLinearGPTQ)): - if attn.head_dim % 8 != 0: - raise ValueError(f"the value attn.head_dim={attn.head_dim} is not divisible by 8 which is \ - according to the assumption that model is 4-bit quantized.") - if attn.hidden_size % v_proj.group_size: - raise ValueError(f"The value of attn.hidden_size={attn.hidden_size} is not divisible by \ - v_proj.group_size = {v_proj.group_size}") - - # Value projection duplication - # Duplication of quantized weights - v_proj.qweight.data = ( - torch.repeat_interleave( - v_proj.qweight.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size), repeat, 0 - ) - .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size) - .T - ) - # Duplication of quantized zero points - v_proj.qzeros.data = ( - torch.repeat_interleave( - v_proj.qzeros.data.T.view(orig_kv_heads, attn.head_dim // 8, attn.hidden_size // v_proj.group_size), - repeat, - 0, - ) - .view((new_kv_heads * attn.head_dim) // 8, attn.hidden_size // v_proj.group_size) - .T - ) - # Duplication of quantization scales - v_proj.scales.data = ( - torch.repeat_interleave( - v_proj.scales.data.T.view(orig_kv_heads, attn.head_dim, attn.hidden_size // v_proj.group_size), - repeat, - 0, - ) - .view(new_kv_heads * attn.head_dim, attn.hidden_size // v_proj.group_size) - .T - ) - v_proj.out_features = v_proj.out_features * repeat - else: - attn.v_proj.weight.data = torch.repeat_interleave( - attn.v_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0 - ).view(new_kv_heads * attn.head_dim, attn.hidden_size) + duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size) + duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size) # Generate modified outputs and tokens with torch.inference_mode(): @@ -146,6 +98,11 @@ def main(args): print("Original:", tokenizer.batch_decode(orig_tokens)) print("Modified:", tokenizer.batch_decode(mod_tokens)) + if not torch.all(orig_tokens == mod_tokens): + raise RuntimeError( + "Something went wrong while duplicating KV heads weights, output token don't match after modification" + ) + # Export the modified model q_model = QEFFAutoModelForCausalLM(model, model_name) export( @@ -155,9 +112,6 @@ def main(args): onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", ) - # Undo the effect of replace_transformers_quantizers - undo_transformers_quantizers() - if __name__ == "__main__": # Set up argument parser