<a href="https://colab.research.google.com/github/sibat119/papers-review-code-impl/blob/main/word2vec/word2vec_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F
import os
import json
import re
from collections import Counter
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

## **Data Preprocessing**

In [38]:
from torch.utils.data import Dataset
embedding = {}
class Word2vecDataset(Dataset):
    NEG_SIZE = 1e8
    def __init__(self, datapath, window_size):
        self.window_size = window_size
        self.sentences_count = 0
        self.negpos = 0
        self.negatives = []
        self.discards = []
        self.word2id = dict()
        self.id2word = dict()
        self.word_frequency = dict()
        self.input_file = open(datapath, encoding="utf8")


        self.corpus = self.get_corpus(datapath)
        self.corpus = self.corpus[:2560000]
        self.words = self.get_words(self.corpus)
        self.word_frequency = self.create_lookup_tables(self.words)
        self.neg_sampling()
        self.sub_sampling()



    def __len__(self):
        return len(self.words)
        # return 256000

    def __getitem__(self, idx):
        # for i in range(len(self.words)):
        tuple_list = []
        context_words = self.get_context_words(idx)
        # print(context_words)

        for word in context_words:
            if idx > len(self.words):
                print(idx, len(self.words))
            assert idx < len(self.words) 
            
            return ((self.word2id[self.words[idx]], self.word2id[word], self.getNegatives(1)[0]))
        # return tuple_list


    # staticmethod
    # def collate(batches):
    #     all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]
    #     all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]
    #     all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0]

    #     return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v)
    
    def get_corpus(self, datapath):
        corpus = ''
        with open(datapath) as input_file:
            data_file = json.load(input_file)
            for x in data_file:
                self.sentences_count += 1
                corpus += (x["text"].lower().strip())
        print(len(corpus.strip()))
        return corpus
  
    def get_words(self, text):
        # Replace punctuation with tokens so we can use them in our model
        text = text.lower()
        text = text.replace('.', ' <PERIOD> ')
        text = text.replace(',', ' <COMMA> ')
        text = text.replace('"', ' <QUOTATION_MARK> ')
        text = text.replace(';', ' <SEMICOLON> ')
        text = text.replace('!', ' <EXCLAMATION_MARK> ')
        text = text.replace('?', ' <QUESTION_MARK> ')
        text = text.replace('(', ' <LEFT_PAREN> ')
        text = text.replace(')', ' <RIGHT_PAREN> ')
        text = text.replace('--', ' <HYPHENS> ')
        text = text.replace('?', ' <QUESTION_MARK> ')
        # text = text.replace('\n', ' <NEW_LINE> ')
        text = text.replace(':', ' <COLON> ')
        words = text.split()
        
        # Remove all words with  5 or fewer occurences
        word_counts = Counter(words)
        trimmed_words = [word for word in words if word_counts[word] > 2]
        print('words -> ', len(trimmed_words), 'corpus->', len(self.corpus))
        return trimmed_words

    def create_lookup_tables(self, words):
        word_counts = Counter(words)
        # sorting the words from most to least frequent in text occurrence
        sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
        # create int_to_vocab dictionaries
        int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab)}
        vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
        print(word_counts)
        
        self.word2id = vocab_to_int
        self.id2word = int_to_vocab
        print(self.word2id)

        return word_counts
    
    def get_context_words(self, idx):
        ''' Get a list of words in a window around an index. '''
        
        # R = np.random.randint(1, window_size+1)
        start = idx - self.window_size if (idx - self.window_size) > 0 else 0
        stop = idx + self.window_size
        target_words = self.words[start:idx] + self.words[idx+1:stop+1]
        
        return list(target_words)
        
    def neg_sampling(self):
        pow_frequency = np.array(list(self.word_frequency.values())) ** 0.5
        words_pow = sum(pow_frequency)
        ratio = pow_frequency / words_pow
        count = np.round(ratio * Word2vecDataset.NEG_SIZE)
        for wid, c in enumerate(count):
            self.negatives += [wid] * int(c)
        self.negatives = np.array(self.negatives)
        np.random.shuffle(self.negatives)

    def getNegatives(self, size):  # TODO check equality with target
        response = self.negatives[self.negpos:self.negpos + size]
        self.negpos = (self.negpos + size) % len(self.negatives)
        if len(response) != size:
            return np.concatenate((response, self.negatives[0:self.negpos]))
        return response

    def sub_sampling(self):
        t = 0.0001
        f = np.array(list(self.word_frequency.values())) / len(self.corpus)
        self.discards = np.sqrt(t / f) + (t / f)
        print(self.discards)

