## Finding entity classes in embeddings

In this notebook we're going to use embeddings to find entity classes and how they correlate with other things

In [None]:
%matplotlib inline
from sklearn import svm
from keras.utils import get_file
import os
import gensim
import numpy as np
import random
import requests
import geopandas as gpd
from IPython.core.pylabtools import figsize
figsize(12, 8)
import csv

In [None]:
MODEL    = 'GoogleNews-vectors-negative300.bin'
data_loc = '/home/smithw/Downloads/deep_learning' # WS: files not backed up here
zipped   = os.path.join(data_loc, MODEL + '.gz')  # WS mod
unzipped = os.path.join(data_loc, MODEL)  # WS
zipped, unzipped

In [None]:
model = gensim.models.KeyedVectors.load_word2vec_format(unzipped, binary=True)

In [None]:
type(model), type(model.vectors)

In [None]:
model.most_similar(positive=['Germany'])

In [None]:
model.most_similar(positive=['Annita_Kirsten'])

In [None]:
countries = list(csv.DictReader(open('data/countries.csv')))
len(countries), countries[:10]

In [None]:
len(model.key_to_index), len(model.index_to_key)

## TRAIN SVM ON COUNTRY NAMES vs NON-COUNTRY NAMES

In [None]:
len(countries), len(model)

In [None]:
# examples of country names for training
num_countries = 140
positive = [x['name'] for x in random.sample(countries, num_countries)]
positive[:10]

In [None]:
# exampls of not-country names
sample = 10000 # WS note: this may randomly contain some countries, but unlikely out of 3e6 tokens
negative = random.sample(model.index_to_key, sample)  # WS this works
negative[:10]

In [None]:
# now label the data
labelled = [(p, 1) for p in positive] + [(n, 0) for n in negative]
random.shuffle(labelled)
X = np.asarray([model[w] for w, l in labelled])
y = np.asarray([l        for w, l in labelled])
X.shape, y.shape

In [None]:
labelled[:5]

In [None]:
sum(y) == num_countries  # add up all the truth: should be same as num_countries

In [None]:
TRAINING_FRACTION = 0.1
cut_off = int(TRAINING_FRACTION * len(labelled))
clf     = svm.SVC(kernel='linear')  # support-vector classification
clf.fit(X[:cut_off], y[:cut_off])  # training

In [None]:
#svm.SVC?, svm.SVC.fit?
#svm.SVC.predict?

In [None]:
# get the predictions from the validation set
res = clf.predict(X[cut_off:])

In [None]:
# test on validation data
missed = [(country, pred) for (pred, truth, country) in zip(res, y[cut_off:], labelled[cut_off:]) if pred != truth]  # WS mod
100 - 100 * float(len(missed)) / len(res), missed

In [None]:
def evaluate(truth, pred):  # WS creation
    total = truth.shape[0]
    tab = {(1, 1): 'TRUE  POSITIVE', (1, 0): 'FALSE  NEGATIVE', 
           (0, 1): 'FALSE POSITIVE', (0, 0): 'TRUE   NEGATIVE'}
    score = {(1, 1): 0, (0, 0): 0, (1, 0): 0, (0, 1): 0}
    for t, p in zip(truth, pred): score[(t, p)] += 1
    truth_table(tab, score)
    correct = score[(0, 0)] + score[(1, 1)]
    print(f'\nTRUE NEG plus TRUE POS: {correct:5} ({float(100 * correct / total):7.4f}%)')

In [None]:
def truth_table(name, values, width=55):  # WS
    print(f'\n{name[(1, 1)]:14}: {values[(1, 1)]:5}   |   {name[(1, 0)]:14}: {values[(1, 0)]:5}')
    print(f'{"".join("_"*width)}')
    print(f'\n{name[(0, 1)]:14}: {values[(0, 1)]:5}   |   {name[(0, 0)]:14}: {values[(0, 0)]:5}\n')

