In [6]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

def load_mappings():
    vocab_file_path = "../Data/vocabulary_86.txt"

    with open(vocab_file_path, "r", encoding="utf-8") as f:
        vocab = []
        for line in f:
            vocab.append(line.strip())

    word_index_mapping,index_word_mapping = {},{}

    for i, j in enumerate(vocab):
        word_index_mapping[j] = i
    for i, j in word_index_mapping.items():
        index_word_mapping[j] = i


    return word_index_mapping, index_word_mapping

class Word2VecModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, dropout_rate):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Embedding(vocab_size, embedding_dim),
            torch.nn.Linear(embedding_dim, vocab_size)
        )

    def forward(self, context):
        embedded = self.network[0](context).mean(dim=1)
        out = self.network[1](embedded)
        return out
    def get_triplets(self):

        """"
        Will generate 5 random triplets and show the similarities as well
        remember to reound them off
        """
        embeddings = self.network[0].weight.data.cpu().numpy()

        #Reference -> https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.cosine_similarity.html
        cos_similarity_mat = cosine_similarity(embeddings)
        indexes = []
        triplets = []
        for word, index in word_index_mapping.items():
            similar, similar_indices_list = [], []

            similar_indices = np.argsort(cos_similarity_mat[index])[::-1]

            for i in similar_indices:
                if i != index:
                    similar_indices_list.append(i)

            similar_indices = similar_indices_list[:3]

            for i in similar_indices:
                similar.append([index_word_mapping[i], cos_similarity_mat[index][i]])

            dissimilar_index = np.argsort(cos_similarity_mat[index])[0]
            dissimilar = (index_word_mapping[dissimilar_index], cos_similarity_mat[index][dissimilar_index])

            triplet = [word, similar, dissimilar]
            triplets.append(triplet)

        for triplet in triplets:
            print(triplet[0], "\n")
            print("similar words:")
            for i in range(len(triplet[1])):
                print("word: " ,triplet[1][i][0]," ", "with similarity: ", triplet[1][i][1])
            print("Dissimilar:", triplet[2][0], triplet[2][1] , "\n")


def load_model(filepath, vocab_size, embedding_dim, dropout_rate):
    model = Word2VecModel(vocab_size, embedding_dim, dropout_rate)
    model.load_state_dict(torch.load(filepath))
    model.eval()
    return model

def get_triplet_for_word(model, word, word_index_mapping, index_word_mapping):
    if word not in word_index_mapping:
        print(f"{word} not found in vocabulary")
        return

    index_of_word = word_index_mapping[word]
    embeddings = model.network[0].weight.data.cpu().numpy()
    cos_similarity_mat = cosine_similarity(embeddings)

    similar_indices = np.argsort(cos_similarity_mat[index_of_word])[::-1]
    similar = [[index_word_mapping[i], cos_similarity_mat[index_of_word][i]]
               for i in similar_indices if i != index_of_word][:3]

    dissimilar_index = np.argsort(cos_similarity_mat[index_of_word])[0]
    dissimilar = [index_word_mapping[dissimilar_index], cos_similarity_mat[index_of_word][dissimilar_index]]

    print("Similar words:")
    for _ in similar:
        print(f"{_[0]} with similarity {_[1]}")
    print(f"Dissimilar: {dissimilar[0]} with similarity {dissimilar[1]}\n")

if __name__ == "__main__":
    vocab_size = 14000
    embedding_dim = 300
    dropout_rate = 0
    model_path = "word2vec_checkpoint.pth"

    word_index_mapping, index_word_mapping = load_mappings()
    model = load_model(model_path, vocab_size, embedding_dim, dropout_rate)

    model.get_triplets()


  model.load_state_dict(torch.load(filepath))


[PAD] 

similar words:
word:  earning   with similarity:  0.6760272
word:  pure   with similarity:  0.6417333
word:  thinner   with similarity:  0.63405806
Dissimilar: lamb -0.45545483 

[UNK] 

similar words:
word:  ##aday   with similarity:  0.21323165
word:  victi   with similarity:  0.19443172
word:  riv   with similarity:  0.18993014
Dissimilar: ##rywh -0.2097566 

##a 

similar words:
word:  transport   with similarity:  0.79665136
word:  darren   with similarity:  0.77197444
word:  burn   with similarity:  0.7147912
Dissimilar: expression -0.45093632 

##b 

similar words:
word:  ##gardl   with similarity:  0.25526077
word:  fri   with similarity:  0.23432463
word:  brown   with similarity:  0.1972457
Dissimilar: khali -0.22845101 

##c 

similar words:
word:  press   with similarity:  0.22496365
word:  ##caus   with similarity:  0.22254768
word:  ##aidhai   with similarity:  0.21799582
Dissimilar: wants -0.22612415 

##d 

similar words:
word:  argument   with similarity:  0.64