# Prototype implementation of cross encoder using MS-MARCO dataset for semantic search similarity.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# import numpy as np
# import random

import pandas as pd
from sklearn.model_selection import train_test_split


In [None]:
class CrossEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.1):
        super(CrossEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.attention = nn.MultiheadAttention(hidden_size * 2, 8)
        self.fc = nn.Linear(hidden_size * 4, output_size)

    def forward(self, query, document):
        query_emb = self.embedding(query)
        doc_emb = self.embedding(document)
        
        query_out, _ = self.lstm(query_emb)
        doc_out, _ = self.lstm(doc_emb)

        attn_out, _ = self.attention(query_out.transpose(0, 1), doc_out.transpose(0, 1), doc_out.transpose(0, 1))
        attn_out = attn_out.transpose(0, 1)

        concat_out = torch.cat((query_out, attn_out), dim=-1)
        logits = self.fc(concat_out)

        similarity = torch.sigmoid(torch.max(logits, dim=1)[0])

        return similarity


In [4]:
class MSMarcoDataset(Dataset):
    def __init__(self, queries, documents, tokenizer, max_seq_len):
        self.queries = queries
        self.documents = documents
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        query = self.queries[idx]
        document = self.documents[idx]

        query_tokens = self.tokenizer(query)[:self.max_seq_len]
        document_tokens = self.tokenizer(document)[:self.max_seq_len]

        return query_tokens, document_tokens


In [None]:
filepath = '/home/ubuntu/projects/cross-encoder-2/data/fulldocs.tsv'
data = pd.read_csv(filepath, sep='\t', header=None, names=['query_id', 'query', 'passage_id', 'document'])

query_data = data['query'].tolist()
document_data = data['document'].tolist()

query_train, query_val, document_train, document_val = train_test_split(query_data, document_data, test_size=0.2, random_state=42)


In [None]:
def simple_tokenizer(text):
    tokens = [ord(char) if ord(char) < 128 else 0 for char in text]
    return torch.tensor(tokens, dtype=torch.long)


In [None]:
max_seq_len = 128

train_dataset = MSMarcoDataset(query_train, document_train, simple_tokenizer, max_seq_len)
val_dataset = MSMarcoDataset(query_val, document_val, simple_tokenizer, max_seq_len)


In [None]:
batch_size = 32

def collate_fn(batch):
    query_batch, doc_batch = zip(*batch)
    query_batch = nn.utils.rnn.pad_sequence(query_batch, batch_first=True, padding_value=0)
    doc_batch = nn.utils.rnn.pad_sequence(doc_batch, batch_first=True, padding_value=0)
    return query_batch, doc_batch

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


In [None]:
input_size = 128
hidden_size = 512
num_layers = 2
output_size = 1
device = torch.device('cuda')
model = CrossEncoder(input_size, hidden_size, num_layers, output_size)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()


In [1]:
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in train_loader:
        query_batch, document_batch = zip(*batch)
        query_batch = nn.utils.rnn.pad_sequence(query_batch, batch_first=True).to(device)
        document_batch = nn.utils.rnn.pad_sequence(document_batch, batch_first=True).to(device)

        optimizer.zero_grad()
        similarity = model(query_batch, document_batch)

        loss = criterion(similarity, torch.ones(len(batch), dtype=torch.float)).to(device)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}')


IndentationError: unexpected indent (3546601399.py, line 21)