In [8]:
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [9]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model.to(device)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [10]:
squad_dataset = load_dataset('squad_v2')

In [11]:
squad_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [12]:
def encode(example):
    encoded = tokenizer.encode_plus(example["context"], example["question"], truncation=True, padding="max_length", max_length=512)
    return encoded

encoded_squad = squad_dataset.map(encode)

In [17]:
res = []
batch_size = 16  # Adjust based on your system's memory capacity

for i in tqdm(range(0, len(encoded_squad['validation']['input_ids'])//100, batch_size)):
    inputs = encoded_squad['validation']['input_ids'][i:i+batch_size]
    sentence_embedding = encoded_squad['validation']['token_type_ids'][i:i+batch_size]
    tokens = [tokenizer.convert_ids_to_tokens(input) for input in inputs]

    # Convert lists to tensors
    inputs = torch.tensor(inputs)
    sentence_embedding = torch.tensor(sentence_embedding)

    # Ensure tensors are on the same device as the model
    inputs = inputs.to(device)
    sentence_embedding = sentence_embedding.to(device)

    with torch.no_grad():  # Deactivate gradients for the following code block
        outputs = model(input_ids=inputs, token_type_ids=sentence_embedding)
        start_scores, end_scores = outputs['start_logits'], outputs['end_logits']

    for i, (input, start_score, end_score) in enumerate(zip(inputs, start_scores, end_scores)):
        start_index = torch.argmax(start_score)
        end_index = torch.argmax(end_score)
        answer = ' '.join(tokens[i][start_index:end_index+1])
        corrected_answer = ''
        for word in answer.split():
            if word[0:2] == '##':
                corrected_answer += word[2:]
            else:
                corrected_answer += ' ' + word
        res.append(corrected_answer)

100%|██████████| 8/8 [00:39<00:00,  4.88s/it]


In [18]:
res

[' [SEP]',
 ' [SEP]',
 ' [SEP]',
 ' [SEP]',
 ' [SEP]',
 ' [SEP]',
 ' normandy , a region in france . they were descended from norse ( " norman " comes from " norseman " ) raiders and pirates from denmark , iceland and norway who , under their leader rollo , agreed to swear fealty to king charles iii of west francia . through generations of assimilation and mixing with the native frankish and roman - gaulish populations , their descendants would gradually merge with the carolingian - based cultures of west francia . the distinct cultural and ethnic identity of the normans emerged initially in the first half of the 10th century , and it continued to evolve over the succeeding centuries . [SEP]',
 ' [SEP]',
 ' [SEP]',
 '',
 ' [SEP]',
 ' [PAD]',
 '',
 ' [SEP]',
 '',
 ' [SEP]',
 ' [PAD]',
 ' [CLS]',
 ' [CLS]',
 ' [CLS]',
 ' [CLS]',
 ' [SEP]',
 ' [PAD]',
 ' [SEP]',
 ' [SEP]',
 ' treaty of saint - clair - sur - epte between king charles iii of west francia and the famed viking ruler rollo , a

In [77]:
res = []
for i in tqdm(range(len(encoded_squad['validation']['input_ids']))):
    inputs = encoded_squad['validation']['input_ids'][i]
    sentence_embedding = encoded_squad['validation']['token_type_ids']
    tokens = tokenizer.convert_ids_to_tokens(inputs)
    res = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
    start_scores, end_scores = res['start_logits'], res['end_logits']
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)
    answer = ' '.join(tokens[start_index:end_index+1])
    corrected_answer = ''
    for word in answer.split():
        if word[0:2] == '##':
            corrected_answer += word[2:]
        else:
            corrected_answer += ' ' + word
    res.append(corrected_answer)

  0%|          | 0/11873 [00:00<?, ?it/s]

In [51]:
encoded_squad['validation']['token_type_ids']

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [9]:
encoded_squad.set_format('torch')

In [35]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    inputs = [item['input_ids'] for item in batch]
    sentence_embedding = [item['token_type_ids'] for item in batch]
    inputs = pad_sequence(inputs, batch_first=True)
    sentence_embedding = pad_sequence(sentence_embedding, batch_first=True)
    return {'input_ids': inputs, 'token_type_ids': sentence_embedding}

val_data_loader = DataLoader(encoded_squad["validation"], batch_size=16, collate_fn=collate_fn)

In [33]:
results = []

for elem in enumerate(encoded_squad['validation']):
    inputs = elem[1]['input_ids'].to(device)
    inputs = inputs.unsqueeze(0)
    sentence_embedding = elem[1]['token_type_ids'].to(device)
    with torch.no_grad():
        outputs = model(input_ids=inputs, token_type_ids=sentence_embedding)
    results.append(outputs)

KeyboardInterrupt: 

In [36]:
results = []

for batch in tqdm(val_data_loader):
    inputs = batch['input_ids'].to(device)
    sentence_embedding = batch['token_type_ids'].to(device)
    with torch.no_grad():
        outputs = model(input_ids=inputs, token_type_ids=sentence_embedding)
    results.append(outputs)

100%|██████████| 743/743 [24:57<00:00,  2.02s/it]


In [39]:
answers = []

for i, e in tqdm(enumerate(results)):
    start_scores, end_scores = e['start_logits'], e['end_logits']
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)
    print(encoded_squad['validation'][i])
    answer = ' '.join(encoded_squad['validation'][i]['input_ids'][start_index:end_index+1])
    corrected_answer = ''
    for word in answer.split():
        if word[0:2] == '##':
            corrected_answer += word[2:]
        else:
            corrected_answer += ' ' + word
    answers.append(corrected_answer)

2it [00:00,  5.97it/s]

{'id': '56ddde6b9a695914005b9628', 'title': 'Normans', 'context': 'The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.', 'question': 'In what country is Normandy located?', 'answers': {'text': ['France', 'France', 'France', 'France'], 'answer_start': tensor([159, 159, 159, 159])}, 'input_ids': tensor([  101,




TypeError: sequence item 0: expected str instance, Tensor found

In [None]:
res = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

In [15]:
res = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

In [16]:
start_scores, end_scores = res['start_logits'], res['end_logits']

In [18]:
start_index = torch.argmax(start_scores)

end_index = torch.argmax(end_scores)

answer = ' '.join(tokens[start_index:end_index+1])

In [19]:
corrected_answer = ''

for word in answer.split():
    
    #If it's a subword token
    if word[0:2] == '##':
        corrected_answer += word[2:]
    else:
        corrected_answer += ' ' + word

print(corrected_answer)

 the scientific study of algorithms and statistical models