## **Define Model**

In [39]:
class SkipGramModel(nn.Module):

    def __init__(self, emb_size, emb_dimension):
        super(SkipGramModel, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension
        self.word_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)
        self.context_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)

        initrange = 1.0 / self.emb_dimension
        init.uniform_(self.word_embeddings.weight.data, -initrange, initrange)
        init.constant_(self.context_embeddings.weight.data, 0)

    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.word_embeddings(pos_u)
        emb_v = self.context_embeddings(pos_v)
        emb_neg_v = self.context_embeddings(neg_v)

        score = torch.sum(torch.mul(emb_u, emb_v), dim=0)
        score = torch.clamp(score, max=10, min=-10)
        score = -F.logsigmoid(score)

        neg_score = torch.sum(torch.mul(emb_u, emb_neg_v), dim=0)#emb_neg_v, emb_u.unsqueeze(1)).squeeze()
        neg_score = torch.clamp(neg_score, max=10, min=-10)
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=0)

        return torch.mean(score + neg_score)

    def save_embedding(self, id2word, file_name):
        embedding = self.word_embeddings.weight.cpu().data.numpy()
        with open(file_name, 'w') as f:
            f.write('%d %d\n' % (len(id2word), self.emb_dimension))
            for wid, w in id2word.items():
                e = ' '.join(map(lambda x: str(x), embedding[wid]))
                f.write('%s %s\n' % (w, e))

## **Training Loop**

In [40]:
def train(skip_gram_model, dataloader, initial_lr=1e-5, iterations=3, device='cpu', id2word=None, output_file_name='embedding'):
  for iteration in range(iterations):

      print("\n\n\nIteration: " + str(iteration + 1))
      optimizer = optim.SparseAdam(skip_gram_model.parameters(), lr=initial_lr)
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))

      running_loss = 0.0
      skip_gram_model.to(device)
      for i, sample_batched in enumerate(tqdm(dataloader)):
          # print(len(sample_batched), sample_batched)
          sample_batched = torch.tensor(sample_batched)

          for item in sample_batched:
              pos_u = item[0].to(device)
              pos_v = item[1].to(device)
              neg_v = item[2].to(device)

              scheduler.step()
              optimizer.zero_grad()
              loss = skip_gram_model.forward(pos_u, pos_v, neg_v)
              loss.backward()
              optimizer.step()

              running_loss = running_loss * 0.9 + loss.item() * 0.1
              if i > 0 and i % 500 == 0:
                  print(" Loss: " + str(running_loss))
      skip_gram_model.save_embedding(id2word, output_file_name)

## **Train Model**

In [41]:
dataset = Word2vecDataset('/content/wiki_clean.json', 2)

33536661
words ->  359969 corpus-> 2560000
[25.95213549  0.32759907  0.23147777 ... 72.         49.19863931
 33.77777778]


In [42]:
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=0, collate_fn=lambda x: x)
vocab_size = len(dataset.word2id)
print('vocabulary size:', vocab_size)
skip_gram_model = SkipGramModel(vocab_size, 50)
print(skip_gram_model)
print(dataset.corpus[:100])

vocabulary size: 14127
SkipGramModel(
  (word_embeddings): Embedding(14127, 50, sparse=True)
  (context_embeddings): Embedding(14127, 50, sparse=True)
)
m137 was a state trunkline highway in the us state of michigan that served as a spur route to the in


In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [44]:
train(skip_gram_model=skip_gram_model, dataloader=dataloader, device=device, id2word=dataset.id2word)




