In [1]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizerFast,
    AutoModelForQuestionAnswering,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from evaluate import load

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
dataset = load_dataset('rajpurkar/squad')
split = dataset['validation'].train_test_split(test_size=0.5, seed=42)
raw_datasets = {
    'train': dataset['train'],
    'validation': split['train'],
    'test': split['test']
}

In [3]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
max_length = 384
doc_stride = 128

In [4]:
def prepare_features(examples):
    tokenized = tokenizer(
        examples['question'], examples['context'],
        truncation='only_second',
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length'
    )

    overflow_to_sample_mapping = tokenized.pop('overflow_to_sample_mapping')
    offset_mapping = tokenized.pop('offset_mapping')

    tokenized['start_positions'] = []
    tokenized['end_positions'] = []

    for i, offsets in enumerate(offset_mapping):
        sample_idx = overflow_to_sample_mapping[i]
        answers = examples['answers'][sample_idx]
        cls_index = tokenized['input_ids'][i].index(tokenizer.cls_token_id)

        if len(answers['answer_start']) == 0:
            tokenized['start_positions'].append(cls_index)
            tokenized['end_positions'].append(cls_index)
        else:
            start_char = answers['answer_start'][0]
            end_char = start_char + len(answers['text'][0])
            sequence_ids = tokenized.sequence_ids(i)

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1
            token_end_index = len(tokenized['input_ids'][i]) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if offsets[token_start_index][0] > end_char or offsets[token_end_index][1] < start_char:
                tokenized['start_positions'].append(cls_index)
                tokenized['end_positions'].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized['start_positions'].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized['end_positions'].append(token_end_index + 1)

    return tokenized

In [5]:
tokenized_datasets = {k: raw_datasets[k].map(
    prepare_features,
    batched=True,
    remove_columns=raw_datasets[k].column_names
) for k in ['train', 'validation', 'test']}

Map:   0%|          | 0/5285 [00:00<?, ? examples/s]

In [6]:
class QADataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        row = self.dataset[idx]
        return {
            'input_ids': torch.tensor(row['input_ids']),
            'attention_mask': torch.tensor(row['attention_mask']),
            'start_positions': torch.tensor(row['start_positions']),
            'end_positions': torch.tensor(row['end_positions'])
        }

In [7]:
train_dataset = QADataset(tokenized_datasets['train'])
val_dataset   = QADataset(tokenized_datasets['validation'])
test_dataset  = QADataset(tokenized_datasets['test'])

data_collator = DataCollatorWithPadding(tokenizer)
batch_size = 8

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=data_collator)

In [8]:
base_model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')
lora_config = LoraConfig(
    task_type=TaskType.QUESTION_ANS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['query', 'value', 'dense']
)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
model = get_peft_model(base_model, lora_config)
for name, param in model.named_parameters():
    if 'lora_' not in name:
        param.requires_grad = False
model.to(device)

PeftModelForQuestionAnswering(
  (base_model): LoraModel(
    (model): BertForQuestionAnswering(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_epochs = 3
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//10, num_training_steps=total_steps)
scaler = torch.cuda.amp.GradScaler()
criterion = nn.CrossEntropyLoss()
metric = load("squad")

  scaler = torch.cuda.amp.GradScaler()


In [11]:
def train_epoch(epoch):
    model.train()
    running_loss = 0.0
    for step, batch in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_pos = batch['start_positions'].to(device)
        end_pos = batch['end_positions'].to(device)

        with torch.cuda.amp.autocast():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            start_logits, end_logits = outputs.start_logits, outputs.end_logits
            loss = (criterion(start_logits, start_pos) + criterion(end_logits, end_pos)) / 2

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        running_loss += loss.item()

        if step % 100 == 0:
            print(f"Epoch {epoch} Step {step}/{len(train_loader)} - Loss: {running_loss/step:.4f}")

    return running_loss / len(train_loader)

In [12]:
@torch.no_grad()
def eval_epoch(data_loader, raw_data):
    model.eval()
    eval_loss = 0.0
    predictions = []
    references = []

    for i, batch in enumerate(data_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_pos = batch['start_positions'].to(device)
        end_pos = batch['end_positions'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_logits, end_logits = outputs.start_logits, outputs.end_logits

        eval_loss += ((criterion(start_logits, start_pos) + criterion(end_logits, end_pos)) / 2).item()

        for b in range(len(input_ids)):
            start_idx = torch.argmax(start_logits[b]).item()
            end_idx = torch.argmax(end_logits[b]).item()
            if start_idx > end_idx:
                answer = ""
            else:
                tokens = input_ids[b][start_idx:end_idx+1]
                answer = tokenizer.decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)

            sample_index = i * data_loader.batch_size + b
            if sample_index < len(raw_data):
                predictions.append({"id": raw_data[sample_index]["id"], "prediction_text": answer})
                references.append({"id": raw_data[sample_index]["id"], "answers": raw_data[sample_index]["answers"]})

    metrics = metric.compute(predictions=predictions, references=references)
    return eval_loss / len(data_loader), metrics

In [None]:
for epoch in range(1, num_epochs + 1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")
    train_loss = train_epoch(epoch)
    print(f"Train Loss: {train_loss:.4f}")
    val_loss, val_metrics = eval_epoch(val_loader, raw_datasets['validation'])
    print(f"Val Loss: {val_loss:.4f} | EM: {val_metrics['exact_match']:.2f} | F1: {val_metrics['f1']:.2f}")


=== Epoch 1/3 ===
Epoch 1 Step 100/11066 - Loss: 5.6841
Epoch 1 Step 200/11066 - Loss: 5.5798
Epoch 1 Step 300/11066 - Loss: 5.5353
Epoch 1 Step 400/11066 - Loss: 5.4284
Epoch 1 Step 500/11066 - Loss: 5.3769
Epoch 1 Step 600/11066 - Loss: 5.2902
Epoch 1 Step 700/11066 - Loss: 5.1808
Epoch 1 Step 800/11066 - Loss: 5.1156
Epoch 1 Step 900/11066 - Loss: 5.0580
Epoch 1 Step 1000/11066 - Loss: 4.9739
Epoch 1 Step 1100/11066 - Loss: 4.9443
Epoch 1 Step 1200/11066 - Loss: 4.9132
Epoch 1 Step 1300/11066 - Loss: 4.8123
Epoch 1 Step 1400/11066 - Loss: 4.7338
Epoch 1 Step 1500/11066 - Loss: 4.6851
Epoch 1 Step 1600/11066 - Loss: 4.6096
Epoch 1 Step 1700/11066 - Loss: 4.5711
Epoch 1 Step 1800/11066 - Loss: 4.5051
Epoch 1 Step 1900/11066 - Loss: 4.4612
Epoch 1 Step 2000/11066 - Loss: 4.4023
Epoch 1 Step 2100/11066 - Loss: 4.3590
Epoch 1 Step 2200/11066 - Loss: 4.2943
Epoch 1 Step 2300/11066 - Loss: 4.2329
Epoch 1 Step 2400/11066 - Loss: 4.1827
Epoch 1 Step 2500/11066 - Loss: 4.1066
Epoch 1 Step 26

In [None]:
print("\n=== Test Performance ===")
test_loss, test_metrics = eval_epoch(test_loader, raw_datasets['test'])
print(f"Test Loss: {test_loss:.4f} | EM: {test_metrics['exact_match']:.2f} | F1: {test_metrics['f1']:.2f}")


=== Test Performance ===
Test Loss: 0.8653 | EM: 62.11 | F1: 62.58
