In the same directory, make sure the following files exist:
1. $\texttt{questions-words.txt}$, which can be downloaded from https://github.com/nicholas-leonard/word2vec/blob/master/questions-words.txt
2. $\texttt{GoogleNews-vectors-negative300.bin}$, which can be downloaded from https://www.kaggle.com/datasets/leadbest/googlenewsvectorsnegative300

In [1]:
!pip install transformers
!pip install gensim

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
### Standard libraries
import operator
from tqdm import tqdm
import numpy as np
import pandas as pd

### For word2vec
import gensim
from gensim import models
from gensim.models import Word2Vec

In [5]:
model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)

def w2v_analogy(a, b, c, d):
    if a in model and b in model and c in model and d in model:
        a_emb, b_emb = model[a], model[b]
        c_emb, d_emb = model[c], model[d]
        f_emb = c_emb - a_emb + b_emb
        
        ## Cosine similarity of d and f
        cos_sim = np.sum(d_emb * f_emb)/np.sqrt(np.sum(d_emb**2) * np.sum(f_emb**2))
        
        ## Most similar (select only lower case words which do not have `_`)
        most_sim = model.similar_by_vector(f_emb, topn=200000, restrict_vocab=None)
        most_sim = [x for x in most_sim if x[0] == x[0].lower() and '_' not in x[0] and x[0] not in {a,b,c}]
        try:
            rank_of_d = np.where(np.array([x[0] for x in most_sim]) == d)[0][0] + 1
        except:
            rank_of_d = None
        top_10 = most_sim[:10] 
        
        return (cos_sim, rank_of_d, top_10)
        
    else:
        return 'At least one word is not in the vocabulary list'

In [6]:
all_in_vocab = []
result = []
categories = []

category = 'none'

with open('questions-words.txt') as file:
    for line in tqdm(file):
        if line[0] == ':':
            category = line[2:].strip('\n')
        if line[0] != ':':
            a, b, c, d = [x.lower() for x in line.strip('\n').split(' ')]
            temp = w2v_analogy(a, b, c, d)
            if temp != 'At least one word is not in the vocabulary list':
                all_in_vocab.append((a, b, c, d))
                result.append(temp)
                categories.append(category)
                
w2v_res = pd.DataFrame()
w2v_res['task'] = all_in_vocab
w2v_res['cosine'] = [x[0] for x in result]
w2v_res['rank'] = [x[1] for x in result]
w2v_res['top_10'] = [x[2] for x in result]
w2v_res['category'] = categories

19558it [2:06:06,  2.58it/s]


In [8]:
w2v_res

Unnamed: 0,task,cosine,rank,top_10,category
0,"(athens, greece, baghdad, iraq)",0.459147,9.0,"[(saddam, 0.48523029685020447), (afghanistan, ...",capital-common-countries
1,"(athens, greece, bangkok, thailand)",0.568460,2.0,"[(europe, 0.5729362368583679), (thailand, 0.56...",capital-common-countries
2,"(athens, greece, beijing, china)",0.238600,19956.0,"[(europe, 0.5613299608230591), (poland, 0.5535...",capital-common-countries
3,"(athens, greece, berlin, germany)",0.548939,2.0,"[(german, 0.5611334443092346), (germany, 0.548...",capital-common-countries
4,"(athens, greece, cairo, egypt)",0.574773,1.0,"[(egypt, 0.5747732520103455), (malta, 0.527613...",capital-common-countries
...,...,...,...,...,...
13700,"(write, writes, talk, talks)",0.359388,83.0,"[(talked, 0.5118362903594971), (discusses, 0.4...",gram9-plural-verbs
13701,"(write, writes, think, thinks)",0.588373,1.0,"[(thinks, 0.5883731842041016), (opines, 0.5764...",gram9-plural-verbs
13702,"(write, writes, vanish, vanishes)",0.519149,3.0,"[(disappear, 0.6582940220832825), (disappears,...",gram9-plural-verbs
13703,"(write, writes, walk, walks)",0.555293,1.0,"[(walks, 0.5552926659584045), (walking, 0.5402...",gram9-plural-verbs


In [7]:
w2v_res.to_csv("w2v_res.csv")