In [17]:
from keras.models import Model
from keras.layers import Input, Dense, Reshape, concatenate, dot
from keras.layers.embeddings import Embedding
from keras.preprocessing.sequence import skipgrams
from keras.preprocessing import sequence

import urllib.request
import collections
import os
import zipfile
import pickle

import numpy as np
import tensorflow as tf
import itertools
import datetime
import sys

In [2]:
# load pkl file
def load_obj(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)
    
dictionary = load_obj('dictionary.pkl')
reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 

In [20]:
time_window = 14
window_size = 100
window = np.zeros(window_size,dtype=int)
sampling_table = sequence.make_sampling_table(554,sampling_factor = 1e-05)
word_target, labels, word_context = [], [], []

for l in range(40):
    sentences = []
    print(l,datetime.datetime.now())
    loadata = np.load('slice_data/0317data'+str(l)+'.npz')
    x=loadata['InputX3D']
    for p in range(len(x)):
        
        patient_sentences = []
        for i in range(len(x[p]-time_window)):
            xx = np.multiply(np.minimum(np.sum(x[p,i:i+time_window,],0),1),np.arange(553)+1)
            if np.sum(xx) > 0:
                patient_sentences.append(xx[xx != 0].tolist())
        k = sorted(patient_sentences)
        for key in list(k for k,_ in itertools.groupby(k)):
            sentences.extend(key)
            sentences.extend(window)
    
    local_couples, local_labels = skipgrams(sentences, 553, window_size=window_size, sampling_table=sampling_table)
    local_word_target, local_word_context = zip(*local_couples)
    labels.extend(local_labels)
    word_target.extend(local_word_target)
    word_context.extend(local_word_context)
    
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32")

0 2018-06-11 22:29:06.002176
1 2018-06-11 22:33:46.630312
2 2018-06-11 22:38:38.123163
3 2018-06-11 22:43:35.201380
4 2018-06-11 22:48:33.851698
5 2018-06-11 22:53:30.058311
6 2018-06-11 22:58:50.262898
7 2018-06-11 23:04:05.637833
8 2018-06-11 23:09:34.540393
9 2018-06-11 23:15:15.046381
10 2018-06-11 23:21:14.058447
11 2018-06-11 23:27:05.498806
12 2018-06-11 23:32:44.101493
13 2018-06-11 23:39:06.088560
14 2018-06-11 23:45:17.341725
15 2018-06-11 23:51:57.090448
16 2018-06-11 23:58:43.435506
17 2018-06-12 00:05:51.585762
18 2018-06-12 00:13:03.665627
19 2018-06-12 00:20:59.391889
20 2018-06-12 00:28:12.679750
21 2018-06-12 00:35:51.798155
22 2018-06-12 00:43:41.048729
23 2018-06-12 00:51:01.780385
24 2018-06-12 00:59:12.913222
25 2018-06-12 01:07:12.122198
26 2018-06-12 01:15:27.835303
27 2018-06-12 01:24:34.573588
28 2018-06-12 01:32:22.534787
29 2018-06-12 01:41:09.960389
30 2018-06-12 01:50:11.899594
31 2018-06-12 01:59:09.960091
32 2018-06-12 02:07:36.834486
33 2018-06-12 02:16:

In [19]:
sys.getsizeof(labels)

273897072

In [None]:
print(datetime.datetime.now())
sampling_table = sequence.make_sampling_table(554,sampling_factor = 1e-05)
couples, labels = skipgrams(sentences, 553, window_size=window_size, sampling_table=sampling_table)
print(datetime.datetime.now())

2018-06-11 20:04:12.573589


In [120]:
word_target, word_context = zip(*couples)
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32")

In [21]:
valid_size = 16     # Random set of words to evaluate similarity on.
valid_window = 100  # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
vocab_size = 553

In [22]:
vector_dim = 100
epochs = 5000000

# create some input variables
input_target = Input((1,))
input_context = Input((1,))

embedding = Embedding(553, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context)