Iteration: 1


 36%|███▌      | 500/1407 [08:02<14:38,  1.03it/s]

 Loss: 1.3862923565643839
 Loss: 1.3862923189822864
 Loss: 1.3862920586607486
 Loss: 1.3862928853340415
 Loss: 1.3862914597309353
 Loss: 1.3862934549436023
 Loss: 1.3862948095606313
 Loss: 1.3862965770786893
 Loss: 1.3862985969983839
 Loss: 1.386299365884361
 Loss: 1.3862963385519065
 Loss: 1.3862969637337337
 Loss: 1.3862966919323514
 Loss: 1.3862968287808337
 Loss: 1.386296391660807
 Loss: 1.3862974406851867
 Loss: 1.3862963224864189
 Loss: 1.3862959479167627
 Loss: 1.386296409506312
 Loss: 1.3862960858393112
 Loss: 1.3862952461762785
 Loss: 1.3862957540980183
 Loss: 1.3862956390229944
 Loss: 1.386295702348478
 Loss: 1.3862966653320141
 Loss: 1.3862970313381804
 Loss: 1.3862966454879928
 Loss: 1.3862972757389982
 Loss: 1.3862969369743026
 Loss: 1.386298348699846
 Loss: 1.3862963290764434
 Loss: 1.38629851684751
 Loss: 1.3862968022744226
 Loss: 1.3862940551448197
 Loss: 1.3862935377605259
 Loss: 1.386293441663459
 Loss: 1.3862945591899232
 Loss: 1.386293204619808
 Loss: 1.386292271608

 36%|███▌      | 501/1407 [08:03<15:43,  1.04s/it]

 Loss: 1.3862959488401458
 Loss: 1.3862961838397068
 Loss: 1.3862960615533009
 Loss: 1.3862989674905613
 Loss: 1.386300986787648
 Loss: 1.3863012544342617
 Loss: 1.3862971680190033
 Loss: 1.3862998917841194
 Loss: 1.386300197405512
 Loss: 1.386298767771925
 Loss: 1.3862985301434445
 Loss: 1.3862975652592882
 Loss: 1.38629815121688
 Loss: 1.3862978917974014
 Loss: 1.3862999232963722
 Loss: 1.3862988548597097
 Loss: 1.3862988230991722
 Loss: 1.3862981031008088
 Loss: 1.386297681599932
 Loss: 1.3862972784072851
 Loss: 1.3862969632176187
 Loss: 1.3862974424863719
 Loss: 1.3862956207726775
 Loss: 1.3862952448488217
 Loss: 1.3862949422801383
 Loss: 1.386295790535645
 Loss: 1.3862936452589358
 Loss: 1.3862952192630105
 Loss: 1.3862954080109955
 Loss: 1.3862932056193198
 Loss: 1.3862930354480127
 Loss: 1.3862944320146007
 Loss: 1.3862951882455137
 Loss: 1.3862939019000577
 Loss: 1.3862928157147212
 Loss: 1.3862916831758418
 Loss: 1.3862914029884457
 Loss: 1.3862909839267838


 71%|███████   | 1000/1407 [16:05<06:28,  1.05it/s]

 Loss: 1.3859841873391772
 Loss: 1.3860141798982892
 Loss: 1.3860466329869514
 Loss: 1.3860791309431393
 Loss: 1.3861343190451145
 Loss: 1.3861457936805202
 Loss: 1.3861603647030931
 Loss: 1.3861522474489398
 Loss: 1.386173802489202
 Loss: 1.3861954431600818
 Loss: 1.3861753899634586
 Loss: 1.386190649161998
 Loss: 1.3862045374127598
 Loss: 1.3862162142943477
 Loss: 1.3861774662093345
 Loss: 1.3861881666442175
 Loss: 1.3862020766491072
 Loss: 1.3862118776817063
 Loss: 1.386221843020225
 Loss: 1.3862285468483906
 Loss: 1.3862350928936846
 Loss: 1.3862456692595286
 Loss: 1.3862510871892275
 Loss: 1.3862463550572188
 Loss: 1.3860538885120681
 Loss: 1.386075599651706
 Loss: 1.3862168285201535
 Loss: 1.386227800811876
 Loss: 1.386235315530493
 Loss: 1.3862297406157797
 Loss: 1.3862367990935696
 Loss: 1.3862518897645044
 Loss: 1.3862600235038132
 Loss: 1.3862652696275528
 Loss: 1.3862647220883204
 Loss: 1.3862650518471094
 Loss: 1.386261188225814
 Loss: 1.3862576036782874
 Loss: 1.3862490966

 71%|███████   | 1001/1407 [16:06<06:54,  1.02s/it]

 Loss: 1.3861465844776095
 Loss: 1.3861598247229332
 Loss: 1.3861179179494925
 Loss: 1.3860792004953635
 Loss: 1.386099405636562
 Loss: 1.3860035069735406
 Loss: 1.3860296006159387
 Loss: 1.386024725004113
 Loss: 1.3860432012952055
 Loss: 1.3860684964725392
 Loss: 1.3860907137694076
 Loss: 1.3861123067410692
 Loss: 1.3861410625820074
 Loss: 1.386141145948593
 Loss: 1.3861637753761034
 Loss: 1.386177573429008
 Loss: 1.3864079181788502
 Loss: 1.3863972304259067
 Loss: 1.386387730657547
 Loss: 1.3863771304662431
 Loss: 1.3862236569978659
 Loss: 1.3862298456422566
 Loss: 1.3855460996263096
 Loss: 1.3856153233199897
 Loss: 1.3857959994688258
 Loss: 1.385819347710725
 Loss: 1.385864846716508
 Loss: 1.3859197194667323


