In [1]:
import os
import json
from time import time
from collections import Counter

import numpy as np

import torch
from allennlp.data.vocabulary import Vocabulary

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
import matplotlib.pyplot as plt
from matplotlib import offsetbox
from sklearn import (manifold, datasets, decomposition, ensemble,
                     discriminant_analysis, random_projection, neighbors)
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN, KMeans

from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import HoverTool, PointDrawTool
from bokeh.models import ColumnDataSource
from bokeh import palettes
import bokeh.models as bmo

output_notebook()

## Functions

In [53]:
def plot_embedding(output_emb, output_desc, ref_emb=None, ref_desc=None, output_class=None):
    
    sc = StandardScaler()
    rp = random_projection.SparseRandomProjection(
        n_components=2, random_state=42)
    if ref_emb is None:
        output_emb = sc.fit_transform(output_emb)
        output_emb = rp.fit_transform(output_emb)
    else:
        tmp_emb = sc.fit_transform(np.concatenate((output_emb, ref_emb), 0))
        tmp_emb = rp.fit_transform(tmp_emb)
        output_emb = tmp_emb[:len(output_desc)]
        ref_emb = tmp_emb[len(output_desc):]
    
    kwargs = {}
    output_data = {
        'x': output_emb[:,0],
        'y': output_emb[:,1],
        'desc': output_desc,
    }
    
    if output_class is not None:
        color_map = bmo.LinearColorMapper(
            high=max(output_class),
            low=0, low_color='black',
            palette=palettes.Turbo256)

        kwargs['color'] = {'field': 'labels', 'transform': color_map}
        output_data['labels'] = output_class

    output_data = ColumnDataSource(data=output_data)
    
    if ref_emb is not None and ref_desc is not None:
        ref_emb = StandardScaler().fit_transform(ref_emb)

        ref_data = {
            'x': ref_emb[:,0],
            'y': ref_emb[:,1],
            'desc': ref_desc
        }
        ref_data = ColumnDataSource(data=ref_data)

    hover = HoverTool(
            tooltips=[
                ("index", "$index"),
                ("desc", "@desc"),
            ])

    p = figure(plot_width=700, plot_height=600,
               tools=[hover,"pan,reset,box_zoom"],
               title="Mouse over the dots")
    
    point_o = p.circle('x', 'y', size=10, source=output_data,
            **kwargs)
    text_o = p.text('x', 'y', text='desc', source=output_data,
                    x_offset=-10, text_font_size={'value': '8pt'})
    
    if ref_emb is not None and ref_desc is not None:
        point_r = p.inverted_triangle('x', 'y', size=15, fill_color='green', alpha=0.5, source=ref_data)
        text_r = p.text('x', 'y', text='desc', source=ref_data,
                x_offset=-10, text_font_size={'value': '8pt'})

    # Allow to move points
    pointdraw = PointDrawTool(renderers=[point_o, text_o])
    p.add_tools(pointdraw)
    
    show(p)

In [71]:
def i2t(i):
    # Index to token
    return v.get_token_from_index(i)
    
def t2i(t):
    # Token to index
    return v.get_token_index(t)

def embed(i, em):
    # Return index or word embedding from `em`
    if isinstance(i, str):
        i = t2i(i)
    return em[i]

def embed_kp(kp, em):
    # Split a string, embed every word and the mean
    tmp = [embed(w, em).unsqueeze(0) for w in kp.split(' ')]
    return torch.cat(tmp, 0).max(0)[0]

def embed_kps(kps, em):
    # Create a tensor with the embeddings of multiple keyphrases
    tmp = [embed_kp(v, em).unsqueeze(0) for kw in kps for v in kw]
    return torch.cat(tmp, 0)

def closest(token, em, n=10):
    # Return the `n` closest words from `token` in `em`
    dis, idx = torch.cosine_similarity(embed(token, em), em, -1).topk(n)
    return [(i2t(i.item()), d) for d, i in zip(dis, idx)]

## Load data

In [34]:
embeddings = torch.load('model_state_epoch_4.th', map_location='cpu')['_target_embedder.weight']
#sembeddings = torch.load('model_state_epoch_4.th', map_location='cpu')['_source_embedder.token_embedder_tokens.weight']
v = Vocabulary.from_files('vocabulary')

In [6]:
with open('KP20k.test.json') as f:
    output = json.load(f)
with open('/Users/ygorgallina/Documents/Prog/ake-datasets/datasets/KP20k/references/test.author.json') as f:
    ref = json.load(f)

In [66]:
doc_id = '00001'

In [72]:
output_emb = embed_kps(output[doc_id], embeddings)
output_lab = [v for kw in output[doc_id] for v in kw]

In [73]:
ref_emb = embed_kps(ref[doc_id], embeddings)
ref_lab = [v for kw in ref[doc_id] for v in kw]

In [74]:
# Cluster
output_emb_scaled = StandardScaler().fit_transform(output_emb)
classes = DBSCAN(eps=4, min_samples=1).fit(output_emb_scaled).labels_.tolist()

In [75]:
print(Counter(classes))
plot_embedding(output_emb, output_lab, ref_emb=ref_emb, ref_desc=ref_lab, output_class=classes)

Counter({2: 13, 0: 1, 1: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1})


In [40]:
clusters = {}
for i, c in enumerate(classes):
    if c not in clusters:
        clusters[c] = []
    clusters[c].append(output_lab[i])
clusters

{0: ['classification'],
 1: ['sequential constraint optimization'],
 2: ['principal component analysis'],
 3: ['latent variable model'],
 4: ['generative models'],
 5: ['support vector machine'],
 6: ['latent variables'],
 7: ['handwritten digit recognition'],
 8: ['support vector machine learning',
  'support vector machine learning algorithm'],
 9: ['machine learning'],
 10: ['data-point'],
 11: ['constraint optimization'],
 12: ['perceptron learning'],
 13: ['satellite image recognition'],
 14: ['perceptron learning algorithm'],
 15: ['digit recognition'],
 16: ['multiple generative models'],
 17: ['image recognition'],
 18: ['pattern recognition'],
 19: ['learning algorithm'],
 20: ['satellite images'],
 21: ['learning'],
 22: ['support vector regression'],
 23: ['use'],
 24: ['standard perceptron learning'],
 25: ['principal component analysis ( pca )'],
 26: ['latent component analysis'],
 27: ['variable model'],
 28: ['image classification'],
 29: ['latent model'],
 30: ['model'