Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions QEfficient/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
# ----------------------------------------------------------------------------

from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING
from transformers.quantizers.quantizer_awq import AwqQuantizer
from transformers.quantizers.quantizer_gptq import GptqHfQuantizer
from transformers.utils.quantization_config import AwqConfig, GPTQConfig

from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer

QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer}
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig, "gptq": QEffGPTQConfig}
DUPLICATE_AUTO_QUANTIZER_MAPPING = {"awq": AwqQuantizer, "gptq": GptqHfQuantizer}
DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": AwqConfig, "gptq": GPTQConfig}


def with_replaced_quantizers(func):
Expand Down Expand Up @@ -39,3 +44,26 @@ def wrapper(*args, **kwargs):
return out

return wrapper


def replace_transformers_quantizers():
"""
This method lets you import AWQ/GPTQ models on CPU without bypassing the
rule of transformers of need to GPU.
Just call this method before using
`transformer.AutoModelForCausalLM.from_pretrained` and any AWQ/GPTQ model
that can be supported by QEfficient will be loaded using CPU.
"""
AUTO_QUANTIZER_MAPPING.update(QEFF_AUTO_QUANTIZER_MAPPING)
AUTO_QUANTIZATION_CONFIG_MAPPING.update(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING)


# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers.
# When a user imports QEfficient, these are always available.
def undo_transformers_quantizers():
"""
This method is used to undo the effects on method `replace_transformers_quantizers`.
After this is called, the transformers library will be used for loading AWQ/GPTQ models.
"""
AUTO_QUANTIZER_MAPPING.update(DUPLICATE_AUTO_QUANTIZER_MAPPING)
AUTO_QUANTIZATION_CONFIG_MAPPING.update(DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING)
2 changes: 1 addition & 1 deletion QEfficient/transformers/quantizers/quantizer_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# -----------------------------------------------------------------------------

import torch
from transformers.quantizers.quantizer_gptq import HfQuantizer
from transformers.quantizers import HfQuantizer
from transformers.utils.quantization_config import GPTQConfig

from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#
# -----------------------------------------------------------------------------

from QEfficient.transformers.quantizers.auto import ( # noqa: F401
replace_transformers_quantizers,
undo_transformers_quantizers,
)
from QEfficient.utils._utils import ( # noqa: F401
check_and_assign_cache_dir,
get_num_layers_from_config,
Expand Down
59 changes: 51 additions & 8 deletions scripts/replicate_kv_head/replicate_kv_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,59 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM, export
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ


def duplicate_weights_for_linear_layer(
layer: torch.nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
):
new_kv_heads = repeat * orig_kv_heads
if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)):
if head_dim % 8 != 0:
raise ValueError(f"the value head_dim={head_dim} is not divisible by 8 which is \
according to the assumption that model is 4-bit quantized.")
if hidden_size % layer.group_size != 0:
raise ValueError(f"The value of hidden_size={hidden_size} is not divisible by \
K_proj.group_size={layer.group_size}")

# Duplication of quantized weights
layer.qweight.data = torch.repeat_interleave(
layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1
).view(hidden_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantized zero points
layer.qzeros.data = torch.repeat_interleave(
layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8),
repeat,
1,
).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantization scales
layer.scales.data = torch.repeat_interleave(
layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim),
repeat,
1,
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
layer.out_features = layer.out_features * repeat
else:
layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)


def main(args):
# Load the model and tokenizer
model_name = args.model_name
model_base_name = model_name.split("/")[-1]
# Replace quantizers for loading Quantized AWQ/GPTQ models on CPU.
replace_transformers_quantizers()
model = AutoModelForCausalLM.from_pretrained(
model_name, # num_hidden_layers=2,
model_name,
num_hidden_layers=1,
attn_implementation="eager",
)

# Undo the effect of replace_transformers_quantizers
undo_transformers_quantizers()
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(args.prompt, return_tensors="pt")

Expand All @@ -44,12 +86,8 @@ def main(args):
attn = block.self_attn
attn.num_key_value_heads = new_kv_heads
attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads
attn.k_proj.weight.data = torch.repeat_interleave(
attn.k_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0
).view(new_kv_heads * attn.head_dim, attn.hidden_size)
attn.v_proj.weight.data = torch.repeat_interleave(
attn.v_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0
).view(new_kv_heads * attn.head_dim, attn.hidden_size)
duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)
duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)

# Generate modified outputs and tokens
with torch.inference_mode():
Expand All @@ -60,6 +98,11 @@ def main(args):
print("Original:", tokenizer.batch_decode(orig_tokens))
print("Modified:", tokenizer.batch_decode(mod_tokens))

if not torch.all(orig_tokens == mod_tokens):
raise RuntimeError(
"Something went wrong while duplicating KV heads weights, output token don't match after modification"
)

# Export the modified model
q_model = QEFFAutoModelForCausalLM(model, model_name)
export(
Expand Down
11 changes: 1 addition & 10 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import numpy as np
import pytest
from transformers import AutoModelForCausalLM
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING

from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
from QEfficient.utils import hf_download
from QEfficient.utils._utils import load_hf_tokenizer
from QEfficient.utils.constants import Constants
Expand Down Expand Up @@ -41,13 +39,6 @@
]


# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers.
# When a user imports QEfficient, these are always available.
def replace_transformers_quantizers():
AUTO_QUANTIZER_MAPPING.update({"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer})
AUTO_QUANTIZATION_CONFIG_MAPPING.update({"awq": QEffAwqConfig, "gptq": QEffGPTQConfig})


def load_causal_lm_model(model_config):
"""
Function to load model from huggingface and transform to KV model
Expand Down
Loading