In [1]:
from utils.utils import QuantBlockConfig
from utils import utils 
from _transformers.src.transformers.models.gpt2.modeling_gpt2 import GPT2MLPQ, GPT2AttentionQ
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel
from utils import lora 
from transformers import GPT2Model
import torch


tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
QUANT_CONFIGS = {i: utils.QuantBlockConfig() for i in range(0, 12)}

precisions = ["test1", "test2"]
dict_configs = {"test1": {i: {"Attention_W_bit": 16, "Attention_A_bit": 8, "MLP_W_bit": 8, "MLP_A_bit": 8} for i in range(12)}, "test2": {i: {"Attention_W_bit": 4, "Attention_A_bit": 4, "MLP_W_bit": 4, "MLP_A_bit": 32} for i in range(12)}}

configs = {}
for k, v in dict_configs.items(): 
    conf = [QuantBlockConfig.from_dict(dict_configs[k][i]) for i in range(12)]
    quant_configs = {i: conf[i] for i in range(12)}
    configs[k] = quant_configs

utils.quantize_model(model, QUANT_CONFIGS)
lora.apply_lora_to_model(model, precisions, r=4, alpha=1.0)

utils.set_active_quant_config(QUANT_CONFIGS, configs["test1"])
lora.set_active_quant_config("test1")

print(model)
tokenizer.pad_token = tokenizer.eos_token
model_inputs = tokenizer(["My name is Pranav. I am from India. I am "], return_tensors="pt", padding=True).to(model.device)
print(model_inputs.input_ids.shape)
# Fix for when model.generate doesn't work: call model directly and use output ids from return value
# If the model is a causal LM, calling it returns a ModelOutput with `logits` (not generated sequences).
# To get generated ids, run with input_ids and use greedy decoding manually.
with torch.no_grad():
    generated_ids = model.generate(
        input_ids=model_inputs.input_ids,
        attention_mask=model_inputs.attention_mask,
        max_new_tokens=30,
        do_sample=False,        # greedy decoding
        pad_token_id=tokenizer.eos_token_id
    )


print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])

utils.set_active_quant_config(QUANT_CONFIGS, configs["test2"])
lora.set_active_quant_config("test2")

with torch.no_grad():
    generated_ids = model.generate(
        input_ids=model_inputs.input_ids,
        attention_mask=model_inputs.attention_mask,
        max_new_tokens=30,
        do_sample=False,        # greedy decoding
        pad_token_id=tokenizer.eos_token_id
    )


print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])

  from .autonotebook import tqdm as notebook_tqdm


