In [8]:
from utils import utils
from _transformers.src.transformers.models.gpt2.modeling_gpt2 import GPT2MLPQ, GPT2AttentionQ
from transformers import AutoModelForCausalLM, AutoTokenizer
import re

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

configs_dict = {i: {"Attention_W_bit": 8, "Attention_A_bit": 8, "MLP_W_bit": 8, "MLP_A_bit": 8} for i in range(12)}

quant_configs = {i: utils.QuantBlockConfig.from_dict(configs_dict[i]) for i in range(0, 12)}


for name, module in list(model.named_modules()):
    class_name = type(module).__name__
    
    parent = model
    name_parts = name.split('.')
    for part in name_parts[:-1]:
        parent = getattr(parent, part)
    attr_name = name_parts[-1]
    
    # Get block index from name_parts: ['transformer', 'h', '<block_idx>', ...]
    if len(name_parts) >= 3 and name_parts[0] == 'transformer' and name_parts[1] == 'h':
        block_idx = int(name_parts[2])
        quant_config = quant_configs[block_idx]
    else:
        continue 
    
    if class_name == "GPT2MLP":
        new_module = GPT2MLPQ(
            intermediate_size=module.c_fc.weight.shape[1],
            config=model.config,
            quant_config=quant_config
        )

        new_module.load_state_dict(module.state_dict(), strict=False)
        setattr(parent, attr_name, new_module)
    
    elif class_name == "GPT2Attention":
        new_module = GPT2AttentionQ(
            config=model.config,
            is_cross_attention=module.is_cross_attention,
            layer_idx=module.layer_idx,
            quant_config=quant_config
        )
        new_module.load_state_dict(module.state_dict(), strict=False)
        setattr(parent, attr_name, new_module)

tokenizer.pad_token = tokenizer.eos_token
model_inputs = tokenizer(["The secret to baking a good cake is ", "What is the meaning of life? "], return_tensors="pt", padding=True).to(model.device)
print(model_inputs.input_ids.shape)
generated_ids = model.generate(**model_inputs, max_length=30, do_sample = False)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


torch.Size([2, 9])


'The secret to baking a good cake is \xa0to baking a good cake.\nThe secret to baking a good cake is to baking a good cake'