In [2]:
import sys
import pickle
import numpy as np
sys.argv.append('--dynet_mem')
sys.argv.append('5000')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import _gdynet as dy
dy.init()


In [3]:
v = '2'
with open('mimic/'+v+'.p',mode='rb') as fp:
    x,y,all_words,all_labels = pickle.load(fp)

    
desc_dict = {}
for line in open('mimic/2/ICD9_descriptions'):
    line = line.strip().split('\t')
    desc_dict[line[0]] = line[1]
    
train_set = list(zip(x[:-2282], y[:-2282]))
test_set = list(zip(x[-2282:], y[-2282:]))
VOCAB_SIZE = len(all_words)

In [4]:
embeddings_size = 50
word_gru_layers = 1
word_gru_state_size = 50
sent_gru_layers = 1
sent_gru_state_size = 50

model = dy.Model()

embeddings = model.add_lookup_parameters((VOCAB_SIZE, embeddings_size))
word_gru_builder = dy.GRUBuilder(word_gru_layers, embeddings_size, word_gru_state_size, model)
word_attention_w1 = model.add_parameters((word_gru_state_size, word_gru_state_size))
#word_w2 =  
word_attention_v = model.add_parameters((1, word_gru_state_size))

sent_gru_builder = dy.GRUBuilder(sent_gru_layers, word_gru_state_size, sent_gru_state_size, model)

sent_attention_w1s = []
sent_attention_vs = []
classifier_ws = []
classifier_bs = []

for _ in all_labels:
    sent_attention_w1s.append(model.add_parameters((sent_gru_state_size, sent_gru_state_size)))
    #word_w2 =  
    sent_attention_vs.append(model.add_parameters((1, sent_gru_state_size)))
    
    classifier_ws.append(model.add_parameters((2, sent_gru_state_size)))
    classifier_bs.append(model.add_parameters((2)))
    

In [5]:
def get_probs(doc):
    dy.renew_cg()
    
    encoded_sents = []
    for sent in doc:
        #e = dy.parameter(embeddings)
        embedded_sent = [embeddings[word] for word in sent]

        states =  word_gru_builder.initial_state().add_inputs(embedded_sent)
        rnn_outputs = [s.output() for s in states]
        
        w = dy.parameter(word_attention_w1)
        v = dy.parameter(word_attention_v)

        attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
        attention_weights = dy.softmax(dy.concatenate(attention_weights))
        
        output_vector = dy.esum(
            [vector * attention_weight for vector, attention_weight in zip(rnn_outputs, attention_weights)])
        encoded_sents.append(output_vector)
    
    states = sent_gru_builder.initial_state().add_inputs(encoded_sents)
    rnn_outputs = [s.output() for s in states]
    all_probs = []
    for i in range(len(all_labels)):
        w = dy.parameter(sent_attention_w1s[i])
        v = dy.parameter(sent_attention_vs[i])
        attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
        attention_weights = dy.softmax(dy.concatenate(attention_weights))
        
        output_vector = dy.esum(
            [vector * attention_weight for vector, attention_weight in zip(rnn_outputs, attention_weights)])
        
        w = dy.parameter(classifier_ws[i])
        b = dy.parameter(classifier_bs[i])
        probs = dy.softmax(w * output_vector + b)
        all_probs.append(probs)
    return all_probs


In [6]:
def train(model, train_set, epochs = 20):
    def get_loss(probs, trues):
        return dy.esum([-dy.log(dy.pick(prob, true)) for prob, true in zip(probs, trues)])
    
    trainer = dy.AdamTrainer(model)
    for e in range(epochs):
        losses = []
        for i, training_example in enumerate(train_set):
            doc, labels = training_example

            loss = get_loss(get_probs(doc), labels)
            losses.append(loss.value())
            loss.backward()
            trainer.update()

            # Accumulate average losses over training to plot
            if i%(int(len(train_set)/100)) == 0:
                print('!',np.mean(losses), end='')
                losses = []
        print('epoch', e, 'done!')


In [None]:
train(model, train_set, 2)

In [7]:
def eval(model, test_set):
    all_preds = []
    for test_example in test_set:
        doc, labels = test_example
        labels = set()
        
        probs = get_probs(doc)
        for i, prob in enumerate(probs):
            prob = prob.value()
            if prob[1] > prob[0]:
                labels.add(i)
        all_preds.append(labels)
    
    gold_y = []
    for test_example in test_set:
        doc, g = test_example
        labels = set()
        for i, l in enumerate(g):
            if l > 0.5:
                labels.add(i)
        gold_y.append(labels)
    
    tp, fp, fn =0., 0., 0.      
    for pred, gold in zip(all_preds, gold_y):
        tp += len(pred.intersection(gold))
        fp += len(pred-gold)
        fn += len(gold-pred)
    prec = tp/(tp+fp)
    print(prec)
    recal = tp/(tp+fn)
    print(recal)
    f = 2*(prec*recal)/(prec+recal)
    print(f)
    
