# Question Answering extraction


## –®–∞–≥ 1: –ò–º–ø–æ—Ä—Ç—ã –∏ –Ω–∞—Å—Ç—Ä–æ–π–∫–∞


In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from datasets import load_dataset
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


PyTorch version: 2.5.1+cu121
CUDA available: True


## –®–∞–≥ 2: –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞—Ç–∞—Å–µ—Ç–∞ SQuAD

SQuAD (Stanford Question Answering Dataset) - —ç—Ç–æ –¥–∞—Ç–∞—Å–µ—Ç –¥–ª—è –∑–∞–¥–∞—á–∏ —á—Ç–µ–Ω–∏—è –∏ –ø–æ–Ω–∏–º–∞–Ω–∏—è —Ç–µ–∫—Å—Ç–∞, –≥–¥–µ –º–æ–¥–µ–ª—å –¥–æ–ª–∂–Ω–∞ –Ω–∞–π—Ç–∏ –æ—Ç–≤–µ—Ç –Ω–∞ –≤–æ–ø—Ä–æ—Å –≤ –∑–∞–¥–∞–Ω–Ω–æ–º –∫–æ–Ω—Ç–µ–∫—Å—Ç–µ.

–ö–∞–∂–¥—ã–π –ø—Ä–∏–º–µ—Ä —Å–æ–¥–µ—Ä–∂–∏—Ç:
- **context**: –ø–∞—Ä–∞–≥—Ä–∞—Ñ —Ç–µ–∫—Å—Ç–∞, –∏–∑ –∫–æ—Ç–æ—Ä–æ–≥–æ –Ω—É–∂–Ω–æ –Ω–∞–π—Ç–∏ –æ—Ç–≤–µ—Ç
- **question**: –≤–æ–ø—Ä–æ—Å
- **answers**: —Å–ø–∏—Å–æ–∫ –≤–æ–∑–º–æ–∂–Ω—ã—Ö –æ—Ç–≤–µ—Ç–æ–≤ —Å –∏—Ö –ø–æ–∑–∏—Ü–∏—è–º–∏ –≤ —Ç–µ–∫—Å—Ç–µ (start, end)


In [2]:
print("Loading SQuAD dataset...")
dataset = load_dataset("squad")

print(f"\nDataset structure:")
print(dataset)

print(f"\nTrain set size: {len(dataset['train'])}")
print(f"Validation set size: {len(dataset['validation'])}")

print("\nSample example:")
example = dataset['train'][0]
print(f"Context: {example['context'][:200]}...")
print(f"\nQuestion: {example['question']}")
print(f"\nAnswers: {example['answers']}")
print(f"\nAnswer text: {example['answers']['text'][0]}")
print(f"Answer start: {example['answers']['answer_start'][0]}")


Loading SQuAD dataset...

Dataset structure:
DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

Train set size: 87599
Validation set size: 10570

Sample example:
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper sta...

Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?

Answers: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

Answer text: Saint Bernadette Soubirous
Answer start: 515


## –®–∞–≥ 3: –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö –≤ –ø–∞–º—è—Ç—å

–ó–∞–≥—Ä—É–∂–∞–µ–º –≤—Å–µ –¥–∞–Ω–Ω—ã–µ –∏–∑ –¥–∞—Ç–∞—Å–µ—Ç–∞ –≤ –ø–∞–º—è—Ç—å –¥–ª—è —É–¥–æ–±–Ω–æ–π —Ä–∞–±–æ—Ç—ã.


In [3]:
train_data = dataset['train']
val_data = dataset['validation']

train_contexts = train_data['context']
train_questions = train_data['question']
train_answers = train_data['answers']

val_contexts = val_data['context']
val_questions = val_data['question']
val_answers = val_data['answers']

print(f"Loaded {len(train_contexts)} training examples")
print(f"Loaded {len(val_contexts)} validation examples")

print(f"\nExample from training set:")
print(f"Context length: {len(train_contexts[0])} characters")
print(f"Question: {train_questions[0]}")
print(f"Answer text: {train_answers[0]['text'][0]}")
print(f"Answer start position: {train_answers[0]['answer_start'][0]}")
print(f"Answer end position: {train_answers[0]['answer_start'][0] + len(train_answers[0]['text'][0])}")