100%|██████████| 1407/1407 [22:38<00:00,  1.04it/s]





Iteration: 2


 36%|███▌      | 500/1407 [08:02<14:39,  1.03it/s]

 Loss: 1.386276417653048
 Loss: 1.3863118055584587
 Loss: 1.386307879965534
 Loss: 1.3863049787411363
 Loss: 1.3863027014251894
 Loss: 1.3862990544363571
 Loss: 1.3862974529973906
 Loss: 1.3862984674136856
 Loss: 1.386297318067642
 Loss: 1.386297738009535
 Loss: 1.3862916071300293
 Loss: 1.3862953161374856
 Loss: 1.3862967707374212
 Loss: 1.386297030835615
 Loss: 1.3862957867287993
 Loss: 1.386297051218456
 Loss: 1.3862945891386027
 Loss: 1.386295496550121
 Loss: 1.3862915925326211
 Loss: 1.386288996828401
 Loss: 1.386290809177879
 Loss: 1.3862900680275472
 Loss: 1.386290044722412
 Loss: 1.3862934212125426
 Loss: 1.386294147393443
 Loss: 1.3862933346819917
 Loss: 1.3862932350509203
 Loss: 1.3862946235781464
 Loss: 1.386295551387568
 Loss: 1.3862939545465407
 Loss: 1.3862860800879804
 Loss: 1.3862873138836869
 Loss: 1.3862844665514094
 Loss: 1.3862796747386452
 Loss: 1.3862869373291729
 Loss: 1.3862869648334382
 Loss: 1.386284593480557
 Loss: 1.3862815890351503
 Loss: 1.3862842017685981

 36%|███▌      | 501/1407 [08:04<15:41,  1.04s/it]

 Loss: 1.3862905687830263
 Loss: 1.386290590769775
 Loss: 1.3862851746142453
 Loss: 1.3862886804873238
 Loss: 1.3862874846340258
 Loss: 1.3862854070080257
 Loss: 1.3862853133630397
 Loss: 1.3862848118500388
 Loss: 1.38628468235342
 Loss: 1.3862861274481562
 Loss: 1.3862783800483418
 Loss: 1.3862788579691057
 Loss: 1.3862786205257718
 Loss: 1.386279503552235
 Loss: 1.3862813473178002
 Loss: 1.3862815404325473
 Loss: 1.386281583105601
 Loss: 1.386282813604245
 Loss: 1.3862855422993625
 Loss: 1.3862856139391773
 Loss: 1.38628442671747
 Loss: 1.3862861596362381
 Loss: 1.3862793746128608
 Loss: 1.3862816485048768
 Loss: 1.3862816565288398
 Loss: 1.3862760609137976
 Loss: 1.3862778436316219
 Loss: 1.386279614970669
 Loss: 1.386280589287506
 Loss: 1.3862776514753938
 Loss: 1.3862758657513774
 Loss: 1.386268107400422
 Loss: 1.386268015181498
 Loss: 1.3862657148916806


 71%|███████   | 1000/1407 [16:06<06:35,  1.03it/s]

 Loss: 1.3837439962876463
 Loss: 1.3839912964689098
 Loss: 1.3842519540000584
 Loss: 1.384445466256913
 Loss: 1.3847382524521141
 Loss: 1.3848474793652585
 Loss: 1.3849904989915989
 Loss: 1.3850740363345655
 Loss: 1.3850769671928387
 Loss: 1.3852159684716019
 Loss: 1.3851381158070284
 Loss: 1.385237969330238
 Loss: 1.3854758358341037
 Loss: 1.3855579510040503
 Loss: 1.385168547753163
 Loss: 1.385275478950442
 Loss: 1.3853730998557519
 Loss: 1.3854686953534225
 Loss: 1.3855522517481036
 Loss: 1.3856192032204795
 Loss: 1.3856854438519535
 Loss: 1.385771536803489
 Loss: 1.3858254885461139
 Loss: 1.3858083965587207
 Loss: 1.3846761657781232
 Loss: 1.3847836977827694
 Loss: 1.3852459126640688
 Loss: 1.3853610575731992
 Loss: 1.38545449559716
 Loss: 1.3854918359353625
 Loss: 1.3855468403072193
 Loss: 1.3856173250768513
 Loss: 1.385692491563612
 Loss: 1.3856885681442501
 Loss: 1.3857279524110628
 Loss: 1.3857736264082379
 Loss: 1.3857527918588448
 Loss: 1.3855053973470326
 Loss: 1.38553720643

 71%|███████   | 1001/1407 [16:07<07:04,  1.05s/it]

 Loss: 1.3849874026630342
 Loss: 1.3851195294011254
 Loss: 1.3851203666641074
 Loss: 1.385235549197831
 Loss: 1.3853341351624475
 Loss: 1.3854371318825613
 Loss: 1.3855190285690306
 Loss: 1.3851976879222105
 Loss: 1.385291023950241
 Loss: 1.3853982721869307
 Loss: 1.3854444177541874
 Loss: 1.3855313078993925
 Loss: 1.385291303673479
 Loss: 1.3850789834372041
 Loss: 1.3851856442670676
 Loss: 1.3845717238528121
 Loss: 1.3845320338436296
 Loss: 1.384581309058815
 Loss: 1.3846938921498209
 Loss: 1.3848532956975952
 Loss: 1.3849696506981481
 Loss: 1.3850980928472665
 Loss: 1.3851454673050632
 Loss: 1.3852721826290002
 Loss: 1.3853672006179312
 Loss: 1.3854346923633891
 Loss: 1.3860765086952875
 Loss: 1.3861662436137228
 Loss: 1.3861887355395819
 Loss: 1.3861969858183263
 Loss: 1.3852058782971322
 Loss: 1.3853090525981564
 Loss: 1.3801206041559677
 Loss: 1.3807686647044213
 Loss: 1.3816395950555733
 Loss: 1.3819180922722756
 Loss: 1.3823438462927287
 Loss: 1.3827846506817054