# setup a cosine similarity operation which will be output in a secondary model
similarity = dot([target, context], normalize=True, axes=0)

# now perform the dot product operation to get a similarity measure
dot_product = dot([target, context], normalize=False, axes=1)
dot_product = Reshape((1,))(dot_product)
# add the sigmoid output layer
output = Dense(1, activation='sigmoid')(dot_product)
# create the primary training model
model = Model(input=[input_target, input_context], output=output)
model.compile(loss='binary_crossentropy', optimizer='rmsprop')

# create a secondary validation model to run our similarity checks during training
validation_model = Model(input=[input_target, input_context], output=similarity)




In [23]:
class SimilarityCallback:
    def run_sim(self):
        for i in range(valid_size):
            valid_word = reverse_dictionary[valid_examples[i]]
            top_k = 8  # number of nearest neighbors
            sim = self._get_sim(valid_examples[i])
            nearest = (-sim).argsort()[1:top_k + 1]
            log_str = 'Nearest to %s:' % valid_word
            for k in range(top_k):
                close_word = reverse_dictionary[nearest[k]]
                log_str = '%s %s,' % (log_str, close_word)
            print(log_str)

    @staticmethod
    def _get_sim(valid_word_idx):
        sim = np.zeros((vocab_size,))
        in_arr1 = np.zeros((1,))
        in_arr2 = np.zeros((1,))
        in_arr1[0,] = valid_word_idx
        for i in range(vocab_size):
            in_arr2[0,] = i
            out = validation_model.predict_on_batch([in_arr1, in_arr2])
            sim[i] = out
        return sim
sim_cb = SimilarityCallback()

In [24]:
arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
    idx = np.random.randint(0, len(labels)-1)
    arr_1[0,] = word_target[idx]
    arr_2[0,] = word_context[idx]
    arr_3[0,] = labels[idx]
    loss = model.train_on_batch([arr_1, arr_2], arr_3)
    if cnt % 10000 == 0:
        print("Iteration {}, loss={}".format(cnt, loss))
    if cnt % 100000 == 0:
        print(datetime.datetime.now())
     #   sim_cb.run_sim()

Iteration 0, loss=0.7138327956199646
2018-06-12 03:22:36.791663
Iteration 10000, loss=0.5671801567077637
Iteration 20000, loss=3.672576665878296
Iteration 30000, loss=0.0007565499399788678
Iteration 40000, loss=2.4368042945861816
Iteration 50000, loss=0.20076459646224976
Iteration 60000, loss=0.04349581152200699
Iteration 70000, loss=0.15330541133880615
Iteration 80000, loss=0.8756828308105469
Iteration 90000, loss=14.332947731018066
Iteration 100000, loss=0.1488679051399231
2018-06-12 03:25:59.701746
Iteration 110000, loss=1.4145479202270508
Iteration 120000, loss=1.4139857292175293
Iteration 130000, loss=0.2694445252418518
Iteration 140000, loss=0.2684411406517029
Iteration 150000, loss=0.10923115164041519
Iteration 160000, loss=0.00037677225191146135
Iteration 170000, loss=0.0016777870478108525
Iteration 180000, loss=0.26924002170562744
Iteration 190000, loss=0.1768321394920349
Iteration 200000, loss=0.3044373393058777
2018-06-12 03:29:20.094175
Iteration 210000, loss=0.285367667675

Iteration 1770000, loss=6.327513694763184
Iteration 1780000, loss=1.192093321833454e-07
Iteration 1790000, loss=2.7557571229408495e-06
Iteration 1800000, loss=0.2772718667984009
2018-06-12 04:23:32.913686
Iteration 1810000, loss=0.09646811336278915
Iteration 1820000, loss=0.00014606586773879826
Iteration 1830000, loss=1.6344839334487915
Iteration 1840000, loss=8.398317337036133
Iteration 1850000, loss=0.08305524289608002
Iteration 1860000, loss=0.1779700368642807
Iteration 1870000, loss=0.08573295176029205
Iteration 1880000, loss=1.192093321833454e-07
Iteration 1890000, loss=1.0320583581924438
Iteration 1900000, loss=0.0005526742897927761
2018-06-12 04:27:02.749920
Iteration 1910000, loss=1.192093321833454e-07
Iteration 1920000, loss=0.011460645124316216
Iteration 1930000, loss=4.204106330871582
Iteration 1940000, loss=0.03797717019915581
Iteration 1950000, loss=6.0679394664475694e-05
Iteration 1960000, loss=14.55609130859375
Iteration 1970000, loss=0.0932709202170372
Iteration 1980000

