In [1]:
from data.marco import get_sentences
from model.word2vec import get_model, train_model
from gensim.utils import simple_preprocess

raw_sentences = get_sentences()
tokenized_sentences = [simple_preprocess(sentence) for sentence in raw_sentences]
skipgram_model = get_model(tokenized_sentences)
train_model(skipgram_model, tokenized_sentences)

model_path = "./artifacts/word2vec-300.bin"
skipgram_model.save(model_path)

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from gensim.models import Word2Vec

model_path = "./artifacts/word2vec-300.bin"
model = Word2Vec.load(model_path)
model.wv['what']

array([ 0.10019273, -0.33374676, -0.02642345,  0.41534606,  0.14714254,
       -0.06856732,  0.43219516, -0.10808956,  0.19691965,  0.07631785,
        0.10408203,  0.07161479,  0.07158351, -0.17790368,  0.21030354,
        0.13951635,  0.36059147, -0.30308104,  0.10031758, -0.07560526,
       -0.0139124 , -0.12335718,  0.31065795, -0.09804939,  0.10784882,
        0.14239235, -0.0718799 , -0.18125203,  0.07639895, -0.10468737,
        0.20515351,  0.19570537, -0.24486354,  0.02445136,  0.00948732,
       -0.13417399,  0.21692474,  0.08664848,  0.20024656,  0.06315239,
        0.00064163,  0.02665208,  0.01468457, -0.25618294, -0.24128504,
        0.13493638,  0.0933916 ,  0.16827904, -0.20723066, -0.2389881 ,
        0.08144969,  0.01782364,  0.18259524,  0.11810031, -0.30440283,
        0.16046287,  0.04261621,  0.04558365,  0.22207616,  0.30472726,
       -0.14382133, -0.281999  , -0.01817146, -0.33375046,  0.13382258,
       -0.27748376, -0.06398495, -0.21484396,  0.11975351, -0.34

In [2]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch):
    # Unzip the batch into lists of queries, relevant_docs, and irrelevant_docs
    query_emb_list, relevant_doc_emb_list, irrelevant_doc_emb_list = zip(*batch)
    
    # Pad sequences for relevant and irrelevant documents
    # This assumes each document is already a tensor of embeddings; adjust if the structure is different
    padded_relevant = pad_sequence([pad_sequence(docs, batch_first=True) for docs in relevant_doc_emb_list], batch_first=True)
    padded_irrelevant = pad_sequence([pad_sequence(docs, batch_first=True) for docs in irrelevant_doc_emb_list], batch_first=True)
    
    # Pad the sequences of query embeddings
    query_emb_tensor = pad_sequence(query_emb_list, batch_first=True)

    return query_emb_tensor, padded_relevant, padded_irrelevant

In [8]:
from model.twotowermodel import DocumentTower, QueryTower
from data.dataset import QueryDocumentDataset
from data.marco import get_training_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils.loss_function import triplet_loss_function
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', device)

# Initialize model and optimizer
document_model = DocumentTower(embedding_dim=300, hidden_dim=128).to(device)
query_model = QueryTower(embedding_dim=300, hidden_dim=128).to(device)

optimizer = Adam(list(document_model.parameters()) + list(query_model.parameters()), lr=0.001)

dataset_instance = QueryDocumentDataset(data=get_training_dataset(), embedding_model=model)
dataloader = DataLoader(dataset_instance, batch_size=64, shuffle=False, collate_fn=custom_collate)  # You can adjust batch_size as needed


counter = 0
print('starting training')

for query_emb, relevant_doc_emb, irrelevant_doc_emb in dataloader:
    # Convert the tokens to embeddings
    # Forward pass to get encodings for two tower

    query_emb = query_emb.to(device)
    relevant_doc_emb = relevant_doc_emb.to(device)
    irrelevant_doc_emb = irrelevant_doc_emb.to(device)

    relevant_doc_encoding, irrelevant_doc_encoding = document_model(relevant_doc_emb, irrelevant_doc_emb)
    query_encoding = query_model(query_emb)

    # Compute triplet loss
    loss = triplet_loss_function(query_encoding, relevant_doc_encoding, irrelevant_doc_encoding, margin=0.3)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    counter += 1

    if counter % 10 == 0:
        print(f"Triplet loss: {loss.item()}")

    if counter % 1000 == 0:
        torch.save(document_model.state_dict(), 'document_model_state_dict.pth')
        torch.save(query_model.state_dict(), 'query_model_state_dict.pth')


device cuda
starting training
Triplet loss: 0.30319374799728394
Triplet loss: 0.29807257652282715
Triplet loss: 0.29265469312667847
Triplet loss: 0.3056081235408783
Triplet loss: 0.2967092990875244
Triplet loss: 0.3004317283630371
Triplet loss: 0.29340994358062744
Triplet loss: 0.30496418476104736
Triplet loss: 0.3018951714038849
Triplet loss: 0.3007829487323761
Triplet loss: 0.2992069125175476
Triplet loss: 0.2943861484527588
Triplet loss: 0.2912207245826721
Triplet loss: 0.29734379053115845
Triplet loss: 0.29820194840431213
Triplet loss: 0.29773378372192383
Triplet loss: 0.3020175099372864
Triplet loss: 0.28082746267318726
Triplet loss: 0.3164953291416168
Triplet loss: 0.3053825795650482
Triplet loss: 0.28426867723464966
Triplet loss: 0.3031502962112427
Triplet loss: 0.3004280924797058
Triplet loss: 0.3054913282394409
Triplet loss: 0.29094356298446655
Triplet loss: 0.2966930866241455
Triplet loss: 0.29426509141921997
Triplet loss: 0.29119622707366943
Triplet loss: 0.30320054292678833