In [None]:
!pip install datasets
!pip install tqdm
!pip install sentencepiece
!pip install word2vec
!pip install gensim

In [None]:
import pandas as pd
import numpy as np
import random
from datasets import list_datasets
from datasets import load_dataset
import time
import sentencepiece as sp
import gensim
from gensim.models import Word2Vec
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

datasets_list = list_datasets()

# get MS Marco
ms_df_dict = load_dataset('ms_marco', 'v1.1')

In [None]:
# Combine queries and passages into a single list
texts = []
for split in ['train', 'validation', 'test']:
    queries = ms_df_dict[split]['query']
    passages = ms_df_dict[split]['passages']
    texts.extend(queries)
    for passage_dict in passages:
        passage_texts = passage_dict['passage_text']
        texts.extend(passage_texts)
    # for passage_dict in passages:
    #     passage_texts = [p['passage_text'] for p in passage_dict]
    #     texts.extend(passage_texts)

print(texts[:3])

In [None]:
# Train SentencePiece tokenizer
vocab_size = 30000
model_name = 'ms_marco_spm'
spm_model = sp.SentencePieceTrainer.train(
    sentence_iterator=iter(texts),
    model_prefix=model_name,
    vocab_size=vocab_size,
    character_coverage=1.0,
    model_type='unigram'
)

In [None]:
# # Save the trained SentencePiece model
model_name = 'ms_marco_spm'
spm_model_path = f"{model_name}.model"
spm_model.save(spm_model_path)

# # Load the trained SentencePiece model
sp_model = sp.SentencePieceProcessor()
sp_model.load(spm_model_path)

In [None]:
# Tokenize the texts using the trained SentencePiece model
tokenized_texts = [sp_model.encode_as_pieces(text) for text in texts]

# Train Word2Vec model
embedding_size = 300
window = 5
min_count = 1
workers = 4

word2vec_model = Word2Vec(
    sentences=tokenized_texts,
    vector_size=embedding_size,
    window=window,
    min_count=min_count,
    workers=workers
)

# Save the trained Word2Vec model
word2vec_model_path = 'ms_marco_word2vec.model'
word2vec_model.save(word2vec_model_path)

In [None]:
import torch
import torch.nn as nn

# Load the trained Word2Vec model
word2vec_model = Word2Vec.load('ms_marco_word2vec.model')

# Create an embedding layer using the Word2Vec weights
embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(word2vec_model.wv.vectors), freeze=True)
print("Embedding layer shape:", embedding_layer.weight.shape)


In [None]:
from tqdm import tqdm

def generate_triples(split):
    queries = []
    pos_docs = []
    neg_docs = []

    num_samples = len(ms_df_dict[split])
    print("num_samples", num_samples)
    for idx in tqdm(range(num_samples)):
        query = ms_df_dict[split][idx]['query']
        
        relevant_passages = ms_df_dict[split][idx]['passages']['passage_text']
        # if idx < 1:
        #     print("relevant_passages", relevant_passages)

        # for p in passages:
        #     if isinstance(p, dict) and 'passage_text' in p:
        #         relevant_passages.append(p['passage_text'])

        if relevant_passages:
            pos_doc = random.choice(relevant_passages)

            # Select a random negative passage from a different query
            neg_idx = random.randint(0, num_samples - 1)
            while neg_idx == idx:
                neg_idx = random.randint(0, num_samples - 1)

            neg_passages = ms_df_dict[split][neg_idx]['passages']['passage_text']
            neg_passage = random.choice(neg_passages)
            neg_doc = neg_passage

            # if idx < 1:
                # print("relevant_passages", relevant_passages)
                # print("pos_doc", pos_doc)
                # print("neg_doc", neg_doc)
            queries.append(query)
            pos_docs.append(pos_doc)
            neg_docs.append(neg_doc)

    return queries, pos_docs, neg_docs

train_queries, train_pos_docs, train_neg_docs = generate_triples('train')
test_queries, test_pos_docs, test_neg_docs = generate_triples('test')
validation_queries, validation_pos_docs, validation_neg_docs = generate_triples('validation')

