In [1]:
import torch
import torchtext

In [2]:
embeddings = torchtext.vocab.GloVe(name="6B", dim=100)

In [3]:
class GloVe():
    def __init__(self, embeddings):
        self.embeddings = embeddings

    def get_vector(self, word):
        vec = self.embeddings.vectors[self.embeddings.stoi[word]]
        return vec 

    def get_closest_words(self, input, k=10):
        if type(input) == str:
            input_word_vec = self.get_vector(input)
        else:
            input_word_vec = input

        distances = [(word, torch.dist(input_word_vec, self.get_vector(word)).item()) 
                    for word in self.embeddings.stoi]

        return sorted(distances, key=lambda w: w[1])[:k]

    def get_analogous_words(self, word1, word2, word3, k=10):
        word1_vec = self.get_vector(word1)
        word2_vec = self.get_vector(word2)
        word3_vec = self.get_vector(word3)

        analogy_vec = word2_vec - word1_vec + word3_vec
        analogous_words = self.get_closest_words(analogy_vec)
        analogous_words = [(word, dist) for (word, dist) in analogous_words
                          if word not in [word1, word2, word3]]

        print(f"{word1} is to {word2} as {word3} is to..")
        print("")

        return analogous_words

In [4]:
glove = GloVe(embeddings)

In [5]:
what_vec = glove.get_vector("what")
what_vec, what_vec.shape