eval(model, test_set)

KeyboardInterrupt: 

In [8]:
def get_important_words_and_sents(doc, i):
    dy.renew_cg()
    
    encoded_sents = []
    words_attention_weights = []
    for sent in doc:
        words_attention_weights.append([])
        
        embedded_sent = [embeddings[word] for word in sent]

        states =  word_gru_builder.initial_state().add_inputs(embedded_sent)
        rnn_outputs = [s.output() for s in states]
        
        w = dy.parameter(word_attention_w1)
        v = dy.parameter(word_attention_v)

        attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
        attention_weights = dy.softmax(dy.concatenate(attention_weights))
        
        words_attention_weights[-1].append(attention_weights)
        output_vector = dy.esum(
            [vector * attention_weight for vector, attention_weight in zip(rnn_outputs, attention_weights)])
        encoded_sents.append(output_vector)
    
    states = sent_gru_builder.initial_state().add_inputs(encoded_sents)
    rnn_outputs = [s.output() for s in states]
    
    
    w = dy.parameter(sent_attention_w1s[i])
    v = dy.parameter(sent_attention_vs[i])
    attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
    attention_weights = dy.softmax(dy.concatenate(attention_weights))

    return attention_weights, words_attention_weights

def print_words(x, sent_weights, words_attention_weights):
    best_sent_index = np.argmax(sent_weights.npvalue())
    best_word = np.argmax(words_attention_weights[best_sent_index][0].npvalue())
    print('best sent:', ' '.join([all_words[word] for word in x[best_sent_index]]))
    print('best word:', all_words[x[best_sent_index][best_word]])

def analyze_example(x, y):
    gold = set()
    for i, l in enumerate(y):
        if l > 0.5:
            gold.add(i)
    
    predicted = set()  
    probs = get_probs(x)
    for i, prob in enumerate(probs):
        prob = prob.value()
        if prob[1] > prob[0]:
            predicted.add(i)
    
    print('######TP######')
    print(predicted & gold)
    print([desc_dict[all_labels[label]] for label in predicted & gold])
    print('######FP######')
    print(predicted - gold)
    print([desc_dict[all_labels[label]]for label in predicted - gold])
    print('######FN######')
    print(gold - predicted)
    print([desc_dict[all_labels[label]]for label in gold - predicted])
    print('---------')
    for label in predicted | gold:
        found = label in predicted 
        print('label:', desc_dict[all_labels[label]])
        print('found:', found)
        print('gold:', label in gold)
        sent_weights, words_attention_weights = get_important_words_and_sents(x, label)
        print_words(x, sent_weights, words_attention_weights)
        print('---------')


In [12]:
model.load('hagru_MIMIC2')

In [15]:
train_words = set()
train_labels = set()
for x, y in train_set:
    for sent in x:
        train_words |= set(sent)
    for label, val in enumerate(y):
        if val:
            train_labels.add(label)    

In [None]:
def get_sent_importance(doc):
    dy.renew_cg()
    
    encoded_sents = []
    words_attention_weights = []
    for sent in doc:
        words_attention_weights.append([])
        
        embedded_sent = [embeddings[word] for word in sent]

        states =  word_gru_builder.initial_state().add_inputs(embedded_sent)
        rnn_outputs = [s.output() for s in states]
        
        w = dy.parameter(word_attention_w1)
        v = dy.parameter(word_attention_v)

        attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
        attention_weights = dy.softmax(dy.concatenate(attention_weights))
        
        words_attention_weights[-1].append(attention_weights.npvalue())
    return words_attention_weights

def get_hex_color(score):
    scale = 1./4
    decimal = int(255 * (1-((1-scale)*score+scale)))
    hexa = ("0x%0.2X" % decimal)[-2:]
    return '#'+hexa*3

def pp_word(word):
    if word in train_words:
        return all_words[word]
    return '<u>'+all_words[word]+'</u>'

