In [None]:
import os
from datasets import load_dataset
from transformers import AutoTokenizer
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import wandb

In [None]:
torch.set_default_device('cuda')
device = torch.device('cuda')

In [None]:
train_datasets = load_dataset('ms_marco', 'v2.1', "python", split="train[:10%]")
test_datasets = load_dataset('ms_marco', 'v2.1', "python", split="test[:10%]")
validation_datasets = load_dataset('ms_marco', 'v2.1', "python", split="validation[:10%]")

# Create DataFrame
def create_dataframe(datasets):
    data = []
    for example in datasets:
        queries = example['query']
        passage_texts = example['passages']['passage_text']
        labels = example['passages']['is_selected']
        for passage_text, label in zip(passage_texts, labels):
            data.append((queries, passage_text, label))
    df = pd.DataFrame(data, columns=['queries', 'passage_texts', 'labels'])
    return df

In [None]:
train_df = create_dataframe(train_datasets)
test_df = create_dataframe(test_datasets)
validation_df = create_dataframe(validation_datasets)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [None]:
class MSMarcoDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        query = self.df.iloc[idx]['queries']
        passage_text = self.df.iloc[idx]['passage_texts']
        label = self.df.iloc[idx]['labels']

        # Tokenize the query and passage_text
        encoded_pair = self.tokenizer(query, passage_text, 
                                      padding='max_length', 
                                      truncation=True, 
                                      max_length=self.max_length,
                                      return_tensors='pt')

        input_ids = encoded_pair['input_ids'].squeeze(0)
        attention_mask = encoded_pair['attention_mask'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label, dtype=torch.float)
        }


In [None]:
# Convert DataFrame to PyTorch Dataset
train_dataset = MSMarcoDataset(train_df, tokenizer, max_length=512)
test_dataset = MSMarcoDataset(test_df, tokenizer, max_length=512)
validation_dataset = MSMarcoDataset(validation_df, tokenizer, max_length=512)

# Set batch size
batch_size = 16

def custom_collate_fn(batch):
    batch = [elem for elem in batch if elem is not None]
    keys = batch[0].keys()
    batch_dict = {key: [] for key in keys}

    for example in batch:
        for key in keys:
            batch_dict[key].append(example[key])

    for key in keys:
        batch_dict[key] = torch.stack(batch_dict[key]) if key != 'token_type_ids' else batch_dict[key]

    return batch_dict

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def create_data_loader(nq_dataset, batch_size):
    return DataLoader(nq_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device))

In [None]:
# the custom model

class CustomTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, max_length, dropout=0.1):
        super(CustomTransformerEncoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self.generate_positional_encoding(d_model, max_length)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model*4, dropout=dropout),
            num_layers=num_layers
        )
        self.output_layer = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()

    def generate_positional_encoding(self, d_model, max_length):
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        positional_encoding = torch.zeros(max_length, d_model)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        return positional_encoding.unsqueeze(0)

    def forward(self, input_ids, attention_mask):
        # Apply token embeddings
        x = self.embedding(input_ids)

        # Apply positional encoding
        x = x + self.positional_encoding[:, :x.size(1), :].to(x.device)

        # Apply the TransformerEncoder layers
        x = self.transformer(x.transpose(0, 1), src_key_padding_mask=attention_mask)
        
        # Take the first token's representation (CLS token) from the last layer
        x = x[0]

        # Apply the output layer
        x = self.output_layer(x)

        # Apply sigmoid activation function
        x = self.sigmoid(x)

        return x.squeeze(-1)


# Initialize the model
vocab_size = tokenizer.vocab_size
d_model = 768
nhead = 12
num_layers = 6
max_length = 512
dropout = 0.1

model = CustomTransformerEncoder(vocab_size, d_model, nhead, num_layers, max_length, dropout)


In [None]:
def load_best_model_if_available(model, model_path):
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print(f"Loaded model from {model_path}")
    else:
        print("No previous model found, training from scratch.")
    return model

In [None]:
# Initialize wandb
wandb.login()
wandb.init(project="msmarco-cross-encoder", entity="raiden076")
wandb.watch(model)
config = wandb.config

# Define loss function and optimizer
loss_function = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training parameters
num_epochs = 3
best_validation_loss = float('inf')

# Set wandb config parameters
config.vocab_size = vocab_size
config.d_model = d_model
config.nhead = nhead
config.num_layers = num_layers
config.max_length = max_length
config.dropout = dropout
config.batch_size = batch_size
config.num_epochs = num_epochs
config.learning_rate = 1e-4
config.optimizer = "Adam"
config.loss_function = "BCELoss"

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Load the best model if available
model_path = "best_model.pt"
model = load_best_model_if_available(model, model_path)

# Train the model
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_dataloader, desc="Training", leave=False):
        # Move tensors to the device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Convert attention_mask to boolean dtype
        attention_mask = attention_mask.bool()

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_ids, attention_mask)

        # Compute loss
        loss = loss_function(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_dataloader)
    wandb.log({"train_loss": train_loss})
    
    # Evaluate on the validation set
    model.eval()
    validation_loss = 0.0
    with torch.no_grad():
        for batch in validation_dataloader:
            # Move tensors to the device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Convert attention_mask to boolean dtype
            attention_mask = attention_mask.bool()

            # Forward pass
            outputs = model(input_ids, attention_mask)

            # Compute loss
            loss = loss_function(outputs, labels)

            validation_loss += loss.item()

        validation_loss /= len(validation_dataloader)
        wandb.log({"validation_loss": validation_loss})

        # Save the best model
        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), "best_model.pt")
            wandb.save("best_model.pt")

wandb.finish()


In [None]:
best_model = CustomTransformerEncoder(vocab_size, d_model, nhead, num_layers, max_length, dropout)
best_model.load_state_dict(torch.load("best_model.pt"))
best_model.to(device)

In [None]:
def recall_at_k(predictions, labels, k):
    top_k_predictions = torch.topk(predictions, k, dim=1).indices
    top_k_labels = torch.topk(labels, k, dim=1).indices
    recall = sum([1 for i in range(k) if top_k_predictions[:, i] in top_k_labels]) / len(labels)
    return recall

In [None]:
model.eval()
test_loss = 0.0
recall_at_1 = 0.0
recall_at_3 = 0.0
recall_at_10 = 0.0
num_test_batches = 0

with torch.no_grad():
    for batch in test_dataloader:
        # Move tensors to the device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device) if batch['token_type_ids'] is not None else None
        labels = batch['label'].to(device)

        # Forward pass
        outputs = model(input_ids, attention_mask, token_type_ids)

        # Compute loss
        loss = loss_function(outputs, labels)

        test_loss += loss.item()

        # Compute recall
        recall_at_1 += recall_at_k(outputs, labels, 1)
        recall_at_3 += recall_at_k(outputs, labels, 3)
        recall_at_10 += recall_at_k(outputs, labels, 10)

        num_test_batches += 1

test_loss /= num_test_batches
recall_at_1 /= num_test_batches
recall_at_3 /= num_test_batches
recall_at_10 /= num_test_batches

print(f"Test loss: {test_loss}")
print(f"Recall@1: {recall_at_1}")
print(f"Recall@3: {recall_at_3}")
print(f"Recall@10: {recall_at_10}")