In [1]:
import torch
from transformers import BigBirdForMultipleChoice, AdamW
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
import random
from torch.utils.data import random_split
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import random_split
import json
import os
import torch_optimizer as optim
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#initialize model
model_name = 'google/bigbird-roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize the model
model = BigBirdForMultipleChoice.from_pretrained(model_name).to(device)

Some weights of BigBirdForMultipleChoice were not initialized from the model checkpoint at google/bigbird-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
#Load data
import json
with open('data/wikihop/train.json', 'r') as file:
    wikihop_data = json.load(file)

In [4]:
#this method is to truncate the context size by relevance. 
import re

def extract_context_windows(text, candidates, window_size=45):
    # Pattern to match a word
    word_pattern = r'\b\w+\b'
    # Combine all candidates into a single regex pattern
    candidates_pattern = '|'.join(re.escape(candidate) for candidate in candidates)
    # Compile a case-insensitive regex pattern
    pattern = re.compile(candidates_pattern, re.IGNORECASE)
    
    # Initialize an empty list to hold all the windows
    windows = []

    # Find all matches of the pattern
    for match in pattern.finditer(text):
        start_pos = match.start()
        end_pos = match.end()

        # Find words around the candidate match
        words_before = re.findall(word_pattern, text[:start_pos])[-window_size:]
        words_after = re.findall(word_pattern, text[end_pos:])[:window_size]
        
        # Combine words before, the candidate, and words after into a window
        window = ' '.join(words_before + [match.group()] + words_after)
        windows.append(window)

    # Combine all windows into a new context
    new_context = ' '.join(windows)
    return new_context

In [5]:
class WikiHopDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # Extract the question, candidates, and supports for the current index
        question = item['query']
        supports = ' '.join(item['supports'])
        candidates = item['candidates']
        correct_answer = item['answer']

        # Limit candidates to a max of 10
        if len(candidates) > 7:
            candidates = candidates[:7]

        # Ensure the correct answer is always included
        if correct_answer not in candidates:
            candidates.insert(0, correct_answer)  

        # Shuffle candidates
        random.shuffle(candidates)  

        # Extract only relevant supports
        full_context = extract_context_windows(supports, candidates)

        # Combine the question with the full context
        combined_context = question + " " + full_context

        # Tokenize the combined context (query + supports)
        context_max_len = self.max_length - int(self.max_length * 0.1)
        context_encoding = self.tokenizer.encode_plus(combined_context, 
                                                    add_special_tokens=True, 
                                                    max_length=context_max_len, 
                                                    padding='max_length',
                                                    truncation=True, return_tensors="pt")

        # Tokenize each candidate
        candidate_max_len = int(self.max_length * 0.01)
        candidates_encoding = [self.tokenizer.encode_plus(candidate, 
                                                        add_special_tokens=False, 
                                                        max_length=candidate_max_len, 
                                                        padding='max_length',
                                                        truncation=True, return_tensors="pt") 
                            for candidate in candidates]

        # Combine context with each candidate
        input_ids = torch.cat([context_encoding['input_ids'].repeat(len(candidates), 1), 
                            torch.stack([c['input_ids'].squeeze(0) for c in candidates_encoding])], dim=1)
        attention_mask = torch.cat([context_encoding['attention_mask'].repeat(len(candidates), 1), 
                                    torch.stack([c['attention_mask'].squeeze(0) for c in candidates_encoding])], dim=1)

        # Get the label (index of the correct answer)
        label = torch.tensor(candidates.index(correct_answer))

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label
        }


In [6]:
import torch
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Separate input_ids, attention_masks, token_type_ids, and labels
    all_input_ids = [item['input_ids'] for item in batch]
    all_attention_masks = [item['attention_mask'] for item in batch]
    all_labels = [item['labels'] for item in batch]

    # Find the maximum number of choices and maximum sequence length
    max_num_choices = max(input_ids.shape[0] for input_ids in all_input_ids)
    max_seq_len = max(input_ids.shape[1] for input_ids in all_input_ids)

    # Pad each choice in each batch item to the maximum sequence length for input_ids
    padded_input_ids = [pad_sequence(item, batch_first=True, padding_value=tokenizer.pad_token_id).view(-1, max_seq_len)[:max_num_choices] 
                        for item in all_input_ids]

    # Pad each batch item to have the same number of choices for input_ids
    padded_input_ids = pad_sequence(padded_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)

    # Repeat the same padding process for attention_mask
    padded_attention_masks = [pad_sequence(item, batch_first=True, padding_value=0).view(-1, max_seq_len)[:max_num_choices] 
                              for item in all_attention_masks]
    padded_attention_masks = pad_sequence(padded_attention_masks, batch_first=True, padding_value=0)


    # Pad labels to the batch size
    labels = torch.tensor(all_labels)

    return {
        'input_ids': padded_input_ids,
        'attention_mask': padded_attention_masks,
        'labels': labels
    }



In [7]:
#prepare data
wikihop_data = wikihop_data
train_size = int(0.8 * len(wikihop_data))  # 80% of dataset for training
val_size = len(wikihop_data) - train_size  # Remaining 20% for validation
train_data, val_data = random_split(wikihop_data, [train_size, val_size])

In [8]:
wikiHop_train_dataset = WikiHopDataset(train_data, tokenizer, max_length=2500)
wikiHop_val_dataset = WikiHopDataset(val_data, tokenizer, max_length=2500)

train_loader = DataLoader(wikiHop_train_dataset, batch_size=4, collate_fn=custom_collate_fn, pin_memory=True)
val_loader = DataLoader(wikiHop_val_dataset, batch_size=4, collate_fn=custom_collate_fn, pin_memory=True)

In [9]:
#test batch 