Loaded 87599 training examples
Loaded 10570 validation examples

Example from training set:
Context length: 695 characters
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer text: Saint Bernadette Soubirous
Answer start position: 515
Answer end position: 541


## –®–∞–≥ 4: –°–æ–∑–¥–∞–Ω–∏–µ PyTorch Dataset –æ–±–µ—Ä—Ç–∫–∏

–°–æ–∑–¥–∞–µ–º –∫–ª–∞—Å—Å `QADataset`, –∫–æ—Ç–æ—Ä—ã–π –Ω–∞—Å–ª–µ–¥—É–µ—Ç—Å—è –æ—Ç `torch.utils.data.Dataset` –∏ –ø—Ä–µ–¥–æ—Å—Ç–∞–≤–ª—è–µ—Ç —É–¥–æ–±–Ω—ã–π –∏–Ω—Ç–µ—Ä—Ñ–µ–π—Å –¥–ª—è —Ä–∞–±–æ—Ç—ã —Å QA –¥–∞–Ω–Ω—ã–º–∏.

–ö–∞–∂–¥—ã–π —ç–ª–µ–º–µ–Ω—Ç –¥–∞—Ç–∞—Å–µ—Ç–∞ –±—É–¥–µ—Ç —Å–æ–¥–µ—Ä–∂–∞—Ç—å:
- context: —Ç–µ–∫—Å—Ç –∫–æ–Ω—Ç–µ–∫—Å—Ç–∞
- question: –≤–æ–ø—Ä–æ—Å
- answer_text: —Ç–µ–∫—Å—Ç –æ—Ç–≤–µ—Ç–∞
- answer_start: –Ω–∞—á–∞–ª—å–Ω–∞—è –ø–æ–∑–∏—Ü–∏—è –æ—Ç–≤–µ—Ç–∞ –≤ –∫–æ–Ω—Ç–µ–∫—Å—Ç–µ
- answer_end: –∫–æ–Ω–µ—á–Ω–∞—è –ø–æ–∑–∏—Ü–∏—è –æ—Ç–≤–µ—Ç–∞ –≤ –∫–æ–Ω—Ç–µ–∫—Å—Ç–µ


In [4]:
class QADataset(Dataset):
    
    def __init__(self, contexts, questions, answers):
        self.contexts = contexts
        self.questions = questions
        self.answers = answers
        
        self.answer_texts = []
        self.answer_starts = []
        self.answer_ends = []
        
        for answer_dict in answers:
            answer_text = answer_dict['text'][0]
            answer_start = answer_dict['answer_start'][0]
            answer_end = answer_start + len(answer_text)
            
            self.answer_texts.append(answer_text)
            self.answer_starts.append(answer_start)
            self.answer_ends.append(answer_end)
    
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, idx):
        return {
            'context': self.contexts[idx],
            'question': self.questions[idx],
            'answer_text': self.answer_texts[idx],
            'answer_start': self.answer_starts[idx],
            'answer_end': self.answer_ends[idx]
        }


## –®–∞–≥ 5: –°–æ–∑–¥–∞–Ω–∏–µ —ç–∫–∑–µ–º–ø–ª—è—Ä–æ–≤ –¥–∞—Ç–∞—Å–µ—Ç–æ–≤

–°–æ–∑–¥–∞–µ–º train –∏ validation –¥–∞—Ç–∞—Å–µ—Ç—ã –∏—Å–ø–æ–ª—å–∑—É—è –Ω–∞—à—É –æ–±–µ—Ä—Ç–∫—É.


In [5]:
train_dataset = QADataset(train_contexts, train_questions, train_answers)
val_dataset = QADataset(val_contexts, val_questions, val_answers)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

sample = train_dataset[0]
print(f"\nSample from dataset:")
print(f"Context: {sample['context'][:150]}...")
print(f"Question: {sample['question']}")
print(f"Answer text: {sample['answer_text']}")
print(f"Answer start: {sample['answer_start']}")
print(f"Answer end: {sample['answer_end']}")

context_slice = sample['context'][sample['answer_start']:sample['answer_end']]
print(f"\nVerification - context slice at answer position: '{context_slice}'")
print(f"Matches answer text: {context_slice == sample['answer_text']}")


