In [2]:
import numpy as np
import pandas as pd
from scipy import spatial
from itertools import zip_longest
from IPython.display import HTML



In [3]:
USE_FULL = False
INPUT_GLOVE = "./glove.42B.300d.txt" if USE_FULL else "./top_50000.txt"


In [4]:
embeddings = {}
with open(INPUT_GLOVE, 'r') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings[word] = vector


In [5]:
def distance(word, reference):
    return spatial.distance.cosine(embeddings[word], embeddings[reference])

def closest_words(reference):
    return sorted(embeddings.keys(), key=lambda w: distance(w, reference))


In [32]:

def goodness(word, answers, bad):
    if word in answers | bad: return -999
    return sum([distance(word, b) for b in bad]) - 4.0 * sum([distance(word, a) for a in answers])

def minimax(word, answers, bad):
    if word in answers | bad: return -999
    return min([distance(word, b) for b in bad]) - max([distance(word, a) for a in answers])

def candidates(answers, bad, size=100):
    best = sorted(embeddings.keys(), key=lambda w: -1 * goodness(w, answers, bad))
    scored_best = sorted([(-1 * minimax(w, answers, bad), w) for w in best[:250]])[:size]
    res = [f"{i}. {w} ({s:.2f})" for i, (s, w) in enumerate(scored_best)]

    return res

In [33]:
def grouper(n, iterable, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(fillvalue=fillvalue, *args)

def tabulate(data, width=10):
    data = list(grouper(width, data))
    return HTML(pd.DataFrame(data).to_html(index=False, header=False))

In [34]:
answers = set("coffee cowboy sack code".lower().split())
bad = set("port violet olympus farm spike stock hotel scuba diver".lower().split())


tabulate(candidates(answers, bad, 10), 5)


0,1,2,3,4
0. swag (-0.06),1. sloppy (-0.05),2. dirty (-0.03),3. mess (0.01),4. monkey (0.02)
5. lazy (0.02),6. shirt (0.02),7. joke (0.03),8. damn (0.03),9. hate (0.03)
