In [51]:
import os, random
from datasets import load_dataset
from transformers import AutoTokenizer
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm import tqdm

In [52]:
torch.set_default_device('cuda')
device = torch.device('cuda')
%env WANDB_NOTEBOOK_NAME = "cross-encoder"

env: WANDB_NOTEBOOK_NAME="cross-encoder"


In [98]:



train_datasets = load_dataset('ms_marco', 'v2.1', "python", split="train[:2.5%]")
test_datasets = load_dataset('ms_marco', 'v2.1', "python", split="test[:2.5%]")
validation_datasets = load_dataset('ms_marco', 'v2.1', "python", split="validation[:1%]")

# Create DataFrame
def create_dataframe(datasets):
    data = []
    flag = False
    for example in datasets:
        queries = example['query']
        passage_texts = example['passages']['passage_text']
        labels = example['passages']['is_selected']
        
        # Get indices of passages with label 1 and 0
        label1_indices = [i for i, label in enumerate(labels) if label == 1]
        label0_indices = [i for i, label in enumerate(labels) if label == 0]
        
        # Randomly select one passage with label 0 if available
        if label1_indices:
            if flag:
                random_label0_index = random.choice(label0_indices)
            
                # Add the randomly selected passage with label 0 to the data
                data.append((queries, passage_texts[random_label0_index], labels[random_label0_index]))
                flag = False
            else:
                flag = True
        
        # Add passages with label 1 to the data
        if label1_indices:
            for index in label1_indices:
                data.append((queries, passage_texts[index], labels[index]))
    
    df = pd.DataFrame(data, columns=['queries', 'passage_texts', 'labels'])
    return df



Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1-data_dir=python/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)
Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1-data_dir=python/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)
Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1-data_dir=python/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [100]:
train_df = create_dataframe(train_datasets)
test_df = create_dataframe(test_datasets)
validation_df = create_dataframe(validation_datasets)

In [101]:
train_df["labels"].value_counts()

1    5200
0    2435
Name: labels, dtype: int64

In [56]:
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [57]:
class MSMarcoDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        query = self.df.iloc[idx]['queries']
        passage_text = self.df.iloc[idx]['passage_texts']
        label = self.df.iloc[idx]['labels']

        # Tokenize the query and passage_text
        encoded_pair = self.tokenizer(query, passage_text, 
                                      padding='max_length', 
                                      truncation=True, 
                                      max_length=self.max_length,
                                      return_tensors='pt')

        input_ids = encoded_pair['input_ids'].squeeze(0)
        attention_mask = encoded_pair['attention_mask'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label, dtype=torch.float)
        }


In [58]:
# Convert DataFrame to PyTorch Dataset
train_dataset = MSMarcoDataset(train_df, tokenizer, max_length=512)
test_dataset = MSMarcoDataset(test_df, tokenizer, max_length=512)
validation_dataset = MSMarcoDataset(validation_df, tokenizer, max_length=512)

# Set batch size
batch_size = 16

def custom_collate_fn(batch):
    batch = [elem for elem in batch if elem is not None]
    keys = batch[0].keys()
    batch_dict = {key: [] for key in keys}

    for example in batch:
        for key in keys:
            batch_dict[key].append(example[key])

    for key in keys:
        batch_dict[key] = torch.stack(batch_dict[key]) if key != 'token_type_ids' else batch_dict[key]

    return batch_dict

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [59]:
def create_data_loader(nq_dataset, batch_size):
    return DataLoader(nq_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device))

In [60]:
# the custom model

class CustomTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, max_length, dropout=0.1):
        super(CustomTransformerEncoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self.generate_positional_encoding(d_model, max_length)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model*4, dropout=dropout),
            num_layers=num_layers
        )
        self.output_layer = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()

    def generate_positional_encoding(self, d_model, max_length):
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        positional_encoding = torch.zeros(max_length, d_model)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        return positional_encoding.unsqueeze(0)

    def forward(self, input_ids, attention_mask):
        # Apply token embeddings
        x = self.embedding(input_ids)

        # Apply positional encoding
        x = x + self.positional_encoding[:, :x.size(1), :].to(x.device)

        # Apply the TransformerEncoder layers
        x = self.transformer(x.transpose(0, 1), src_key_padding_mask=attention_mask)
        
        # Take the first token's representation (CLS token) from the last layer
        x = x[0]

        # Apply the output layer
        x = self.output_layer(x)

        # Apply sigmoid activation function
        x = self.sigmoid(x)

        return x.squeeze(-1)


# Initialize the model
vocab_size = tokenizer.vocab_size
d_model = 768
nhead = 12
num_layers = 6
max_length = 512
dropout = 0.1

model = CustomTransformerEncoder(vocab_size, d_model, nhead, num_layers, max_length, dropout)


In [61]:
def save_checkpoint(model, optimizer, epoch, train_loss, validation_loss, model_path):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "train_loss": train_loss,
        "validation_loss": validation_loss
    }
    torch.save(checkpoint, model_path)

def load_checkpoint(model, optimizer, model_path):
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        train_loss = checkpoint["train_loss"]
        validation_loss = checkpoint["validation_loss"]
        print(f"Loaded model from {model_path}")
    else:
        print("No previous model found, training from scratch.")
        epoch = 0
        train_loss = None
        validation_loss = None

    return model, optimizer, epoch, train_loss, validation_loss


In [102]:
# Initialize wandb
wandb.login()
wandb.init(project="msmarco-cross-encoder", entity="raiden076")
wandb.watch(model)
config = wandb.config

# Define loss function and optimizer
loss_function = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training parameters
num_epochs = 5
best_validation_loss = float('inf')

# Set wandb config parameters
config.vocab_size = vocab_size
config.d_model = d_model
config.nhead = nhead
config.num_layers = num_layers
config.max_length = max_length
config.dropout = dropout
config.batch_size = batch_size
config.num_epochs = num_epochs
config.learning_rate = 1e-4
config.optimizer = "Adam"
config.loss_function = "BCELoss"

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Load the best model if available
model_path = "best_model.pt"
model, optimizer, starting_epoch, _, best_validation_loss = load_checkpoint(model, optimizer, model_path)
if best_validation_loss is None:
    best_validation_loss = float('inf')

# Train the model
for epoch in range(starting_epoch, num_epochs):
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_dataloader, desc="Training", leave=False):
        # Move tensors to the device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Convert attention_mask to boolean dtype
        attention_mask = attention_mask.bool()

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_ids, attention_mask)

        # Compute loss
        loss = loss_function(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_dataloader)
    wandb.log({"train_loss": train_loss})
    
    # Evaluate on the validation set
    model.eval()
    validation_loss = 0.0
    with torch.no_grad():
        for batch in validation_dataloader:
            # Move tensors to the device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Convert attention_mask to boolean dtype
            attention_mask = attention_mask.bool()

            # Forward pass
            outputs = model(input_ids, attention_mask)

            # Compute loss
            loss = loss_function(outputs, labels)

            validation_loss += loss.item()

        validation_loss /= len(validation_dataloader)
        wandb.log({"validation_loss": validation_loss})

        # Save the best model
        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            save_checkpoint(model, optimizer, epoch + 1, train_loss, validation_loss, model_path)
            wandb.save(model_path)

wandb.finish()


Loaded model from best_model.pt


Training:  55%|█████▍    | 345/630 [04:42<03:52,  1.23it/s]

In [65]:
# Load the best model
best_model = CustomTransformerEncoder(vocab_size, d_model, nhead, num_layers, max_length, dropout)
best_optimizer = optim.Adam(best_model.parameters(), lr=1e-4)

# Load the checkpoint
best_model, best_optimizer, _, _, _ = load_checkpoint(best_model, best_optimizer, "best_model.pt")