Iteration 3510000, loss=2.526118755340576
Iteration 3520000, loss=1.4572294276149478e-05
Iteration 3530000, loss=1.192093321833454e-07
Iteration 3540000, loss=9.179154403682332e-06
Iteration 3550000, loss=0.003248857334256172
Iteration 3560000, loss=0.00031461307662539184
Iteration 3570000, loss=0.40845948457717896
Iteration 3580000, loss=0.005628319922834635
Iteration 3590000, loss=0.6214299201965332
Iteration 3600000, loss=15.942384719848633
2018-06-12 05:24:46.177815
Iteration 3610000, loss=5.545942783355713
Iteration 3620000, loss=0.0006693457835353911
Iteration 3630000, loss=0.008198600262403488
Iteration 3640000, loss=1.192093321833454e-07
Iteration 3650000, loss=11.99709415435791
Iteration 3660000, loss=0.00031464279163628817
Iteration 3670000, loss=0.008647171780467033
Iteration 3680000, loss=1.0000001537946446e-07
Iteration 3690000, loss=13.377436637878418
Iteration 3700000, loss=0.5926105976104736
2018-06-12 05:28:01.391837
Iteration 3710000, loss=1.311302526119107e-06
Iterat

In [25]:
model.save('skipgram1.h5')

In [139]:
# 60000 patients' data, 500000 epochs, 100 vector dimension
#model.save('skipgram.h5')
from keras.models import load_model
model = load_model('skipgram.h5')
a1 = model.layers[2].get_weights()[0]
a1

array([[-0.00770117,  0.00796707,  0.02193469, ...,  0.02107811,
         0.03270494,  0.04702425],
       [ 1.275105  ,  1.24389184,  1.26958609, ...,  1.2521956 ,
        -1.19423676,  1.1561631 ],
       [-1.85333097, -1.7782414 , -1.83254027, ..., -1.82702792,
         1.85137808, -1.85270727],
       ..., 
       [-5.20829821, -4.9914012 , -5.27518845, ..., -5.30610037,
         5.27674961, -5.27119541],
       [ 0.2422433 ,  0.14661992,  0.20882246, ...,  0.16810475,
        -0.21586379,  0.24917722],
       [ 0.04168575,  0.14598228,  0.07426101, ...,  0.14002229,
        -0.07735883,  0.10723677]], dtype=float32)

In [26]:
# 400000 patients' data, 5000000 epochs, 100 vector dimension
#model.save('skipgram1.h5')
from keras.models import load_model
model = load_model('skipgram.h5')
a1 = model.layers[2].get_weights()[0]
a1

array([[-0.00770117,  0.00796707,  0.02193469, ...,  0.02107811,
         0.03270494,  0.04702425],
       [ 1.275105  ,  1.24389184,  1.26958609, ...,  1.2521956 ,
        -1.19423676,  1.1561631 ],
       [-1.85333097, -1.7782414 , -1.83254027, ..., -1.82702792,
         1.85137808, -1.85270727],
       ..., 
       [-5.20829821, -4.9914012 , -5.27518845, ..., -5.30610037,
         5.27674961, -5.27119541],
       [ 0.2422433 ,  0.14661992,  0.20882246, ...,  0.16810475,
        -0.21586379,  0.24917722],
       [ 0.04168575,  0.14598228,  0.07426101, ...,  0.14002229,
        -0.07735883,  0.10723677]], dtype=float32)