In [1]:
import torch
import os
from transformers import BigBirdForMultipleChoice, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import random
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#initialize model
model_name = 'google/bigbird-roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize the model
model = BigBirdForMultipleChoice.from_pretrained(model_name).to(device)

Some weights of BigBirdForMultipleChoice were not initialized from the model checkpoint at google/bigbird-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
#Load data
import json
with open('data/wikihop/dev.json', 'r') as file:
    wikihop_dev = json.load(file)
with open('data/wikihop/train.json', 'r') as file:
    wikihop_train = json.load(file)

In [4]:
#Data Processing
#this method is to truncate the context size by relevance. 
import re

def extract_context_windows(text, candidates, window_size=45):
    # Pattern to match a word
    word_pattern = r'\b\w+\b'
    # Combine all candidates into a single regex pattern
    candidates_pattern = '|'.join(re.escape(candidate) for candidate in candidates)
    # Compile a case-insensitive regex pattern
    pattern = re.compile(candidates_pattern, re.IGNORECASE)
    
    # Initialize an empty list to hold all the windows
    windows = []

    # Find all matches of the pattern
    for match in pattern.finditer(text):
        start_pos = match.start()
        end_pos = match.end()

        # Find words around the candidate match
        words_before = re.findall(word_pattern, text[:start_pos])[-window_size:]
        words_after = re.findall(word_pattern, text[end_pos:])[:window_size]
        
        # Combine words before, the candidate, and words after into a window
        window = ' '.join(words_before + [match.group()] + words_after)
        windows.append(window)

    # Combine all windows into a new context
    new_context = ' '.join(windows)
    return new_context


class WikiHopDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # Extract the question, candidates, and supports for the current index
        question = item['query']
        supports = ' '.join(item['supports'])
        candidates = item['candidates']
        correct_answer = item['answer']

        # Limit candidates to a max of 10
        if len(candidates) > 7:
            candidates = candidates[:7]

        # Ensure the correct answer is always included
        if correct_answer not in candidates:
            candidates.insert(0, correct_answer)  

        # Shuffle candidates
        random.shuffle(candidates)  

        # Extract only relevant supports
        full_context = extract_context_windows(supports, candidates)

        # Combine the question with the full context
        combined_context = question + " " + full_context

        # Tokenize the combined context (query + supports)
        context_max_len = int(self.max_length * 0.98)
        context_encoding = self.tokenizer.encode_plus(combined_context, 
                                                    add_special_tokens=True, 
                                                    max_length=context_max_len, 
                                                    padding='max_length',
                                                    truncation=True, return_tensors="pt")

        # Tokenize each candidate
        candidate_max_len = int(self.max_length * 0.02)
        candidates_encoding = [self.tokenizer.encode_plus(candidate, 
                                                        add_special_tokens=False, 
                                                        max_length=candidate_max_len, 
                                                        padding='max_length',
                                                        truncation=True, return_tensors="pt") 
                            for candidate in candidates]

        # Combine context with each candidate
        input_ids = torch.cat([context_encoding['input_ids'].repeat(len(candidates), 1), 
                            torch.stack([c['input_ids'].squeeze(0) for c in candidates_encoding])], dim=1)
        attention_mask = torch.cat([context_encoding['attention_mask'].repeat(len(candidates), 1), 
                                    torch.stack([c['attention_mask'].squeeze(0) for c in candidates_encoding])], dim=1)

        # Get the label (index of the correct answer)
        label = torch.tensor(candidates.index(correct_answer))

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label
        }

from torch.nn.utils.rnn import pad_sequence

import torch
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Separate input_ids, attention_masks, token_type_ids, and labels
    all_input_ids = [item['input_ids'] for item in batch]
    all_attention_masks = [item['attention_mask'] for item in batch]
    all_labels = [item['labels'] for item in batch]

    # Find the maximum number of choices and maximum sequence length
    max_num_choices = max(input_ids.shape[0] for input_ids in all_input_ids)
    max_seq_len = max(input_ids.shape[1] for input_ids in all_input_ids)

    # Pad each choice in each batch item to the maximum sequence length for input_ids
    padded_input_ids = [pad_sequence(item, batch_first=True, padding_value=tokenizer.pad_token_id).view(-1, max_seq_len)[:max_num_choices] 
                        for item in all_input_ids]

    # Pad each batch item to have the same number of choices for input_ids
    padded_input_ids = pad_sequence(padded_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)

    # Repeat the same padding process for attention_mask
    padded_attention_masks = [pad_sequence(item, batch_first=True, padding_value=0).view(-1, max_seq_len)[:max_num_choices] 
                              for item in all_attention_masks]
    padded_attention_masks = pad_sequence(padded_attention_masks, batch_first=True, padding_value=0)


    # Pad labels to the batch size
    labels = torch.tensor(all_labels)

    return {
        'input_ids': padded_input_ids,
        'attention_mask': padded_attention_masks,
        'labels': labels
    }





