In [None]:
import time
import json
import torch
from collections import defaultdict
from gensim.models import Word2Vec
from sklearn.manifold import TSNE

%matplotlib inline
import matplotlib.pyplot as plt

# Helpers

In [None]:
topic2color = {"sport": "royalblue", 
               "colours": "red", 
               "countries": "yellow", 
               "cities": "green", 
               "food-drink": "chocolate", 
               "names": "aqua", 
               "family-rel": "darkviolet", 
               "vehicles": "dimgrey", 
               "astro": "deeppink", 
               "religion": "black"}

ENTRIES_PATH = "/content/gdrive/MyDrive/Colab_Notebooks/entries.json"

In [None]:
def get_dictionaries():
  with open("data/word2ind.json", 'r') as json_file:
      word2ind = json.load(json_file)

  with open(ENTRIES_PATH, 'r') as json_file:
    entries = json.load(json_file)

  return word2ind, entries

In [None]:
def do_plotting(data, topics=None, width=20, height=15):
  FS = (width, height)
  fig, ax = plt.subplots(figsize=FS)

  for key in data.keys():
    if topics != None and key not in topics:
      continue
    labels = [point[0] for point in data[key]]
    x_pos = [point[1] for point in data[key]]
    y_pos = [point[2] for point in data[key]]

    scatter = ax.scatter(x_pos, y_pos, c=topic2color[key], alpha=.7, label=key)

    ax.legend()

    for i, txt in enumerate(labels):
        ax.annotate(txt, (x_pos[i], y_pos[i]))

In [None]:
def get_plot_data(vectors, topics):
  tsne = TSNE(n_iter=1500, metric="cosine")
  embed = tsne.fit_transform(vectors)

  plot_data = defaultdict(list)
  for ind, topic in enumerate(topics):
    plot_data[topic[0]].append((topic[1], embed[ind,0], embed[ind, 1]))

  return plot_data

In [None]:
def get_vectors_topics(weights, is_gen=False):
  word2ind, entries = get_dictionaries()

  vectors = []
  topics = []

  for key, value in entries.items():
    if is_gen:
      vectors.append(gen_sg.wv[key])
    else:
      vectors.append(weights[word2ind[key]].numpy().tolist())
    
    topics.append((value, key))

  return vectors, topics

# Skip-Gram 

In [None]:
sg_weights = torch.load("../weights/sg300_ep15.pt")

sg_vec, sg_top = get_vectors_topics(sg_weights)

sg_plot_data = get_plot_data(sg_vec, sg_top)

In [None]:
do_plotting(sg_plot_data)

In [None]:
do_plotting(sg_plot_data, ["sport", "colours", "vehicles"], width=12, height=8)

# CBOW

In [None]:
cbow_weights = torch.load("../weights/cbow300_ep15.pt")

cbow_vec, cbow_top = get_vectors_topics(cbow_weights)

cbow_plot_data = get_plot_data(cbow_vec, cbow_top)

In [None]:
do_plotting(cbow_plot_data)

In [None]:
do_plotting(cbow_plot_data, ["sport", "vehicles", "colours"], width=12, height=8)

# Gensim Skip-Gram

In [None]:
gen_sg = Word2Vec.load("../weights/gen_sg300.txt")

gen_sg_vec, gen_sg_top = get_vectors_topics(gen_sg, True)

gen_sg_plot_data = get_plot_data(gen_sg_vec, gen_sg_top)

In [None]:
do_plotting(gen_sg_plot_data)

In [None]:
do_plotting(gen_sg_plot_data, ["colours", "sport", "vehicles"], width=12, height=8)

# Gensim CBOW

In [None]:
gen_cbow_model = Word2Vec.load("../weights/gen_cbow300.txt")

gen_cbow_vec, gen_cbow_top = get_vectors_topics(gen_cbow_model, True)

gen_cbow_plot_data = get_plot_data(gen_cbow_vec, gen_cbow_top)

In [None]:
do_plotting(gen_cbow_plot_data)

In [None]:
do_plotting(gen_cbow_plot_data, topics=["sport", "colours", "vehicles"], width=12, height=8)