# Fetch a batch from the DataLoader
batch = next(iter(train_loader))

# Extract input_ids, attention_mask, and labels from the batch
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']

# Print shapes
print(f"Shape of input_ids: {input_ids.shape}")
print(f"Shape of attention_mask: {attention_mask.shape}")
print(f"Shape of labels: {labels.shape}")

Shape of input_ids: torch.Size([4, 8, 2275])
Shape of attention_mask: torch.Size([4, 8, 2275])
Shape of labels: torch.Size([4])


In [10]:
# Assuming your DataLoader is named `data_loader` and the BigBird model's max token size is `max_length`
max_length = 1024
# Fetch a batch from the DataLoader
batch = next(iter(train_loader))

# Extract input_ids from the batch
input_ids = batch['input_ids']
batch_size, num_choices, seq_length = input_ids.shape

# Iterate over each item and choice in the batch
for i in range(batch_size):
    for j in range(num_choices):
        token_length = (input_ids[i, j] != tokenizer.pad_token_id).sum()
        print(f"Token length for item {i}, choice {j}: {token_length}")

        # Check if token length exceeds the maximum allowed length
        if token_length > max_length:
            print(f"Warning: Token length for item {i}, choice {j} exceeds the maximum length of {max_length}")


Token length for item 0, choice 0: 598
Token length for item 0, choice 1: 593
Token length for item 0, choice 2: 597
Token length for item 0, choice 3: 593
Token length for item 0, choice 4: 592
Token length for item 0, choice 5: 0
Token length for item 0, choice 6: 0
Token length for item 0, choice 7: 0
Token length for item 1, choice 0: 2252
Token length for item 1, choice 1: 2251
Token length for item 1, choice 2: 2251
Token length for item 1, choice 3: 2252
Token length for item 1, choice 4: 2257
Token length for item 1, choice 5: 2251
Token length for item 1, choice 6: 2254
Token length for item 1, choice 7: 2251
Token length for item 2, choice 0: 526
Token length for item 2, choice 1: 528
Token length for item 2, choice 2: 525
Token length for item 2, choice 3: 526
Token length for item 2, choice 4: 525
Token length for item 2, choice 5: 0
Token length for item 2, choice 6: 0
Token length for item 2, choice 7: 0
Token length for item 3, choice 0: 2251
Token length for item 3, cho

In [11]:
#checkpoint save method
#model, optimizer, scaler, scheduler, epoch, step, checkpoint_filepath
def save_checkpoint(model, optimizer, scaler, scheduler, epoch, step, filepath):
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    checkpoint = {
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }
    torch.save(checkpoint, filepath)


In [12]:
# Validation method
def validate(model, val_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # No need to use autocast in validation as it is beneficial during backward pass which doesn't occur in validation
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, 
                            labels=labels)
            loss = outputs.loss  
            if loss.dim() > 0:
                loss = loss.mean()          
            val_loss += loss.item()

    return val_loss / len(val_loader)

In [13]:
if torch.cuda.device_count() > 1:
    print(f"{torch.cuda.device_count()} GPUs available. Using Data Parallel.")
    model = torch.nn.DataParallel(model)

#training and validate
checkpoint_interval = 100 

#save_path
save_path = "models/BigBird"

# Training hyperparameters
num_epochs = 5
learning_rate = 1e-4  # You can experiment with this value
adam_epsilon = 1e-8
num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = num_training_steps // 10  # 10% of training steps as warm-up

# Prepare the model and optimizers
model.to(device)
base_optimizer = optim.RAdam(model.parameters(), lr=learning_rate, eps=adam_epsilon)
optimizer = optim.Lookahead(base_optimizer, k=5, alpha=0.5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
scaler = torch.cuda.amp.GradScaler()

# Training loop
accumulation_steps = 4  # for example, accumulate gradients over 4 forward passes

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()  # Move zero_grad() outside the inner loop

    for step, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            
            loss = outputs.loss / accumulation_steps  # Scale loss
            if loss.dim() > 0:
                loss = loss.mean()   

        scaler.scale(loss).backward()

        if (step + 1) % accumulation_steps == 0 or step + 1 == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()  # Zero the gradients after optimizer step

            scheduler.step()  # Update the learning rate

        train_loss += loss.item() * accumulation_steps  # Unscaled loss for logging

        if (step + 1) % checkpoint_interval == 0:
            checkpoint_filename = f"checkpoint_epoch_{epoch+1}_step_{step+1}.pt"
            checkpoint_filepath = os.path.join(save_path, checkpoint_filename)
            # Save checkpoint function needs to be updated to handle the new optimizer
            save_checkpoint(model, optimizer, scaler, scheduler, epoch, step, checkpoint_filepath)

    average_train_loss = train_loss / len(train_loader)
    val_loss = validate(model, val_loader, device)  
    print(f"Epoch {epoch+1}/{num_epochs} completed. Train Loss: {average_train_loss:.4f}, Val Loss: {val_loss:.4f}")
    torch.cuda.empty_cache()

final_model_path = os.path.join(save_path, "BigBird_10000_rows_model.pt")
torch.save(model.state_dict(), final_model_path)
print(f"Training complete. Final model saved to {final_model_path}")


8 GPUs available. Using Data Parallel.
Epoch 1/5 completed. Train Loss: 1.3592, Val Loss: 0.6534
Epoch 2/5 completed. Train Loss: 0.5682, Val Loss: 0.5422


Training Epoch 1/5:   0%|          | 0/8748 [00:00<?, ?it/s]2023-11-12 19:02:13.060244: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-12 19:02:13.895651: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-11-12 19:02:13.895725: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
Training Epoch 1/5: 100%|██████████| 8748/8748 [7:12:10<00:00,  2.96s/it]   
Training Epoch 2/5: 100%|████