Train dataset size: 87599
Validation dataset size: 10570

Sample from dataset:
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front o...
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer text: Saint Bernadette Soubirous
Answer start: 515
Answer end: 541

Verification - context slice at answer position: 'Saint Bernadette Soubirous'
Matches answer text: True


## –®–∞–≥ 6: –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ DataLoader

–°–æ–∑–¥–∞–µ–º DataLoader –¥–ª—è –±–∞—Ç—á–µ–≤–æ–π –æ–±—Ä–∞–±–æ—Ç–∫–∏ –¥–∞–Ω–Ω—ã—Ö.


In [6]:
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0
)

print("Testing DataLoader...")
for i, batch in enumerate(train_loader):
    print(f"\nBatch {i+1}:")
    print(f"  Batch size: {len(batch['context'])}")
    print(f"  Context lengths: {[len(ctx) for ctx in batch['context']]}")
    print(f"  Questions: {batch['question']}")
    print(f"  Answer texts: {batch['answer_text']}")
    if i >= 1:
        break

print("\nDataLoader —Ä–∞–±–æ—Ç–∞–µ—Ç –∫–æ—Ä—Ä–µ–∫—Ç–Ω–æ!")


Testing DataLoader...

Batch 1:
  Batch size: 4
  Context lengths: [687, 905, 849, 253]
  Questions: ['The royal courts sponsored both Buddhism and what?', 'What was the controversial domestic surveillance operation in this era?', 'How can religious beliefs contribute to a person remaining in pain?', 'Aside from the koofiyad, what do Somali men wear on their head?']
  Answer texts: ['Saivism', 'COINTELPRO', 'prevent the individual from seeking help', 'turban']

Batch 2:
  Batch size: 4
  Context lengths: [526, 867, 817, 768]
  Questions: ['Where is the Gold State Coach kept?', 'How is DNA grouping superior?', 'When was Montini absent from the conclave?', 'When were women first admitted to Northwestern?']
  Answer texts: ['the Royal Mews', 'genetic code itself is used', '1958', '1869']

DataLoader —Ä–∞–±–æ—Ç–∞–µ—Ç –∫–æ—Ä—Ä–µ–∫—Ç–Ω–æ!


## –®–∞–≥ 7: –ó–∞–≥—Ä—É–∑–∫–∞ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä–∞ BERT

–î–ª—è —Ä–∞–±–æ—Ç—ã —Å BERT –Ω—É–∂–Ω–æ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å —Å–æ–æ—Ç–≤–µ—Ç—Å—Ç–≤—É—é—â–∏–π —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä, –∫–æ—Ç–æ—Ä—ã–π –ø—Ä–µ–æ–±—Ä–∞–∑—É–µ—Ç —Ç–µ–∫—Å—Ç –≤ —Ç–æ–∫–µ–Ω—ã –∏ –¥–æ–±–∞–≤–ª—è–µ—Ç —Å–ø–µ—Ü–∏–∞–ª—å–Ω—ã–µ —Ç–æ–∫–µ–Ω—ã [CLS] –∏ [SEP].


In [7]:
from transformers import BertTokenizerFast

model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(model_name)

print(f"Tokenizer loaded: {model_name}")
print(f"Vocab size: {tokenizer.vocab_size}")

sample_text = "Hello, this is a test."
encoded = tokenizer(sample_text, return_tensors='pt', padding=True, truncation=True)
print(f"\nSample encoding:")
print(f"  Input IDs shape: {encoded['input_ids'].shape}")
print(f"  Input IDs: {encoded['input_ids']}")
print(f"  Decoded: {tokenizer.decode(encoded['input_ids'][0])}")


Tokenizer loaded: bert-base-uncased
Vocab size: 30522

Sample encoding:
  Input IDs shape: torch.Size([1, 9])
  Input IDs: tensor([[ 101, 7592, 1010, 2023, 2003, 1037, 3231, 1012,  102]])
  Decoded: [CLS] hello, this is a test. [SEP]


## –®–∞–≥ 8: –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ Dataset —Å —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏–µ–π