In [5]:
wikiHop_test_dataset = WikiHopDataset(wikihop_dev[:100], tokenizer, max_length=2750)

test_loader = DataLoader(wikiHop_test_dataset, batch_size=4, collate_fn=custom_collate_fn, pin_memory=True)


In [6]:
def test(model, test_loader, device):
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            if loss.dim() > 0:
                loss = loss.mean()
            test_loss += loss.item()

            # Calculate accuracy
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=1)
            # print(f"predictions: {predictions}")
            # print(f"labels: {labels}")
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)

    avg_test_loss = test_loss / len(test_loader)
    accuracy = correct_predictions / total_predictions

    return avg_test_loss, accuracy




In [7]:
import torch
import os
from transformers import LongformerForMultipleChoice

#load model
checkpoint = torch.load('models/BigBird/checkpoint_epoch_3_step_6300.pt', map_location='cpu')

model = BigBirdForMultipleChoice.from_pretrained('google/bigbird-roberta-base', state_dict=checkpoint)

Some weights of BigBirdForMultipleChoice were not initialized from the model checkpoint at google/bigbird-roberta-base and are newly initialized: ['encoder.layer.2.attention.self.query.weight', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.0.attention.output.LayerNorm.bi

In [8]:


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Call the test method
avg_test_loss, test_accuracy = test(model, test_loader, device)
print(f'Average test loss: {avg_test_loss}')
print(f'Test accuracy: {test_accuracy}')


100%|██████████| 25/25 [01:06<00:00,  2.65s/it]


Average test loss: 2.120200538635254
Test accuracy: 0.08


Test 1, :500 rows, checkpoint_epoch_1_step_400.pt, training: 1000 rows, candidates: 7
Average test loss: 2.129120014190674
Test accuracy: 0.11

Test 2, :500 rows, checkpoint_epoch_1_step_400.pt, training: 1000 rows, candidates: 4
Average test loss: 
Test accuracy: 0.18

Test 3, :1000 rows, checkpoint_epoch_1_step_5400.pt, training: all rows, candidates: 4
Average test loss: 1.612755440711975
Test accuracy: 0.132

Test 4, :1000 rows, checkpoint_epoch_1_step_6000.pt, training: all rows, candidates: 7
Average test loss: 2.0350936398506163
Test accuracy: 0.262
Test 4, :1000 rows, checkpoint_epoch_1_step_6000.pt, training: all rows, candidates: 4
Average test loss: 1.616220552444458
Test accuracy: 0.174

Test 5, :1000 rows, checkpoint_epoch_1_step_7200.pt, training: all rows, candidates: 7
Average test loss: 2.0654408631324768
Test accuracy: 0.291

Test 5, :1000 rows, checkpoint_epoch_1_step_7200.pt, training: all rows, candidates: 4
Average test loss: 1.609327109336853
Test accuracy: 0.174

Test 6, :1000 rows, checkpoint_epoch_1_step_8300.pt, training: all rows, candidates: 7
Average test loss: 2.0652692041397094
Test accuracy: 0.21

In [9]:
# text = "combined_context: country sms braunschweig The North German Confederation was a confederation of 22 previously independent states of northern Germany with nearly 30 million inhabitants It was the first modern German nation state and previously independent states of northern Germany with nearly 30 million inhabitants It was the first modern German nation state and the basis for the later German Empire 18711918 when several south German states such as Bavaria joined Weimar Republic is an unofficial historical designation for the German state between 1919 and 1933 She was laid down in 1901 and commissioned in October 1904 at a cost of 23 983 000 marks She was named after the then Duchy of Brunswick German Braunschweig Her sister ships were Elsass Hessen Preussen and Lothringen The ship served in the II Squadron of the German fleet after commissioning though State of Brunswick was a state of the German Reich in the time of the Weimar Republic It was formed after the abolition of the Duchy of Brunswick in the course of the German Revolution of 191819 Its capital was Braunschweig Brunswick The Congress of Vienna German Wiener Kongress was a conference of Revolution of 191819 Its capital was Braunschweig Brunswick The Congress of Vienna German Wiener Kongress was a conference of ambassadors of European states chaired by Austria n statesman Klemens von Metternich and held in Vienna from November 1814 to June 1815 though the delegates had arrived and were already negotiating by remain at peace The leaders were conservatives with little use for republicanism or revolution both of which threatened to upset the status quo in Europe France lost all its recent conquests while Prussia Austria and Russia made major territorial gains Prussia added smaller German states in the west Swedish Pomerania and little use for republicanism or revolution both of which threatened to upset the status quo in Europe France lost all its recent conquests while Prussia Austria and Russia made major territorial gains Prussia added smaller German states in the west Swedish Pomerania and 60 of the Kingdom of Saxony Austria gained Prussia Austria and Russia made major territorial gains Prussia added smaller German states in the west Swedish Pomerania and 60 of the Kingdom of Saxony Austria gained Venice and much of northern Italy Russia gained parts of Poland The new Kingdom of the Netherlands had been created just months before and and much of northern Italy Russia gained parts of Poland The new Kingdom of the Netherlands had been created just months before and included formerly Austria n territory that in 1830 became Belgium The immediate background was Napoleonic France s defeat and surrender in May 1814 which brought an end to Kingdom of the Netherlands had been created just months before and included formerly Austrian territory that in 1830 became Belgium The immediate background was Napoleonic France s defeat and surrender in May 1814 which brought an end to twenty five years of nearly continuous war Negotiations continued despite the outbreak of years of nearly continuous war Negotiations continued despite the outbreak of fighting triggered by Napoleon s dramatic return from exile and resumption of power in France during the Hundred Days of MarchJuly 1815 The Congress s Final Act was signed nine days before his final defeat at Waterloo on 18 June Hundred Days of MarchJuly 1815 The Congress s Final Act was signed nine days before his final defeat at Waterloo on 18 June 1815 The German Confederation was an association of 39 German states in Central Europe created by the Congress of Vienna in 1815 to coordinate the economies of separate German weak and ineffective as well as an obstacle to the creation of a German nation state It collapsed due to the rivalry between Prussia and Austria warfare the 1848 revolution and the inability of the multiple members to compromise The Royal Navy RN is the United Kingdom s naval warfare force by the English kings from the early medieval period the first major maritime engagements were fought in the Hundred Years War against the kingdom of France The modern Royal Navy traces its origins to the early 16th century the oldest of the UK s armed services it is known as the Evangelical Lutheran Church in Brunswick It is also home to the Jägermeister distillery and houses a campus of the Ostfalia University of Applied Sciences The German Empire officially was the historical German nation state that existed from the unification of Germany in 1871 to the abdication of Kaiser Wilhelm II in November in Washington D C from November 1921 to February 1922 and it was signed by the governments of the United Kingdom the United States Japan France and Italy It limited the construction of battleships battlecruisers and aircraft carriers by the signatories The numbers of other categories of warships including cruisers destroyers categories of warships including cruisers destroyers and submarines were not limited by the treaty but those ships were limited to 10 000 tons displacement The Duchy of Brunswick was a historical German state Its capital was the city of Brunswick Braunschweig It was established as the successor state of the Principality of Brunswick Brunswick Wolfenbüttel by the Congress of Vienna in 1815 In the course of the 19th century history of Germany the duchy was part of the German Confederation the North German Confederation and from 1871 the German Empire It was disestablished after the end of World War I its territory incorporated into the Congress of Vienna in 1815 In the course of the 19th century history of Germany the duchy was part of the German Confederation the North German Confederation and from 1871 the German Empire It was disestablished after the end of World War I its territory incorporated into the Weimar Republic as the the course of the 19th century history of Germany the duchy was part of the German Confederation the North German Confederation and from 1871 the German Empire It was disestablished after the end of World War I its territory incorporated into the Weimar Republic as the Free State of Brunswick World War"

In [10]:
text = " ". join(wikihop_dev[1000]['candidates'][1])
# Tokenize the text
tokens = tokenizer.tokenize(text)

# Count the number of tokens
num_tokens = len(tokens)

print("Number of tokens:", num_tokens)

Number of tokens: 16