100%|██████████| 1407/1407 [22:38<00:00,  1.04it/s]





Iteration: 3


 36%|███▌      | 500/1407 [08:01<14:32,  1.04it/s]

 Loss: 1.3862690745127284
 Loss: 1.3862692432104424
 Loss: 1.3862693235128112
 Loss: 1.3862680129571845
 Loss: 1.3862668096152626
 Loss: 1.3862469630653576
 Loss: 1.3862500104798303
 Loss: 1.3862555188083732
 Loss: 1.386276092720993
 Loss: 1.3862789689835617
 Loss: 1.386274309695069
 Loss: 1.3859675870004033
 Loss: 1.3860028635557953
 Loss: 1.386034827032369
 Loss: 1.3860536997902526
 Loss: 1.3860769318391204
 Loss: 1.3860951107903707
 Loss: 1.3861166574505914
 Loss: 1.3861345593286707
 Loss: 1.3861427078384
 Loss: 1.3861514839295601
 Loss: 1.3861632924763012
 Loss: 1.386178044809787
 Loss: 1.3861901059751705
 Loss: 1.3862009610240158
 Loss: 1.3862104683075396
 Loss: 1.3862187149185583
 Loss: 1.3862200810365657
 Loss: 1.386231276439379
 Loss: 1.3862368938744816
 Loss: 1.3862308392602878
 Loss: 1.386245131165863
 Loss: 1.386243045035971
 Loss: 1.3862328467106577
 Loss: 1.3862597528698226
 Loss: 1.3862647995593114
 Loss: 1.3862663494266834
 Loss: 1.3862623798892884
 Loss: 1.3862701083462

 36%|███▌      | 501/1407 [08:03<15:42,  1.04s/it]

 Loss: 1.386292522931264
 Loss: 1.3862929455496367
 Loss: 1.3862888674787428
 Loss: 1.3862899775074493
 Loss: 1.3862969966024075
 Loss: 1.3862957916816991
 Loss: 1.386288210346781
 Loss: 1.386285082633331
 Loss: 1.3862871552720974
 Loss: 1.3862857543124536
 Loss: 1.386285077574293
 Loss: 1.3862792590639952
 Loss: 1.3862806027575103
 Loss: 1.3862829088071378
 Loss: 1.3862854610889608
 Loss: 1.3862840030499806
 Loss: 1.3862877452887752
 Loss: 1.3863362697825723
 Loss: 1.3863292897998596
 Loss: 1.3863276808195686
 Loss: 1.38632242996097
 Loss: 1.3863656978482044
 Loss: 1.3863581354028616
 Loss: 1.3863522351926536
 Loss: 1.386402238113818
 Loss: 1.3863872784702218
 Loss: 1.3863689987356873
 Loss: 1.3863652546848724
 Loss: 1.3863558768909459
 Loss: 1.386334943742867
 Loss: 1.3863343429308972
 Loss: 1.3863334564931846
 Loss: 1.386335674694269
 Loss: 1.3863218877653085
 Loss: 1.3863137949055258
 Loss: 1.386301552225276
 Loss: 1.3863038852534808
 Loss: 1.386308917527388


 71%|███████   | 1000/1407 [16:01<06:26,  1.05it/s]

 Loss: 1.378914282828651
 Loss: 1.379622202614023
 Loss: 1.3803089930108847
 Loss: 1.3808156436823305
 Loss: 1.381538479281291
 Loss: 1.38193023987367
 Loss: 1.3823680590488396
 Loss: 1.3825442532507672
 Loss: 1.3835112696714487
 Loss: 1.3837793033564645
 Loss: 1.383495309464147
 Loss: 1.3837711857366652
 Loss: 1.3839511078543745
 Loss: 1.3841830255342082
 Loss: 1.3830385948794386
 Loss: 1.3838471840838167
 Loss: 1.3840704921799516
 Loss: 1.3843010572121395
 Loss: 1.3845008886628616
 Loss: 1.3846828827357234
 Loss: 1.3848413368251271
 Loss: 1.3850292450356199
 Loss: 1.3851606684477071
 Loss: 1.385065040369416
 Loss: 1.3819238555844886
 Loss: 1.382194466508889
 Loss: 1.3830843452669357
 Loss: 1.3832021073154068
 Loss: 1.3835112615512124
 Loss: 1.38366146958646
 Loss: 1.3838699824521854
 Loss: 1.3841570526578946
 Loss: 1.3843728342848052
 Loss: 1.3842875396487442
 Loss: 1.3844518633434768
 Loss: 1.384613618709111
 Loss: 1.3846631277717327
 Loss: 1.3848422016904214
 Loss: 1.38488246615577

 71%|███████   | 1001/1407 [16:02<06:56,  1.03s/it]

 Loss: 1.3856829622183588
 Loss: 1.3857261972541526
 Loss: 1.3857822868449912
 Loss: 1.3856922793290103
 Loss: 1.3857624537856357
 Loss: 1.3850331789652386
 Loss: 1.3843897300720107
 Loss: 1.3845455036544703
 Loss: 1.3828550981864842
 Loss: 1.3825802571224135
 Loss: 1.3826882630569008
 Loss: 1.3829058340175986
 Loss: 1.3832399187371767
 Loss: 1.383453238417414
 Loss: 1.3837363497105604
 Loss: 1.3837404408175378
 Loss: 1.384105958770391
 Loss: 1.3846009834683033
 Loss: 1.3846789238520945
 Loss: 1.3855745587648587
 Loss: 1.3856485540182861
 Loss: 1.3857507337193018
 Loss: 1.3857947256321006
 Loss: 1.3829454481143617
 Loss: 1.383263769704598
 Loss: 1.3687785115466993
 Loss: 1.3704671066963505
 Loss: 1.3725497962800115
 Loss: 1.373599562802981
 Loss: 1.3748907867900169
 Loss: 1.3761241516916425