–ù—É–∂–Ω–æ –æ–±–Ω–æ–≤–∏—Ç—å Dataset –∫–ª–∞—Å—Å, —á—Ç–æ–±—ã –æ–Ω —Ç–æ–∫–µ–Ω–∏–∑–∏—Ä–æ–≤–∞–ª –¥–∞–Ω–Ω—ã–µ –∏ –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤—ã–≤–∞–ª –ø–æ–∑–∏—Ü–∏–∏ —Å–∏–º–≤–æ–ª–æ–≤ –≤ –ø–æ–∑–∏—Ü–∏–∏ —Ç–æ–∫–µ–Ω–æ–≤. –≠—Ç–æ –∫—Ä–∏—Ç–∏—á–µ—Å–∫–∏ –≤–∞–∂–Ω–æ –¥–ª—è QA –∑–∞–¥–∞—á–∏.


In [8]:
class QATokenizedDataset(Dataset):
    
    def __init__(self, contexts, questions, answers, tokenizer, max_length=384):
        self.contexts = contexts
        self.questions = questions
        self.answers = answers
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.answer_texts = []
        self.answer_starts = []
        self.answer_ends = []
        
        for answer_dict in answers:
            answer_text = answer_dict['text'][0]
            answer_start = answer_dict['answer_start'][0]
            answer_end = answer_start + len(answer_text)
            
            self.answer_texts.append(answer_text)
            self.answer_starts.append(answer_start)
            self.answer_ends.append(answer_end)
    
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        answer_start_char = self.answer_starts[idx]
        answer_end_char = self.answer_ends[idx]
        
        encoded = self.tokenizer(
            question,
            context,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_offsets_mapping=True,
            return_tensors='pt'
        )
        
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        offset_mapping = encoded['offset_mapping'].squeeze(0)
        
        sep_token_id = self.tokenizer.sep_token_id
        sep_positions = (input_ids == sep_token_id).nonzero(as_tuple=True)[0]
        if len(sep_positions) > 0:
            context_start_token = sep_positions[0].item() + 1
        else:
            context_start_token = 1
        
        context_start_char = None
        for i in range(context_start_token, len(offset_mapping)):
            start_char, end_char = offset_mapping[i]
            if start_char != 0 or end_char != 0:
                context_start_char = start_char
                break
        
        if context_start_char is None:
            context_start_char = 0
        
        start_pos = 0
        end_pos = 0
        
        for i, (start_char, end_char) in enumerate(offset_mapping):
            if start_char == 0 and end_char == 0:
                continue
            if i < context_start_token:
                continue
            
            if start_char >= context_start_char:
                char_pos_in_context = start_char - context_start_char
                
                if start_pos == 0 and char_pos_in_context <= answer_start_char < end_char - context_start_char:
                    start_pos = i
                if char_pos_in_context < answer_end_char <= end_char - context_start_char:
                    end_pos = i
        
        if start_pos == 0 or end_pos == 0 or start_pos > end_pos:
            start_pos = context_start_token
            end_pos = min(context_start_token + 1, len(input_ids) - 1)
        
        start_pos = min(start_pos, len(input_ids) - 1)
        end_pos = min(end_pos, len(input_ids) - 1)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'start_positions': torch.tensor(start_pos, dtype=torch.long),
            'end_positions': torch.tensor(end_pos, dtype=torch.long),
            'answer_text': self.answer_texts[idx],
            'context': context,
            'question': question
        }


## –®–∞–≥ 9: –°–æ–∑–¥–∞–Ω–∏–µ —Ç–æ–∫–µ–Ω–∏–∑–∏—Ä–æ–≤–∞–Ω–Ω—ã—Ö –¥–∞—Ç–∞—Å–µ—Ç–æ–≤

–°–æ–∑–¥–∞–µ–º –Ω–æ–≤—ã–µ –¥–∞—Ç–∞—Å–µ—Ç—ã —Å —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏–µ–π. –î–ª—è —É—Å–∫–æ—Ä–µ–Ω–∏—è –º–æ–∂–Ω–æ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å –ø–æ–¥–≤—ã–±–æ—Ä–∫—É –¥–∞–Ω–Ω—ã—Ö.


In [9]:
train_size = 10000
val_size = 1000

train_tokenized_dataset = QATokenizedDataset(
    train_contexts[:train_size],
    train_questions[:train_size],
    train_answers[:train_size],
    tokenizer,
    max_length=384
)