print("yay")

In [None]:
print("train_queries[0]", train_queries[0])
print("train_pos_docs[0]", train_pos_docs[0])
print("train_neg_docs[0]", train_neg_docs[0])

In [None]:
class MSMarcoDataset(Dataset):
    def __init__(self, queries, pos_docs, neg_docs, tokenizer):
        self.queries = queries
        self.pos_docs = pos_docs
        self.neg_docs = neg_docs
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        query = self.queries[idx]
        pos_doc = self.pos_docs[idx]
        neg_doc = self.neg_docs[idx]

        query_tokens = self.tokenizer.encode(query)
        pos_doc_tokens = self.tokenizer.encode(pos_doc)
        neg_doc_tokens = self.tokenizer.encode(neg_doc)

        return query_tokens, pos_doc_tokens, neg_doc_tokens

def collate_fn(batch):
    queries, pos_docs, neg_docs = zip(*batch)
    queries = pad_sequence([torch.LongTensor(q) for q in queries], batch_first=True)
    pos_docs = pad_sequence([torch.LongTensor(d) for d in pos_docs], batch_first=True)
    neg_docs = pad_sequence([torch.LongTensor(d) for d in neg_docs], batch_first=True)
    return queries, pos_docs, neg_docs

train_dataset = MSMarcoDataset(train_queries, train_pos_docs, train_neg_docs, sp_model)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

query_tokens, pos_doc_tokens, neg_doc_tokens = train_dataset[0]
print("Query tokens:", query_tokens)
print("Positive document tokens:", pos_doc_tokens)
print("Negative document tokens:", neg_doc_tokens)

In [None]:
validation_dataset = MSMarcoDataset(validation_queries, validation_pos_docs, validation_neg_docs, sp_model)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)

    def forward(self, input):
        output, hidden = self.gru(input)
        return output[:, -1, :]

class TwoTowerModel(nn.Module):
    def __init__(self, embedding_layer, hidden_size):
        super(TwoTowerModel, self).__init__()
        self.embedding = embedding_layer
        self.query_encoder = EncoderRNN(embedding_layer.embedding_dim, hidden_size)
        self.doc_encoder = EncoderRNN(embedding_layer.embedding_dim, hidden_size)

    def forward(self, query, doc):
        query_emb = self.embedding(query)
        doc_emb = self.embedding(doc)
        query_enc = self.query_encoder(query_emb)
        doc_enc = self.doc_encoder(doc_emb)
        return query_enc, doc_enc

hidden_size = 128
model = TwoTowerModel(embedding_layer, hidden_size)

def distance_function(query_enc, doc_enc):
    return nn.functional.cosine_similarity(query_enc, doc_enc)

def triplet_loss(query_enc, pos_doc_enc, neg_doc_enc, margin=1.0):
    pos_distance = distance_function(query_enc, pos_doc_enc)
    neg_distance = distance_function(query_enc, neg_doc_enc)
    loss = nn.functional.relu(margin - pos_distance + neg_distance)
    return loss.mean()

optimizer = optim.Adam(model.parameters())
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0

    for query, pos_doc, neg_doc in train_dataloader:
        query, pos_doc, neg_doc = query.to(device), pos_doc.to(device), neg_doc.to(device)
        optimizer.zero_grad()
        query_enc, pos_doc_enc = model(query, pos_doc)
        _, neg_doc_enc = model(query, neg_doc)
        loss = triplet_loss(query_enc, pos_doc_enc, neg_doc_enc)
        total_train_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_train_loss = total_train_loss / len(train_dataloader)

    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for query, pos_doc, neg_doc in validation_dataloader:
            query, pos_doc, neg_doc = query.to(device), pos_doc.to(device), neg_doc.to(device)
            query_enc, pos_doc_enc = model(query, pos_doc)
            _, neg_doc_enc = model(query, neg_doc)
            loss = triplet_loss(query_enc, pos_doc_enc, neg_doc_enc)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(validation_dataloader)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")