100%|██████████| 1407/1407 [22:28<00:00,  1.04it/s]


In [45]:
embedding

{}

## **Evaluation**

Load the saved model from the previous step and find out top 10 similar words for

a. “Coffee”

b. “Pasta”

c. “Tuna”

d. “Cookies”

In [100]:
from scipy.spatial import distance

Load Embedding


In [88]:
# f = open('/content/embedding', 'r')
embedding = skip_gram_model.word_embeddings.weight.cpu().data.numpy()
emb_dict = {}
for wid, w in dataset.id2word.items():
    emb_dict[w] = embedding[wid]
embedding.shape

(14127, 50)

In [125]:
def get_top_10_similar_words(word, k, embed):
    top_words = []
    for wid, w in dataset.id2word.items():
        if w == word:
            continue
        val = embedding[wid]
        dist = distance.cosine(embed, val)
        
        top_words.append({'word': w, 'dist': dist})
    top_embeddings = sorted(top_words, key=lambda i: i["dist"], reverse=False)
    top_k = top_embeddings[:k]
    
    return top_k

# Coffee

In [126]:
#coffee not present in the corpus
test = 'tea'
word_id = dataset.word2id[test]
test_embed = embedding[word_id]
top_10 = get_top_10_similar_words(test, 10, test_embed)
top_10