val_tokenized_dataset = QATokenizedDataset(
    val_contexts[:val_size],
    val_questions[:val_size],
    val_answers[:val_size],
    tokenizer,
    max_length=384
)

print(f"Train tokenized dataset size: {len(train_tokenized_dataset)}")
print(f"Val tokenized dataset size: {len(val_tokenized_dataset)}")

sample = train_tokenized_dataset[0]
print(f"\nSample from tokenized dataset:")
print(f"Input IDs shape: {sample['input_ids'].shape}")
print(f"Start position: {sample['start_positions']}")
print(f"End position: {sample['end_positions']}")
print(f"Answer text: {sample['answer_text']}")


Train tokenized dataset size: 10000
Val tokenized dataset size: 1000

Sample from tokenized dataset:
Input IDs shape: torch.Size([384])
Start position: 130
End position: 137
Answer text: Saint Bernadette Soubirous


## –®–∞–≥ 10: –°–æ–∑–¥–∞–Ω–∏–µ –º–æ–¥–µ–ª–∏ QA –Ω–∞ BERT

–°–æ–∑–¥–∞–µ–º –±–∞–∑–æ–≤—É—é –º–æ–¥–µ–ª—å, –∫–æ—Ç–æ—Ä–∞—è –∏—Å–ø–æ–ª—å–∑—É–µ—Ç BERT –¥–ª—è –ø–æ–ª—É—á–µ–Ω–∏—è —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –∏ –¥–≤–∞ –ª–∏–Ω–µ–π–Ω—ã—Ö —Å–ª–æ—è –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è start –∏ end –ø–æ–∑–∏—Ü–∏–π –æ—Ç–≤–µ—Ç–∞.


In [10]:
from transformers import BertModel
import torch.nn as nn
import pytorch_lightning as pl

