In [2]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

# Load mappings
def load_mappings():
    vocab_file_path = "vocabulary_86.txt"

    with open(vocab_file_path, "r", encoding="utf-8") as f:
        vocab = [line.strip() for line in f]

    word_index_mapping = {word: i for i, word in enumerate(vocab)}
    index_word_mapping = {i: word for word, i in word_index_mapping.items()}

    return word_index_mapping, index_word_mapping

# Load model class
class Word2VecModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, dropout_rate):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Embedding(vocab_size, embedding_dim),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(embedding_dim, vocab_size)
        )

    def forward(self, context):
        embedded = self.network[0](context).mean(dim=1)
        aggregated = self.network[1](embedded)
        out = self.network[2](aggregated)
        return out
    def get_triplets(self):

        """"
        Will generate 5 random triplets and show the similarities as well
        remember to reound them off
        """
        embeddings = self.network[0].weight.data.cpu().numpy()

        #Reference -> https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.cosine_similarity.html
        cos_similarity_mat = cosine_similarity(embeddings)
        indexes = []
        triplets = []
        for word, index in word_index_mapping.items():
            similar, similar_indices_list = [], []

            similar_indices = np.argsort(cos_similarity_mat[index])[::-1]

            for i in similar_indices:
                if i != index:
                    similar_indices_list.append(i)

            similar_indices = similar_indices_list[:3]

            for i in similar_indices:
                similar.append([index_word_mapping[i], cos_similarity_mat[index][i]])

            dissimilar_index = np.argsort(cos_similarity_mat[index])[0]
            dissimilar = (index_word_mapping[dissimilar_index], cos_similarity_mat[index][dissimilar_index])

            triplet = [word, similar, dissimilar]
            triplets.append(triplet)

        for triplet in triplets:
            print(triplet[0], "\n")
            print("similar words:")
            for i in range(len(triplet[1])):
                print("word: " ,triplet[1][i][0]," ", "with similarity: ", triplet[1][i][1])
            print("Dissimilar:", triplet[2][0], triplet[2][1] , "\n")


# Load model from .pth file
def load_model(filepath, vocab_size, embedding_dim, dropout_rate):
    model = Word2VecModel(vocab_size, embedding_dim, dropout_rate)
    model.load_state_dict(torch.load(filepath))
    model.eval()
    return model

# Generate triplets for a word
def get_triplet_for_word(model, word, word_index_mapping, index_word_mapping):
    if word not in word_index_mapping:
        print(f"{word} not found in vocabulary")
        return

    index_of_word = word_index_mapping[word]
    embeddings = model.network[0].weight.data.cpu().numpy()
    cos_similarity_mat = cosine_similarity(embeddings)

    similar_indices = np.argsort(cos_similarity_mat[index_of_word])[::-1]
    similar = [[index_word_mapping[i], cos_similarity_mat[index_of_word][i]]
               for i in similar_indices if i != index_of_word][:3]

    dissimilar_index = np.argsort(cos_similarity_mat[index_of_word])[0]
    dissimilar = [index_word_mapping[dissimilar_index], cos_similarity_mat[index_of_word][dissimilar_index]]

    print(f"{word}\nSimilar words:")
    for sim in similar:
        print(f"{sim[0]} with similarity {sim[1]:.4f}")
    print(f"Dissimilar: {dissimilar[0]} with similarity {dissimilar[1]:.4f}\n")

if __name__ == "__main__":
    vocab_size = 14000
    embedding_dim = 600
    dropout_rate = 0.3
    model_path = "word2vec_checkpoint.pth"

    word_index_mapping, index_word_mapping = load_mappings()
    model = load_model(model_path, vocab_size, embedding_dim, dropout_rate)

    model.get_triplets()


  model.load_state_dict(torch.load(filepath))


[PAD] 

similar words:
word:  earning   with similarity:  0.7073668
word:  thinner   with similarity:  0.65173614
word:  pure   with similarity:  0.6193519
Dissimilar: walmart -0.43760082 

[UNK] 

similar words:
word:  unstable   with similarity:  0.15642601
word:  midt   with similarity:  0.14526747
word:  ji   with similarity:  0.14168407
Dissimilar: ##mporari -0.17963327 

##a 

similar words:
word:  jump   with similarity:  0.7806083
word:  kill   with similarity:  0.7777572
word:  talent   with similarity:  0.746179
Dissimilar: east -0.5089629 

##b 

similar words:
word:  walma   with similarity:  0.14685746
word:  env   with similarity:  0.14439406
word:  mir   with similarity:  0.14374566
Dissimilar: challe -0.15020862 

##c 

similar words:
word:  ##frig   with similarity:  0.14521769
word:  woul   with similarity:  0.1388793
word:  ##rch   with similarity:  0.13325688
Dissimilar: ##tubborn -0.15085664 

##d 

similar words:
word:  headphones   with similarity:  0.61998135
wo