[{'dist': 0.37572556734085083, 'word': 'pictorial'},
 {'dist': 0.4872002601623535, 'word': '85'},
 {'dist': 0.5208940804004669, 'word': 'lou'},
 {'dist': 0.5293445885181427, 'word': 'clearing'},
 {'dist': 0.5346536040306091, 'word': 'bloom'},
 {'dist': 0.5530212819576263, 'word': 'deurne'},
 {'dist': 0.5571599006652832, 'word': 'mayfair'},
 {'dist': 0.5580795705318451, 'word': 'listening'},
 {'dist': 0.5597977936267853, 'word': 'sylvia'},
 {'dist': 0.5689918398857117, 'word': '200405'}]

In [None]:
dataset.word2id

# Pasta

In [120]:
test = 'italy'
word_id = dataset.word2id[test]
test_embed = embedding[word_id]
top_10 = get_top_10_similar_words(test, 10, test_embed)
top_10

[{'dist': 0.4481458067893982, 'word': 'limit'},
 {'dist': 0.47057926654815674, 'word': 'options'},
 {'dist': 0.4995225667953491, 'word': 'commentaries'},
 {'dist': 0.5103249549865723, 'word': 'nottingham'},
 {'dist': 0.5182115435600281, 'word': 'contain'},
 {'dist': 0.5187094211578369, 'word': '1969'},
 {'dist': 0.5320057272911072, 'word': '1980'},
 {'dist': 0.5321478247642517, 'word': 'shinchosha'},
 {'dist': 0.5356270372867584, 'word': 'depth'},
 {'dist': 0.5488672852516174, 'word': 'milone'}]