In [None]:
evaluate(y[:cut_off], clf.predict(X[:cut_off]))  # training sample

In [None]:
evaluate(y[cut_off:], clf.predict(X[cut_off:]))  # validation sample

## LOOK AT SAMPLES OF ALL THINGS CALLED A COUNTRY IN A LARGER SET

In [None]:
model.vectors.shape

In [None]:
# note: if too many vectors are taken, RAM will overflow, given that the word2vec dbase is also in RAM
# the full 3000000 is too large to take at once
#all_predictions = clf.predict(model.vectors[:1000000]) # 1e6 takes 30s to run
batch_size = 100000  # takes 3 seconds for 100000
batch_predictions = clf.predict(model.vectors[:batch_size, :]) # WS mod

In [None]:
batch_predictions.shape, sum(batch_predictions)

In [None]:
type(model.index_to_key)

In [None]:
model.index_to_key[:10]

In [None]:
res = []
# zip will go to the smallest of the length of batch_predictions or model.index_to_key
for word, pred in zip(model.index_to_key, batch_predictions):  # WS index_to_key replaces index2word
    if pred:
        res.append(word)  # WS turned off break: see how many hits there are
        #if len(res) == 150:
        #    break
random.sample(res, 20) # can see the country false-alarms mixed in with the true countries

In [None]:
len(res)

# START HERE

In [None]:
country_to_idx = {country['name']: idx for idx, country in enumerate(countries)}

In [None]:
countries[0]

In [None]:
country_names = [k['name'] for k in countries]
country_names.sort()
#country_names

In [None]:
country_vecs = np.asarray([model[c['name']] for c in countries])  # get vector for each country
country_vecs.shape

Quick sanity check to see what is similar to Canada:

In [None]:
dists = np.dot(country_vecs, country_vecs[country_to_idx['Canada']])  # (184, 300) dot (300,) => (184,)

In [None]:
country_vecs[0].shape, dists.shape

In [None]:
for idx in reversed(np.argsort(dists)[-10:]):
    print(countries[idx]['name'], dists[idx])

Ranking countries for a specific term:

In [None]:
def rank_countries(term, topn=10, field='name'):
    if not term in model:
        return []
    vec = model[term]
    dists = np.dot(country_vecs, vec)
    return [(countries[idx][field], float(dists[idx])) for idx in reversed(np.argsort(dists)[-topn:])]

In [None]:
rank_countries('cricket', field='name')  # field also can be 'cc' or 'cc3'

Now let's visualize this on a world map:

In [None]:
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world.head(20)

We can now plot some maps!

In [None]:
type(world)

In [None]:
#world['iso_a3'].map?

In [None]:
def map_term(term):
    d = {k.upper(): v for k, v in rank_countries(term, topn=0, field='cc3')}
    print(len(d))
    if len(d) > 0:  # WS added to handle empty results
        world[term] = world['iso_a3'].map(d)  # WS this creates a new column for the dataframe 'world'
        world[term] /= world[term].max()
        world.dropna().plot(term, cmap='OrRd')

In [None]:
map_term('United_States')

In [None]:
print(world.columns)

In [None]:
world.head(5)

In [None]:
world.dropna().plot('name', ) #cmap='OrRd')

In [None]:
world.dropna().plot('continent')

In [None]:
world.dropna().plot('gdp_md_est', cmap='prism')

In [None]:
#world.plot?

In [None]:
map_term('coffee')

In [None]:
print(world['coffee'].dropna())

In [None]:
map_term('cricket')

In [None]:
map_term('China')

In [None]:
map_term('vodka')

In [None]:
map_term('guns')

In [None]:
map_term('Panama_Canal')

In [None]:
map_term('Paris')

In [None]:
map_term('cancer')

In [None]:
map_term('accidents')

In [None]:
map_term('murder')

In [None]:
map_term('poor')

In [None]:
map_term('G7')

In [None]:
map_term('wealthy')

In [None]:
map_term('paradise')

In [None]:
map_term('WWII')