In [94]:
import psycopg2

from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import json
import spacy
import numpy as np
from sklearn.cluster import KMeans
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

from bokeh.io import output_notebook, show
from bokeh.models import ColumnDataSource, OpenURL, TapTool
from bokeh.plotting import figure, ColumnDataSource
from bokeh.palettes import d3, brewer
from bokeh.layouts import row, column
output_notebook()

## Document loading

In [95]:
with open('../../../../IKON-backend/assets/IKON-backend-config/secrets/postgres_password') as password_file:
    password = password_file.read().strip()
conn = psycopg2.connect(dbname="ikon", user="ikonuser", password=password, port=5432, host='localhost')
cur = conn.cursor()
cur.execute("SELECT project_abstract FROM projects LIMIT 10000;")
texts = [text[0] for text in cur.fetchall() if text[0]]

In [96]:
nlp = spacy.load('de')
def lemmatize(text):
    return [token.lemma_ for token in nlp(text)]

## Training

In [97]:
documents = [TaggedDocument(lemmatize(doc), [i]) for i, doc in enumerate(texts)]
model = Doc2Vec(documents, vector_size=20, window=2, min_count=1, workers=4, epochs=40)

In [98]:
%time model.train(documents, total_examples=model.corpus_count, epochs=model.epochs)

CPU times: user 2min 8s, sys: 3.98 s, total: 2min 12s
Wall time: 51.5 s


## Embedding

In [113]:
embedded = model.docvecs.vectors_docs
clusters = KMeans(n_clusters=5).fit(embedded)
lda = LinearDiscriminantAnalysis(n_components=2)
plane = lda.fit(embedded, clusters.labels_).transform(embedded)

In [114]:
# configure bokeh plot  
colours = d3['Category10'][5]
source = ColumnDataSource(data=dict(
    x=plane[:, 0],
    y=plane[:, 1],
    colours=np.array(colours)[clusters.labels_],
    labels=clusters.labels_
))

# scatterplot
scatter = figure(plot_width=800, plot_height=800, title=None, toolbar_location="below", tools='tap,pan,wheel_zoom,save')
scatter.scatter('x', 'y', size=10,color='colours', legend='labels', source=source)

show(scatter)

In [None]:
print