def gen_sents_html(x):
    out_html = "\n"

    for i,(scores,sent) in enumerate(zip(get_sent_importance(x), x)):
        scores = scores[0]
        if len(sent) > 10:
            out_html += "<tr><td id=\"sent"+str(i)+"\"></td><td>"
        else:
            out_html += "<tr style=\"display:none;\"><td id=\"sent"+str(i)+"\"></td><td>"
        for score, word in zip(scores, sent):
            out_html +="<font color=\""+get_hex_color(score)+"\">"+pp_word(word)+" </font>"
        out_html+= "</td></tr>\n"
    return out_html

def get_labels_types(x, y):
    gold = set()
    for i, l in enumerate(y):
        if l > 0.5:
            gold.add(i)
    
    predicted = set()  
    probs = get_probs(x)
    for i, prob in enumerate(probs):
        prob = prob.value()
        if prob[1] > prob[0]:
            predicted.add(i)
    
    return list(predicted & gold), list(predicted - gold), list(gold - predicted)

def get_sents_importance(doc, i):
    dy.renew_cg()
    
    encoded_sents = []
    for sent in doc:
        
        embedded_sent = [embeddings[word] for word in sent]

        states =  word_gru_builder.initial_state().add_inputs(embedded_sent)
        rnn_outputs = [s.output() for s in states]
        
        w = dy.parameter(word_attention_w1)
        v = dy.parameter(word_attention_v)

        attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
        attention_weights = dy.softmax(dy.concatenate(attention_weights))
        
        output_vector = dy.esum(
            [vector * attention_weight for vector, attention_weight in zip(rnn_outputs, attention_weights)])
        encoded_sents.append(output_vector)
    
    states = sent_gru_builder.initial_state().add_inputs(encoded_sents)
    rnn_outputs = [s.output() for s in states]
    
    w = dy.parameter(sent_attention_w1s[i])
    v = dy.parameter(sent_attention_vs[i])
    attention_weights = [v*dy.tanh(w*o) for o in rnn_outputs]
    attention_weights = dy.softmax(dy.concatenate(attention_weights))

    return attention_weights.value()

def pp_label(label):
    if label in train_labels:
        return desc_dict[all_labels[label]]
    return '<u>'+desc_dict[all_labels[label]]+'</u>'
    
def gen_labels_html(x, y):
    tp, fp, fn = get_labels_types(x, y)
    label_to_id = {label:i for i, label in enumerate(tp+fp+fn)}
    
    tps_html = ""
    for label in tp:
        tps_html += "<button class=\"btn btn-primary\" type=\"button\" onclick=\"myFunction("
        tps_html += str(label_to_id[label])+")\">"
        tps_html += pp_label(label)+"</button>\n"
    
    fps_html = ""
    for label in fp:
        fps_html += "<button class=\"btn btn-primary\" type=\"button\" onclick=\"myFunction("
        fps_html += str(label_to_id[label])+")\">"
        fps_html += pp_label(label)+"</button>\n"
        
    fns_html = ""
    for label in fn:
        fns_html += "<button class=\"btn btn-primary\" type=\"button\" onclick=\"myFunction("
        fns_html += str(label_to_id[label])+")\">"
        fns_html += pp_label(label)+"</button>\n"
        
    label_names_html = str([desc_dict[all_labels[label]] for label in tp+fp+fn])+'\n'
    weights_list_html = str([get_sents_importance(x, label) for label in tp+fp+fn])+'\n'
    
    return tps_html ,fps_html ,fns_html, label_names_html,weights_list_html
    
def gen_visualize_html(x, y, i):
    html = ' '.join(open('visualizer_html_template').readlines())
    sents_html = gen_sents_html(x)
    tps_html ,fps_html ,fns_html, label_names_html, weights_list_html = gen_labels_html(x,y)
    
    html = html.replace('###WEIGHTS###', weights_list_html)
    html = html.replace('###LABEL_NAMES###', label_names_html)
    html = html.replace('###SENTS###', sents_html)
    html = html.replace('###TPS###', tps_html)
    html = html.replace('###FPS###', fps_html)
    html = html.replace('###FNS###', fns_html)
    html = html.replace('###PREV###', str(i-2)+'.html')
    html = html.replace('###NEXT###', str(i)+'.html')
    return html

for i, item in enumerate(test_set):
    html = gen_visualize_html(item[0], item[1], i+1)
    print('html/'+str(i)+'.html')
    with open('html/'+str(i)+'.html', 'w',encoding='utf-8') as fp:
        fp.write(html)
        

html/0.html
html/1.html
html/2.html
html/3.html
html/4.html
html/5.html
html/6.html
html/7.html
html/8.html
html/9.html
html/10.html
html/11.html
html/12.html
