In [1]:
from datasets import load_dataset, load_metric
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer
import re
import torch

In [2]:
dataset = load_dataset('squad', token=True, trust_remote_code=True)
dataset['test'] = dataset.pop('validation')
dataset['test'] = dataset['test'].filter(lambda x: x['answers']['text'] != [])

In [4]:
tokenizer = AutoTokenizer.from_pretrained('tiiuae/falcon-7b-instruct')

In [11]:
class NoChunkDataset(Dataset):
    def __init__(self, ds, prompt):
        self.samples = []
        for row in tqdm(ds):
            self.samples.append((prompt.format(context=row['context'], question=row['question']),
                                 row['answers']['text']))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [12]:
formatted_dataset = NoChunkDataset(dataset['test'], 'Context: {context}\n\nQuestion: {question}\n\nAnswer: ')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10570/10570 [00:01<00:00, 7022.17it/s]


In [18]:
formatted_dataset[:2]

[('Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.\n\nQuestion: Which NFL team represented the AFC at Super Bowl 50?\n\nAnswer: ',
  ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']),
 ('Context: Super Bowl 50 was an American football game to determine the

In [19]:
dataloader = DataLoader(formatted_dataset, batch_size=2, shuffle=False)

In [24]:
dataset['test'][0]

{'id': '56be4db0acb8001400a502ed',
 'title': 'Super_Bowl_50',
 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 'question': 'Which NFL team represented the NFC at Super Bowl 50?',
 'answers': {'text': ['Carolina Panthers',
   'Carolina Panthers',
   'Caroli

In [23]:
for batch in dataloader:
    print(batch[0])
    break

('Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.\n\nQuestion: Which NFL team represented the AFC at Super Bowl 50?\n\nAnswer: ', 'Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 s

In [None]:
class TechQA(Dataset):
    def __init__(self, ds, prompt):
        self.samples = []
        for row in tqdm(ds):
            context = row['context']
            context_chunks = tokenizer(context, add_special_tokens=False, truncation=True, max_length=1024,
                                       stride=50, return_overflowing_tokens=True)
            true_spans = row['answers']['text']
            question = row['question']
            
            flag = 0
            for chunk in context_chunks['input_ids']:
                decoded_chunk = tokenizer.decode(chunk, clean_up_tokenization_spaces=False)
                for ans in true_spans:
                    if ans in decoded_chunk:
                        flag = 1
                    else:
                        flag = 0
                        break
                if flag == 1:
                    self.samples.append((prompt.format(context=decoded_chunk, question=question), true_spans))
                    break

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]
    
formatted_dataset = TechQA(dataset['test'], 'Context: {context}\n\nQuestion: {question}\n\nAnswer: ')

c=0
for expanded_prompt, true_answers in formatted_dataset:
    d=0
    for ans in true_answers:
        if re.search(fr'{re.escape(ans)}', expanded_prompt, re.IGNORECASE):
            d+=1
    if d==len(true_answers):
        c+=1
c

In [None]:
dataset

In [None]:
c=0
for expanded_prompt, true_answers in formatted_dataset:
    d=0
    for ans in true_answers:
        if re.search(fr'{re.escape(ans)}', expanded_prompt, re.IGNORECASE):
            d+=1
    if d == len(true_answers):
        print(expanded_prompt, true_answers)
        c+=1
    else:
        print(expanded_prompt, true_answers)
        break
c

In [None]:
c=0
for row in dataset['test']:
    d=0
    for ans in row['answers']['text']:
        if re.search(fr'{re.escape(ans)}', row['context'], re.IGNORECASE):
            #print(row['context'], ans)
            #print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>.')
            d+=1
    if d==len(row['answers']['text']):
        c+=1
c