# Move the best model to the device
best_model.to(device)


OutOfMemoryError: CUDA out of memory. Tried to allocate 734.00 MiB (GPU 0; 14.62 GiB total capacity; 13.17 GiB already allocated; 29.38 MiB free; 14.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [66]:
def recall_at_k(predictions, labels, k):
    top_k_predictions = torch.topk(predictions, k).indices
    top_k_labels = torch.topk(labels, k).indices
    recall = sum([1 for i in range(k) if top_k_predictions[i] in top_k_labels]) / len(labels)
    return recall


In [67]:
model.eval()
test_loss = 0.0
recall_at_1 = 0.0
recall_at_3 = 0.0
recall_at_10 = 0.0
num_test_batches = 0

with torch.no_grad():
    for batch in test_dataloader:
        # Move tensors to the device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # token_type_ids = batch['token_type_ids'].to(device) if batch['token_type_ids'] is not None else None
        labels = batch['label'].to(device)
        attention_mask = attention_mask.bool()

        # Forward pass
        outputs = model(input_ids, attention_mask)

        # Compute loss
        loss = loss_function(outputs, labels)

        test_loss += loss.item()

        # Compute recall
        recall_at_1 += recall_at_k(outputs, labels, 1)
        recall_at_3 += recall_at_k(outputs, labels, 3)
        recall_at_10 += recall_at_k(outputs, labels, 10)

        num_test_batches += 1

test_loss /= num_test_batches
recall_at_1 /= num_test_batches
recall_at_3 /= num_test_batches
recall_at_10 /= num_test_batches

print(f"Test loss: {test_loss}")
print(f"Recall@1: {recall_at_1}")
print(f"Recall@3: {recall_at_3}")
print(f"Recall@10: {recall_at_10}")

ZeroDivisionError: float division by zero

In [78]:
def infer(query, candidate_passages, tokenizer, model, device, max_length=512):
    model.eval()
    
    encoded_pairs = [
        tokenizer(query, passage, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
        for passage in candidate_passages
    ]
    
    scores = []
    with torch.no_grad():
        for encoded_pair in encoded_pairs:
            input_ids = encoded_pair['input_ids'].to(device)
            attention_mask = encoded_pair['attention_mask'].to(device)
            attention_mask = attention_mask.bool()
            
            score = model(input_ids, attention_mask)
            scores.append(score.item())
    
    ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    ranked_passages = [candidate_passages[i] for i in ranked_indices]
    ranked_scores = [scores[i] for i in ranked_indices]
    
    return ranked_passages, ranked_scores


In [79]:
def pred(query, candidate_passages):
    ranked_passages, ranked_scores = infer(query, candidate_passages, tokenizer, best_model, device)

    for i, (passage, score) in enumerate(zip(ranked_passages, ranked_scores)):
        print(f"{i+1}. {passage} (score: {score:.4f})")

In [80]:
query = "What is the capital of France?"
candidate_passages = [
    "The capital of France is Paris.",
    "France is a country in Europe.",
    "The Eiffel Tower is in Paris, the capital of France.",
    "Germany is the neighboring country of France."
]
pred(query, candidate_passages)


1. The Eiffel Tower is in Paris, the capital of France. (score: 0.4705)
2. Germany is the neighboring country of France. (score: 0.4701)
3. The capital of France is Paris. (score: 0.4699)
4. France is a country in Europe. (score: 0.4699)


In [91]:

query = "What do bees make?"
candidate_passages = [
    "Bees make honey.",
    "Ants are small insects that live in colonies.",
    "Butterflies are insects with colorful wings.",
    "Birds can fly and lay eggs in nests."
]






pred(query, candidate_passages)

1. Butterflies are insects with colorful wings. (score: 0.4705)
2. Birds can fly and lay eggs in nests. (score: 0.4705)
3. Ants are small insects that live in colonies. (score: 0.4704)
4. Bees make honey. (score: 0.4697)