class QABertModel(nn.Module):
    
    def __init__(self, model_name='bert-base-uncased'):
        super(QABertModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.hidden_size = self.bert.config.hidden_size
        self.qa_outputs = nn.Linear(self.hidden_size, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        return start_logits, end_logits

print("Base QA model class created")


Base QA model class created


## –®–∞–≥ 11: Lightning –º–æ–¥—É–ª—å –¥–ª—è QA

–°–æ–∑–¥–∞–µ–º PyTorch Lightning –º–æ–¥—É–ª—å —Å training/validation —à–∞–≥–∞–º–∏ –∏ –º–µ—Ç—Ä–∏–∫–∞–º–∏ F1 –∏ Exact Match.


In [11]:
def normalize_answer(s):
    def remove_articles(text):
        return text.replace(" a ", " ").replace(" an ", " ").replace(" the ", " ")
    
    def white_space_fix(text):
        return " ".join(text.split())
    
    def remove_punc(text):
        import string
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    
    common = set(prediction_tokens) & set(ground_truth_tokens)
    
    if len(common) == 0:
        return 0
    
    precision = len(common) / len(prediction_tokens) if len(prediction_tokens) > 0 else 0
    recall = len(common) / len(ground_truth_tokens) if len(ground_truth_tokens) > 0 else 0
    
    if precision + recall == 0:
        return 0
    
    return 2 * precision * recall / (precision + recall)

def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

class QALightningModule(pl.LightningModule):
    
    def __init__(self, model_name='bert-base-uncased', lr=2e-5):
        super().__init__()
        self.save_hyperparameters()
        self.model = QABertModel(model_name)
        self.lr = lr
        self.tokenizer = tokenizer
        self.loss = nn.CrossEntropyLoss(ignore_index=ignored_index)
        
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask)
    
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        start_positions = batch['start_positions']
        end_positions = batch['end_positions']
        
        start_logits, end_logits = self(input_ids, attention_mask)
        
        seq_length = start_logits.size(1)
        ignored_index = seq_length
        start_positions = start_positions.clamp(0, seq_length - 1)
        end_positions = end_positions.clamp(0, seq_length - 1)
        
        
        start_loss = self.loss(start_logits, start_positions)
        end_loss = self.loss(end_logits, end_positions)
        loss = (start_loss + end_loss) / 2
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        start_positions = batch['start_positions']
        end_positions = batch['end_positions']
        answer_texts = batch['answer_text']
        
        start_logits, end_logits = self(input_ids, attention_mask)
        
        seq_length = start_logits.size(1)
        ignored_index = seq_length
        start_positions = start_positions.clamp(0, seq_length - 1)
        end_positions = end_positions.clamp(0, seq_length - 1)
        
        loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        loss = (start_loss + end_loss) / 2
        
        start_preds = start_logits.argmax(dim=1)
        end_preds = end_logits.argmax(dim=1)
        
        batch_f1_scores = []
        batch_em_scores = []
        
        for i in range(len(answer_texts)):
            start_idx = start_preds[i].item()
            end_idx = end_preds[i].item()
            
            if start_idx > end_idx:
                predicted_text = ""
            else:
                token_ids = input_ids[i][start_idx:end_idx+1]
                predicted_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
            
            ground_truth = answer_texts[i]
            
            f1 = f1_score(predicted_text, ground_truth)
            em = exact_match_score(predicted_text, ground_truth)
            
            batch_f1_scores.append(f1)
            batch_em_scores.append(em)
        
        avg_f1 = np.mean(batch_f1_scores)
        avg_em = np.mean(batch_em_scores)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', avg_f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_em', avg_em, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

print("Lightning module created")


Lightning module created


## –®–∞–≥ 12: –û–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏ —Å Lightning

–°–æ–∑–¥–∞–µ–º DataLoaders –∏ –∑–∞–ø—É—Å–∫–∞–µ–º –æ–±—É—á–µ–Ω–∏–µ —Å –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏–µ–º PyTorch Lightning Trainer.


In [12]:
train_loader = DataLoader(
    train_tokenized_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_tokenized_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=0
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

lightning_model = QALightningModule(model_name=model_name, lr=2e-5)

trainer = pl.Trainer(
    max_epochs=2,
    accelerator='auto',
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=50,
    gradient_clip_val=1.0
)

print("\nStarting training...")
trainer.fit(lightning_model, train_loader, val_loader)
print("\nTraining completed!")


Train batches: 1250
Val batches: 125


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



Starting training...



  | Name  | Type        | Params | Mode 
----------------------------------------------
0 | model | QABertModel | 109 M  | train
----------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.935   Total estimated model params size (MB)
2         Modules in train mode
228       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/tam2511/venvs/train_py10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=77` in the `DataLoader` to improve performance.
/home/tam2511/venvs/train_py10/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/tam2511/venvs/train_py10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=77` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.



Training completed!


In [13]:
lightning_model.eval()
sample_batch = next(iter(val_loader))

with torch.no_grad():
    input_ids = sample_batch['input_ids'][:3]
    attention_mask = sample_batch['attention_mask'][:3]
    
    start_logits, end_logits = lightning_model(input_ids, attention_mask)
    start_preds = start_logits.argmax(dim=1)
    end_preds = end_logits.argmax(dim=1)
    
    for i in range(3):
        start_idx = start_preds[i].item()
        end_idx = end_preds[i].item()
        
        if start_idx > end_idx:
            predicted_text = ""
        else:
            token_ids = input_ids[i][start_idx:end_idx+1]
            predicted_text = tokenizer.decode(token_ids, skip_special_tokens=True)
        
        print(f"\nQuestion: {sample_batch['question'][i]}")
        print(f"Context: {sample_batch['context'][i][:150]}...")
        print(f"Predicted: {predicted_text}")
        print(f"Ground truth: {sample_batch['answer_text'][i]}")
        print(f"F1: {f1_score(predicted_text, sample_batch['answer_text'][i]):.4f}")
        print(f"EM: {exact_match_score(predicted_text, sample_batch['answer_text'][i])}")



Question: Which NFL team represented the AFC at Super Bowl 50?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football...
Predicted: american football conference ( afc ) champion denver broncos
Ground truth: Denver Broncos
F1: 0.4444
EM: False

Question: Which NFL team represented the NFC at Super Bowl 50?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football...
Predicted: 
Ground truth: Carolina Panthers
F1: 0.0000
EM: False

Question: Where did Super Bowl 50 take place?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football...
Predicted: levi ' s stadium in the san francisco bay area at santa clara, california
Ground truth: Santa Clara, California
F1: 0.4000
EM: False
