In [40]:
import numpy as np
import gensim
import nltk
from nltk.corpus import stopwords
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import math
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
word_model = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin', binary=True)  
stop_words =  set(stopwords.words('english'))

In [114]:
def get_sentences(path):
    with open(path) as file:
        lines = file.read().replace('\n', '')
    
    return nltk.sent_tokenize(lines)

def filter_sentence(sentence):
    tokens = gensim.utils.simple_preprocess(sentence)#nltk.tokenize.word_tokenize(sentence)
    filtered = [w.lower() for w in tokens if not w in stop_words and w in word_model.vocab]
    return filtered

def get_sentence_embedding(sentence):
    vectors = np.zeros((len(sentence),300))
    for i, word in enumerate(sentence):
        embedding = word_model[word]
        assert np.all(np.isfinite(embedding))
        vectors[i] = embedding
    mean = np.mean(vectors, axis=0)
    return mean


In [26]:
def embed_text(path):
    sentences = get_sentences(path)
    embeddings = []
    for s in sentences:
        filtered = filter_sentence(s)
        if not len(filtered)==0:
            embedded = get_sentence_embedding(filtered)
            embeddings.append(embedded)
    
    embeddings_array = np.zeros((len(embeddings), 300))
    for i,e in enumerate(embeddings):
        embeddings_array[i] = e
    return embeddings_array

In [45]:
def reduce_embeddings(embeddings, components=150):
    
    covar =  PCA(n_components=components)
    reduced = covar.fit_transform(embeddings)
    return reduced

In [110]:
def mean_vector(vectors):
    #vectors are in a list
    assert not len(vectors)==0
    
    first = np.zeros(vectors[0].shape)
    for v in vectors:
        first+=v
    return first/(len(vectors))

def get_representatives(clusters, cluster_groups):
    
    representatives = []
    
    for i, cluster in enumerate(clusters):
        
        min_dist = np.inf
        index = -1
        for j, v in enumerate(cluster_groups[i]):
            dist = np.sum((v-cluster)**2)
            if dist < min_dist:
                min_dist=dist
                index = j
        representatives.append(cluster_groups[i][index])
    return representatives
    


In [111]:
def kmeans_clustering(vectors, num_clusters=10, iterations=10):
    
    #initialize clusters to random points
    clusters = vectors[np.random.choice(vectors.shape[0], num_clusters, replace=False)]
    new_clusters = clusters
    for j in range(iterations):
        clusters = new_clusters
        # list of vectors belonging to each cluster
        cluster_groups = [[] for cluster in clusters]
        for i in range(vectors.shape[0]):
            vector = vectors[i]

            min_dist = np.inf
            index = -1
            for i,cluster in enumerate(clusters):
                dist = np.sum( (vector-cluster)**2)
                if dist < min_dist:
                    min_dist = dist 
                    index = i
            
            # add to group of closest cluster
            cluster_groups[index].append(vector)
        print("Iteration:", j, [len(c) for c in cluster_groups])
        new_clusters = [mean_vector(group) for group in cluster_groups]
        
        diff = [new for i,new in enumerate(new_clusters) if not np.array_equal(new, clusters[i])]
        if len(diff)==0:
            break
    return clusters, cluster_groups

In [115]:
embeddings = reduce_embeddings(embed_text("bible.txt"), components=150)

In [106]:
clusters, groups = kmeans_clustering(embeddings, num_clusters=20, iterations=40)
reps = get_representatives(clusters, groups)

Iteration: 0 [275, 511, 306, 83, 19, 154, 58, 21, 7, 9, 143, 45, 198, 61, 1681, 168, 118, 152, 977, 181]
Iteration: 1 [325, 574, 384, 109, 42, 186, 170, 35, 13, 9, 189, 39, 213, 50, 1403, 137, 157, 186, 773, 173]
Iteration: 2 [263, 671, 404, 116, 44, 153, 214, 46, 33, 10, 303, 38, 289, 39, 1156, 96, 203, 206, 725, 158]
Iteration: 3 [186, 728, 429, 130, 50, 110, 212, 50, 46, 10, 388, 35, 320, 37, 1022, 78, 253, 232, 693, 158]
Iteration: 4 [144, 771, 444, 141, 54, 86, 212, 52, 55, 10, 413, 31, 337, 34, 961, 75, 289, 245, 649, 164]
Iteration: 5 [121, 782, 453, 156, 56, 76, 212, 52, 56, 10, 434, 32, 331, 32, 960, 75, 312, 242, 612, 163]
Iteration: 6 [109, 793, 452, 174, 53, 71, 209, 53, 60, 10, 449, 34, 324, 30, 981, 75, 320, 234, 575, 161]
Iteration: 7 [108, 811, 447, 182, 54, 67, 208, 55, 60, 10, 461, 34, 324, 27, 995, 75, 318, 219, 554, 158]
Iteration: 8 [109, 828, 445, 180, 54, 64, 207, 55, 58, 10, 467, 34, 320, 24, 1011, 75, 319, 211, 540, 156]
Iteration: 9 [109, 842, 447, 182, 54, 62

20