(tensor([-1.5180e-01,  3.8409e-01,  8.9340e-01, -4.2421e-01, -9.2161e-01,
          3.7988e-02, -3.2026e-01,  3.4119e-03,  2.2101e-01, -2.2045e-01,
          1.6661e-01,  2.1956e-01,  2.5325e-01, -2.9267e-01,  1.0171e-01,
         -7.5491e-02, -6.0406e-02,  2.8194e-01, -5.8519e-01,  4.8271e-01,
          1.7504e-02, -1.2086e-01, -1.0990e-01, -6.9554e-01,  1.5600e-01,
          7.0558e-02, -1.5058e-01, -8.1811e-01, -1.8535e-01, -3.6863e-01,
          3.1650e-02,  7.6616e-01,  8.4041e-02,  2.6928e-03, -2.7440e-01,
          2.1815e-01, -3.5157e-02,  3.2569e-01,  1.0032e-01, -6.0932e-01,
         -7.0316e-01,  1.8299e-01,  3.3134e-01, -1.2416e-01, -9.0542e-01,
         -3.9157e-02,  4.4719e-01, -5.7338e-01, -4.0172e-01, -8.2234e-01,
          5.5740e-01,  1.5101e-01,  2.4598e-01,  1.0113e+00, -4.6626e-01,
         -2.7133e+00,  4.3273e-01, -1.6314e-01,  1.5828e+00,  5.5081e-01,
         -2.4738e-01,  1.4184e+00, -1.6867e-02, -1.9368e-01,  1.0090e+00,
         -5.9864e-02,  9.1853e-01,  4.

In [6]:
the_vec = glove.get_vector("the")
the_vec, the_vec.shape

(tensor([-0.0382, -0.2449,  0.7281, -0.3996,  0.0832,  0.0440, -0.3914,  0.3344,
         -0.5755,  0.0875,  0.2879, -0.0673,  0.3091, -0.2638, -0.1323, -0.2076,
          0.3340, -0.3385, -0.3174, -0.4834,  0.1464, -0.3730,  0.3458,  0.0520,
          0.4495, -0.4697,  0.0263, -0.5415, -0.1552, -0.1411, -0.0397,  0.2828,
          0.1439,  0.2346, -0.3102,  0.0862,  0.2040,  0.5262,  0.1716, -0.0824,
         -0.7179, -0.4153,  0.2033, -0.1276,  0.4137,  0.5519,  0.5791, -0.3348,
         -0.3656, -0.5486, -0.0629,  0.2658,  0.3020,  0.9977, -0.8048, -3.0243,
          0.0125, -0.3694,  2.2167,  0.7220, -0.2498,  0.9214,  0.0345,  0.4674,
          1.1079, -0.1936, -0.0746,  0.2335, -0.0521, -0.2204,  0.0572, -0.1581,
         -0.3080, -0.4162,  0.3797,  0.1501, -0.5321, -0.2055, -1.2526,  0.0716,
          0.7056,  0.4974, -0.4206,  0.2615, -1.5380, -0.3022, -0.0734, -0.2831,
          0.3710, -0.2522,  0.0162, -0.0171, -0.3898,  0.8742, -0.7257, -0.5106,
         -0.5203, -0.1459,  

In [7]:
glove.get_closest_words("rocket")

[('rocket', 0.0),
 ('rockets', 4.294834613800049),
 ('launcher', 4.528061389923096),
 ('propelled', 4.573644638061523),
 ('launching', 4.586248397827148),
 ('launch', 4.654458045959473),
 ('firing', 4.665835857391357),
 ('fired', 4.6791582107543945),
 ('launchers', 4.699558258056641),
 ('missiles', 4.822944641113281)]

In [8]:
glove.get_closest_words("dogs")

[('dogs', 0.0),
 ('dog', 3.2425272464752197),
 ('cats', 3.528623342514038),
 ('cat', 4.055587291717529),
 ('pets', 4.109102249145508),
 ('animals', 4.179422378540039),
 ('horses', 4.338571071624756),
 ('pigs', 4.477031707763672),
 ('sniffing', 4.527379989624023),
 ('puppies', 4.53402042388916)]

In [9]:
glove.get_closest_words("space")

[('space', 0.0),
 ('spaces', 4.672468185424805),
 ('nasa', 4.70680046081543),
 ('earth', 4.920621871948242),
 ('shuttle', 4.9815287590026855),
 ('spaceship', 5.055478572845459),
 ('spacecraft', 5.101546764373779),
 ('module', 5.113476753234863),
 ('discovery', 5.203618049621582),
 ('orbit', 5.274632930755615)]

In [10]:
glove.get_closest_words("cartoon")

[('cartoon', 0.0),
 ('cartoons', 3.426957607269287),
 ('animated', 3.5315918922424316),
 ('parody', 4.147957801818848),
 ('spoof', 4.408638000488281),
 ('comic', 4.498523235321045),
 ('caricature', 4.643454074859619),
 ('live-action', 4.821722507476807),
 ('animation', 4.872594833374023),
 ('poster', 4.894413471221924)]

In [11]:
glove.get_analogous_words("man", "king", "woman")

man is to king as woman is to..



[('queen', 4.08107852935791),
 ('monarch', 4.642907619476318),
 ('throne', 4.905500888824463),
 ('elizabeth', 4.921558380126953),
 ('prince', 4.981146812438965),
 ('daughter', 4.985714912414551),
 ('mother', 5.064087390899658),
 ('cousin', 5.077497482299805),
 ('princess', 5.078685760498047)]

In [12]:
glove.get_analogous_words("dog", "puppy", "cat")

dog is to puppy as cat is to..



[('kitten', 3.814647674560547),
 ('puppies', 4.0254998207092285),
 ('kittens', 4.157486915588379),
 ('pterodactyl', 4.188157558441162),
 ('scaredy', 4.194512844085693),
 ('tigress', 4.203792572021484),
 ('tabby', 4.257164478302002),
 ('pup', 4.304572582244873)]

In [13]:
glove.get_analogous_words("japan", "tokyo", "india")

japan is to tokyo as india is to..



[('delhi', 3.4406583309173584),
 ('mumbai', 4.064070701599121),
 ('bombay', 4.229851722717285),
 ('lahore', 4.608885288238525),
 ('karachi', 4.626933574676514),
 ('dhaka', 4.73362922668457),
 ('calcutta', 4.911397457122803),
 ('islamabad', 4.915811538696289),
 ('colombo', 5.079125881195068)]

In [14]:
glove.get_analogous_words("sky", "birds", "ocean")

sky is to birds as ocean is to..



[('mammals', 6.280848979949951),
 ('migratory', 6.450166702270508),
 ('animals', 6.53190803527832),
 ('species', 6.55869197845459),
 ('fish', 6.747007369995117),
 ('whales', 6.748803615570068),
 ('reptiles', 6.835613250732422),
 ('amphibians', 7.035011291503906),
 ('insects', 7.082273006439209)]

In [15]:
glove.get_analogous_words("cars", "roads", "trains")

cars is to roads as trains is to..



[('rail', 6.112190246582031),
 ('highways', 6.130516052246094),
 ('routes', 6.1638360023498535),
 ('railway', 6.392805099487305),
 ('connecting', 6.472668170928955),
 ('route', 6.508240222930908),
 ('bridges', 6.801179885864258),
 ('transit', 6.902473449707031)]

In [16]:
glove.get_analogous_words("humans", "earth", "aliens")

humans is to earth as aliens is to..



[('alien', 6.089966297149658),
 ('spaceship', 6.189017295837402),
 ('planet', 6.470367431640625),
 ('shadows', 6.541500091552734),
 ('voyager', 6.677394866943359),
 ('mars', 6.795182228088379),
 ('space', 6.814069747924805),
 ('sky', 6.867318153381348)]