# Tuna

In [121]:
test = 'village'
word_id = dataset.word2id[test]
test_embed = embedding[word_id]
top_10 = get_top_10_similar_words(test, 10, test_embed)
top_10

[{'dist': 0.455991268157959, 'word': 'phd'},
 {'dist': 0.48788273334503174, 'word': 'morrison'},
 {'dist': 0.5122575461864471, 'word': 'vienna'},
 {'dist': 0.5160517394542694, 'word': 'vol'},
 {'dist': 0.5255751609802246, 'word': 'web'},
 {'dist': 0.5306389033794403, 'word': 'extra'},
 {'dist': 0.5364818274974823, 'word': 'venezuelan'},
 {'dist': 0.5387088656425476, 'word': 'atherton'},
 {'dist': 0.5406548082828522, 'word': 'celebrate'},
 {'dist': 0.5421563684940338, 'word': 'sold'}]

# Cookies

In [122]:
test = 'festival'
word_id = dataset.word2id[test]
test_embed = embedding[word_id]
top_10 = get_top_10_similar_words(test, 10, test_embed)
top_10

[{'dist': 0.48736655712127686, 'word': 'tonight'},
 {'dist': 0.49895763397216797, 'word': 'tough'},
 {'dist': 0.5054886937141418, 'word': 'saves'},
 {'dist': 0.5142286121845245, 'word': 'fushimi'},
 {'dist': 0.5163343548774719, 'word': 'violation'},
 {'dist': 0.5235106348991394, 'word': 'fertilizer'},
 {'dist': 0.5324481725692749, 'word': '1955'},
 {'dist': 0.5498812794685364, 'word': 'flights'},
 {'dist': 0.5544025897979736, 'word': '141'},
 {'dist': 0.5547012388706207, 'word': 'categoryafrican'}]

## **Glove word analytics**

Load Glove 300d vector file. Available at: https://nlp.stanford.edu/projects/glove/ and solve below analogies

i. Spain is to Sapnish as Germany is to ____

ii. Japan is to Tokyo as France is to ____

iii. Woman is to Man as Queen is to ____

iv. Australia is to Hotdog as Italy is to ____

Use cosine similarity between the word vectors to solve the analogies.

In [52]:
!wget http://nlp.stanford.edu/data/glove.6B.zip

--2021-09-21 03:17:48--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2021-09-21 03:17:49--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2021-09-21 03:17:49--  http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


2021-0

In [53]:
!unzip glove*.zip

Archive:  glove.6B.zip
  inflating: glove.6B.50d.txt        
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       


In [54]:
!ls
!pwd

embedding	   glove.6B.200d.txt  glove.6B.50d.txt	sample_data
glove.6B.100d.txt  glove.6B.300d.txt  glove.6B.zip	wiki_clean.json
/content


In [64]:
print('Indexing word vectors.')

embeddings_index = {}
words = []
f = open('glove.6B.100d.txt', encoding='utf-8')
for line in f:
    values = line.split()
    word = values[0]
    words.append(word)
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))


Indexing word vectors.
Found 400000 word vectors.


In [83]:
print(embeddings_index)

KeyError: ignored

In [69]:
def get_analogy(token_a, token_b, token_c, embed):
    vec1 = embed[token_a]
    vec2 = embed[token_b]
    vec3 = embed[token_c]
    x = vec2 - vec1 + vec3
    max_cosine_sim = 0
    for w in words:
        if w in [token_a, token_b, token_c]:
            continue
        cosine_sim = 1- distance.cosine(x, embed[w])
        
        if cosine_sim > max_cosine_sim:
            max_cosine_sim = cosine_sim
            best_word = w
            
    return best_word

In [70]:
get_analogy('spain', 'spanish', 'germany', embeddings_index)

'german'

In [71]:
get_analogy('japan', 'tokyo', 'france', embeddings_index)

'paris'

In [72]:
get_analogy('woman', 'man', 'queen', embeddings_index)

'king'

In [73]:
get_analogy('australia', 'hotdog', 'italy', embeddings_index)

'trattoria'