ðŸš¨ `quant_configs` is part of GPT2Model.__init__'s signature, but not documented. Make sure to add it to the docstring of the function in /Users/pranavponnusamy/Documents/EICTest/_transformers/src/transformers/models/gpt2/modeling_gpt2.py.
None
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2AttentionQ(
          (c_attn): LoraAttentionKV(
            (layer): QuantLinear()
            (lora_A_k): ParameterDict(
                (test1): Parameter containing: [torch.FloatTensor of size 768x4]
                (test2): Parameter containing: [torch.FloatTensor of size 768x4]
            )
            (lora_B_k): ParameterDict(
                (test1): Parameter containing: [torch.FloatTensor of size 4x768]
                (test2): Parameter contai

In [2]:
# ============ Load Dataset ============
from datasets import load_dataset
from torch.utils.data import DataLoader

def collate_fn(batch):
    """Custom collate for SQuAD dataset"""
    return {
        'context': [item['context'] for item in batch],
        'question': [item['question'] for item in batch],
        'answers': [item['answers']['text'][0] if item['answers']['text'] else "" for item in batch]
    }

ds = load_dataset("rajpurkar/squad")
train_data = ds['train'].select(range(1000))  # Use subset for testing

train_loader = DataLoader(
    train_data, 
    batch_size=4, 
    shuffle=True,
    collate_fn=collate_fn
)

print(f"Training samples: {len(train_data)}")
print(f"Batches per epoch: {len(train_loader)}")


Training samples: 1000
Batches per epoch: 250


In [3]:
def create_masked_labels(tokenizer, contexts, questions, answers, max_length=512):
    """
    Create input_ids and labels where only answer tokens contribute to loss.
    Labels use -100 for masked positions (ignored by CrossEntropyLoss).
    """
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    
    for context, question, answer in zip(contexts, questions, answers):
        # Build the full prompt
        prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
        full_text = f"{prompt} {answer}{tokenizer.eos_token}"
        
        # Tokenize prompt (context + question) separately to get its length
        prompt_tokens = tokenizer(prompt, add_special_tokens=True)
        prompt_length = len(prompt_tokens.input_ids)
        
        # Tokenize full sequence
        full_tokens = tokenizer(
            full_text,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = full_tokens.input_ids.squeeze(0)
        attention_mask = full_tokens.attention_mask.squeeze(0)
        
        # Create labels: -100 for prompt, actual token ids for answer
        labels = input_ids.clone()
        labels[:prompt_length] = -100  # Mask context + question
        
        # Also mask padding tokens
        labels[attention_mask == 0] = -100
        
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)
    
    return {
        'input_ids': torch.stack(input_ids_list),
        'attention_mask': torch.stack(attention_mask_list),
        'labels': torch.stack(labels_list)
    }

In [None]:
from utils import lora
import torch

# Your precision configs
precisions = ["test1", "test2"]
num_epochs = 1

# Optional: weight each config's contribution to the loss
loss_scale = {
    "test1": 1.0,
    "test2": 1.0,
}

# Get ALL LoRA parameters (for all configs)
lora_params = [p for n, p in model.named_parameters() if 'lora_' in n]
for p in model.parameters():
    p.requires_grad = False
for p in lora_params:
    p.requires_grad = True

optimizer = torch.optim.AdamW(lora_params, lr=1e-4)

model.train()
for epoch in range(num_epochs):
    for batch in train_loader:
        # Prepare inputs (same for all configs)
        batch_data = create_masked_labels(
            tokenizer,
            batch['context'],
            batch['question'],
            batch['answers']
        )
        input_ids = batch_data['input_ids'].to(model.device)
        attention_mask = batch_data['attention_mask'].to(model.device)
        labels = batch_data['labels'].to(model.device)
        
        optimizer.zero_grad()  # Zero grads once at the start
        
        loss_values = {}
        
        # Forward + backward for EACH precision config
        for precision in precisions:
            utils.set_active_quant_config(QUANT_CONFIGS, configs[precision])
            lora.set_active_quant_config(precision)
                        
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Scale the loss for this config
            loss = outputs.loss * loss_scale[precision]
            
            # Accumulate gradients (don't step yet!)
            loss.backward()
            
            loss_values[precision] = outputs.loss.item()
            
            # Free memory
            del outputs, loss
        
        # Gradient clipping (optional but recommended)
        # torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
        
        # Single optimizer step after all configs
        optimizer.step()
        
        # Logging
        loss_str = " | ".join([f"{p}: {v:.4f}" for p, v in loss_values.items()])
        print(f"Losses: {loss_str}")

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Losses: test1: 5.2204 | test2: 10.0681
Losses: test1: 4.7503 | test2: 9.8488
Losses: test1: 3.0438 | test2: 8.5001
Losses: test1: 4.3545 | test2: 8.3515
Losses: test1: 4.7495 | test2: 9.0765
Losses: test1: 3.6963 | test2: 8.6566
Losses: test1: 3.0536 | test2: 8.5534
Losses: test1: 3.8697 | test2: 8.9160
Losses: test1: 3.1377 | test2: 8.3708
Losses: test1: 1.9212 | test2: 9.4158
Losses: test1: 2.9648 | test2: 7.8286
Losses: test1: 2.7820 | test2: 8.4585
Losses: test1: 3.5794 | test2: 8.4243
Losses: test1: 3.2107 | test2: 8.3756
Losses: test1: 2.1760 | test2: 8.8208
Losses: test1: 2.7039 | test2: 9.0153
Losses: test1: 3.1041 | test2: 8.1269
Losses: test1: 2.3237 | test2: 8.8550
Losses: test1: 2.1907 | test2: 7.3412
