In [1]:
from __future__ import print_function
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import re
import collections
import random
from time import time

from gensim.models import Word2Vec
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, FastICA

import data_handler as dh
import semeval_data_helper as sdh
# plot settings
% matplotlib inline
# print(plt.rcParams.keys())
plt.rcParams['figure.figsize'] = (16,9)

In [None]:
# reload(sdh)

In [None]:
# print(DH.max_seq_len)
# paths, targets = DH.readable_data(show_dep=True)
# for p, t in zip(paths, targets) :
#     t = t.split(", ")
#     print("%s (%s) %s" % (t[0], p, t[1]))
# print('<X>' in DH.vocab)

In [2]:
from relembed import RelEmbed

In [3]:
# reload(dh)
DH = dh.DataHandler('data/semeval_train_sdp_8000', valid_percent=10)

Creating Data objects...
Done creating Data objects
7999 total examples :: 7199 training : 800 valid (90:10 split)
Vocab size: 22683 Dep size: 50


In [None]:
%time
# load the pretrained word embeddings
fname = 'data/GoogleNews-vectors-negative300.bin'
word2vec = Word2Vec.load_word2vec_format(fname, binary=True)

In [None]:
# reload(sdh)
train, valid, test, label2int, int2label = sdh.load_semeval_data()
num_classes = len(int2label.keys())

In [None]:
# %%bash
# git pull

In [None]:
# convert the semeval data to indices under the wiki vocab:
train['sdps'] = DH.sentences_to_sequences(train['sdps'])
valid['sdps'] = DH.sentences_to_sequences(valid['sdps'])
test['sdps'] = DH.sentences_to_sequences(test['sdps'])
    
train['targets'] = DH.sentences_to_sequences(train['targets'])
valid['targets'] = DH.sentences_to_sequences(valid['targets'])
test['targets'] = DH.sentences_to_sequences(test['targets'])

In [None]:
max_seq_len = max([len(path) for path in train['sdps']+valid['sdps']+test['sdps']])
print(max_seq_len, DH.max_seq_len)
DH.max_seq_len = max_seq_len

In [None]:
# the embedding matrix is started of as random uniform [-1,1]
# then we replace everything but the OOV tokens with the approprate google vector
word_embeddings = np.random.uniform(low=-1., high=1., size=[DH.vocab_size, 300]).astype(np.float32)
num_found = 0
for i, token in enumerate(DH.vocab):
    if token in word2vec:
        word_embeddings[i] = word2vec[token]
        num_found += 1
print("%i / %i pretrained" % (num_found, DH.vocab_size))

In [None]:
config = {
    'max_num_steps':DH.max_seq_len,
    'word_embed_size':150,
    'dep_embed_size':25,
    'vocab_size':DH.vocab_size,
    'dep_vocab_size':DH.dep_size,
    'num_predict_classes':num_classes,
    'pretrained_word_embeddings':None, #word_embeddings,
    'max_grad_norm':3.,
    'model_name':'drnn_wiki_semeval_w2v',
    'checkpoint_prefix':'checkpoints/',
    'summary_prefix':'tensor_summaries/'
}
try:
    tf.reset_default_graph()
except:
    pass
try:
    tf.get_default_session().close()
except:
    pass
drnn = RelEmbed(config)
print(drnn)

In [None]:
def run_validation_test(num_nearby=20):
    valid_phrases, valid_targets , _, valid_lens = DH.validation_batch()
    random_index = int(random.uniform(0, len(valid_lens)))
    query_phrase = valid_phrases[random_index]
    query_len = valid_lens[random_index]
    query_target = valid_targets[random_index]
    padded_qp = np.zeros([DH.max_seq_len, 2]).astype(np.int32)
    padded_qp[:len(query_phrase), 0] = [x[0] for x in query_phrase]
    padded_qp[:len(query_phrase), 1] = [x[1] for x in query_phrase]    
    dists, phrase_idx = drnn.validation_phrase_nearby(padded_qp, query_len, valid_phrases, valid_lens)
    print("="*80)
    print("Top %i closest phrases to <%s> '%s' <%s>" 
          % (num_nearby, DH.vocab_at(query_target[0]), 
             DH.sequence_to_sentence(query_phrase, query_len), 
             DH.vocab_at(query_target[1])))
    for i in range(num_nearby):
        dist = dists[i]
        phrase = valid_phrases[phrase_idx[i]]
        len_ = valid_lens[phrase_idx[i]]
        target = valid_targets[phrase_idx[i]]
        print("%i: %0.3f : <%s> '%s' <%s>" 
              % (i, dist, DH.vocab_at(target[0]),
                 DH.sequence_to_sentence(phrase, len_),
                 DH.vocab_at(target[1])))
    print("="*80)
#     drnn.save_validation_accuracy(frac_correct)

In [None]:
def time_left(num_epochs, num_steps, fit_time, nearby_time, start_time, nearby_mod):
    total = num_epochs*num_steps*fit_time + ((num_epochs*num_steps)/float(nearby_mod))*nearby_time
    return total - (time() - start_time)

# Unsupervised Training

In [None]:
# hyperparameters
num_epochs = 1
batch_size =50
neg_per = 25
num_nearby = 50
nearby_mod = 50
sample_power = .75
DH.scale_vocab_dist(sample_power)

# bookkeeping
num_steps = DH.num_steps(batch_size)
total_step = 1
save_interval = 30 * 60 # half hour in seconds
save_time = time()

#timing stuff
start = time()
fit_time = 0
nearby_time = 0

for epoch in range(num_epochs):
    offset = 0 #if epoch else 400
    DH.shuffle_data()
    for step , batch in enumerate(DH.batches(batch_size, offset=offset, neg_per=neg_per)):
        if not step: step = offset
        t0 = time()
        loss = drnn.partial_unsup_fit(*batch)
        fit_time = (fit_time * float(total_step) +  time() - t0) / (total_step + 1) # running average
        if step % 10 == 0:
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            left = time_left(num_epochs, num_steps, fit_time, nearby_time, start, nearby_mod)
            ml,sl = divmod(left, 60)
            hl,ml = divmod(ml, 60)
            pps = batch_size*(neg_per + 1) / fit_time 
            print("(%i:%i:%i) step %i/%i, epoch %i Training Loss = %1.5f :: %0.3f phrases/sec :: (%i:%i:%i) hours left" 
                  % (h,m,s, step, num_steps, epoch, loss, pps, hl, ml, sl))
        if (total_step-1) % nearby_mod == 0: # do one right away so we get a good timing estimate
            t0 = time()
            run_validation_test(num_nearby) # check out the nearby phrases in the validation set
            valid_loss = drnn.validation_loss(*DH.validation_batch())
            print("Validation loss: %0.4f" % valid_loss)
            nearby_time = (nearby_time * float(total_step) + time() - t0) / (total_step + 1) # running average

        if (time() - save_time) > save_interval:
            print("Saving model...")
            drnn.checkpoint()
            save_time = time()
        total_step +=1
drnn.checkpoint()

In [None]:
drnn.checkpoint()

In [None]:
# # test the embeddings

# ### VALID ###
# # valid_phrases, valid_targets, _, valid_lens = DH.validation_batch()
# # phrase_embeds, target_embeds = drnn.embed_phrases_and_targets(valid_phrases, valid_targets, valid_lens)
# # phrase_labels, target_labels = DH.readable_data(valid=True)

# ### TRAIN ###
# train_phrases, train_targets, _, train_lens = DH.batches(500, neg_per=0, offset=0).next()
# phrase_embeds, target_embeds = drnn.embed_phrases_and_targets(train_phrases, train_targets, train_lens)
# phrase_labels, target_labels = DH.readable_data(show_dep=False, valid=False)
        
# phrase_embeds /= np.sqrt(np.sum(phrase_embeds**2, 1, keepdims=True))
# target_embeds /= np.sqrt(np.sum(target_embeds**2, 1, keepdims=True))

In [None]:
# ### JOINT ###
# start = 0
# stride = 40
# end = start + stride

# lowd = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
# # lowd = PCA(n_components=2)

# joint_embeds = np.vstack([phrase_embeds[start:end], target_embeds[start:end]])
# joint_2d = lowd.fit_transform(joint_embeds)
# phrase_2d, target_2d = joint_2d[:stride], joint_2d[stride:]

# fig, ax = plt.subplots(figsize=(20,16))
# for i, label in enumerate(phrase_labels[start:end]):
#     label = "%i: %s" % (i, label)
#     x, y = phrase_2d[i,:]
#     ax.scatter(x, y, color='b')
#     ax.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',
#                    ha='right', va='bottom')
# for i, label in enumerate(target_labels[start:end]):
#     label = "%i: %s" % (i, label)
#     x, y = target_2d[i,:]
#     ax.scatter(x, y, color='r')
#     ax.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',
#                    ha='right', va='bottom')

In [None]:
# ### PHRASE ONLY ###
# start = 0
# stride = 50
# end = start + stride

# lowd = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
# # lowd = PCA(n_components=2)

# phrase_2d = lowd.fit_transform(phrase_embeds[start:end])

# fig, ax = plt.subplots(figsize=(20,16))
# for i, label in enumerate(phrase_labels[start:end]):
#     label = "%i: %s" % (i, label)
#     x, y = phrase_2d[i,:]
#     ax.scatter(x, y, color='b')
#     ax.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',
#                    ha='right', va='bottom')

In [None]:
# ### TARGET ONLY ###
# start = 0
# stride = 35
# end = start + stride

# lowd = TSNE(perplexity=20, n_components=2, init='pca', n_iter=5000)
# # lowd = PCA(n_components=2)

# target_2d = lowd.fit_transform(target_embeds[start:end])

# fig, ax = plt.subplots(figsize=(20,16))
# for i, label in enumerate(target_labels[start:end]):
#     label = "%i: %s" % (i, label)
#     x, y = target_2d[i,:]
#     ax.scatter(x, y, color='r')
#     ax.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',
#                    ha='right', va='bottom')

In [None]:
### TW2V demo ###
start = 200
stride = 100
end = start + stride

lowd = TSNE(perplexity=20, n_components=2, init='pca', n_iter=5000)
# lowd = PCA(n_components=2)

target_2d = lowd.fit_transform(word_embeddings[start:end])

fig, ax = plt.subplots(figsize=(28,16))
for i, label in enumerate(DH.vocab[start:end]):
    label = "%s" % (label)
    x, y = target_2d[i,:]
    ax.scatter(x, y, color='b')
    ax.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',
                   ha='right', va='bottom')
    
plt.savefig('word2vec_demo.png', dpi=200)

# Test out semeval data

In [None]:
zip_train = zip(train['raws'], train['sents'], train['sdps'], train['targets'], train['labels'])
zip_valid = zip(valid['raws'], valid['sents'], valid['sdps'], valid['targets'], valid['labels'])
zip_test = zip(test['raws'], test['sents'], test['sdps'], test['targets'])

In [None]:
for i, (raw, _, sdp, target, label) in enumerate(zip_train):
    if i > 5:
        break
    print(raw)
    print("%s :: %s" % (DH.sequence_to_sentence(sdp, show_dep=True), DH.sequence_to_sentence(target)))
    print(int2label[label])
    print("="*80)


# Supervised Training

In [None]:
batch_size = 50
num_steps = len(train['labels']) // batch_size
num_epochs = 25
display_mod = 10
valid_mod = 50
print("Num steps %i" %num_steps)

start = time()
for epoch in range(num_epochs):
    random.shuffle(zip_train) # shuffling should only happen once per epoch
    _, _, sdps, targets, labels = zip(*zip_train)
    for step in range(num_steps): # num_steps
        class_batch = DH.classification_batch(batch_size, sdps, targets, labels, 
                                              offset=step, shuffle=False)
        xent = drnn.partial_class_fit(*class_batch)
        if step % display_mod == 0:   
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            print("(%i:%i:%i) s %i/%i, e %i avg class xent loss = %0.4f" % (h,m,s, step, num_steps, epoch, xent))
        if step % valid_mod == 0:
            valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
            valid_xent = drnn.validation_class_loss(*valid_batch)
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            print("="*80)
            print("(%i:%i:%i) s %i/%i, e %i validation avg class xent loss = %0.4f" % (h,m,s, step, num_steps, epoch, valid_xent))
            print("="*80)
#             print("Saving model...")
#             drnn.checkpoint()
    label_set = set(train['labels'])
    preds = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3])
    cm, stats = confusion_matrix(preds[0], valid['labels'], label_set)
    print("Macro P: %2.4f, R: %3.4f, F1: %0.4f" % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
drnn.checkpoint()
print("Done")

# Unsupervised 10 then Supervised 25

In [None]:
# hyperparameters
num_epochs = 5
batch_size =50
neg_per = 25
num_nearby = 20
nearby_mod = 50
sample_power = .75
DH.scale_vocab_dist(sample_power)

# # bookkeeping
num_steps = DH.num_steps(batch_size)
total_step = 1
save_interval = 30 * 60 # half hour in seconds
save_time = time()

#timing stuff
start = time()
fit_time = 0
nearby_time = 0

for epoch in range(num_epochs):
    offset = 0 #if epoch else 400
    DH.shuffle_data()
    for step , batch in enumerate(DH.batches(batch_size, offset=offset, neg_per=neg_per)):
        if not step: step = offset
        t0 = time()
        loss = drnn.partial_unsup_fit(*batch)
        fit_time = (fit_time * float(total_step) +  time() - t0) / (total_step + 1) # running average
        if step % 10 == 0:
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            left = time_left(num_epochs, num_steps, fit_time, nearby_time, start, nearby_mod)
            ml,sl = divmod(left, 60)
            hl,ml = divmod(ml, 60)
            pps = batch_size*(neg_per + 1) / fit_time 
            print("(%i:%i:%i) step %i/%i, epoch %i Training Loss = %1.5f :: %0.3f phrases/sec :: (%i:%i:%i) hours left" 
                  % (h,m,s, step, num_steps, epoch, loss, pps, hl, ml, sl))
        if (total_step-1) % nearby_mod == 0: # do one right away so we get a good timing estimate
            t0 = time()
            run_validation_test(num_nearby) # check out the nearby phrases in the validation set
            valid_loss = drnn.validation_loss(*DH.validation_batch())
            print("Validation loss: %0.4f" % valid_loss)
            nearby_time = (nearby_time * float(total_step) + time() - t0) / (total_step + 1) # running average

        if (time() - save_time) > save_interval:
            print("Saving model...")
            drnn.checkpoint()
            save_time = time()
        total_step +=1
    valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
    label_set = set(train['labels'])
    preds, dists = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3], return_probs=True)
    cm, stats = confusion_matrix(preds, valid['labels'], label_set)
    print("Macro P: %2.4f, R: %3.4f, F1: %0.4f" % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
drnn.checkpoint()

batch_size = 50
num_steps = len(train['labels']) // batch_size
num_epochs = 25
display_mod = 10
valid_mod = 50



print(num_steps)

start = time()
for epoch in range(num_epochs):
#     class_batch = DH.classification_batch(batch_size, train['sdps'], train['targets'], train['labels'], offset=0)
#     random.shuffle(class_batch)

    for step in range(10): # num_steps
        inputs, targets, labels, lens = DH.classification_batch(batch_size, train['sdps'], train['targets'], train['labels'], offset=step)
        class_batch = zip(inputs, targets, labels, lens)
        random.shuffle(class_batch)
        class_batch = zip(*class_batch)
        xent = drnn.partial_class_fit(*class_batch)
        if step % display_mod == 0:   
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            print("(%i:%i:%i) s %i/%i, e %i avg class xent loss = %0.4f" % (h,m,s, step, num_steps, epoch, xent))
        if step % valid_mod == 0:
            valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
            valid_xent = drnn.validation_class_loss(*valid_batch)
            m,s = divmod(time()-start, 60)
            h,m = divmod(m, 60)
            print("="*80)
            print("(%i:%i:%i) s %i/%i, e %i validation avg class xent loss = %0.4f" % (h,m,s, step, num_steps, epoch, valid_xent))
            print("="*80)
#             print("Saving model...")
#             drnn.checkpoint()
    label_set = set(train['labels'])
    preds, dists = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3], return_probs=True)
    cm, stats = confusion_matrix(preds, valid['labels'], label_set)
    print("Macro P: %2.4f, R: %3.4f, F1: %0.4f" % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
drnn.checkpoint()
print("Done")

# Alternating

In [None]:


for cycle in range(20):
    # hyperparameters
    num_epochs = 1
    batch_size =50
    neg_per = 25
    num_nearby = 20
    nearby_mod = 50
    sample_power = .75
    DH.scale_vocab_dist(sample_power)

    # # bookkeeping
    num_steps = DH.num_steps(batch_size)
    total_step = 1
    save_interval = 30 * 60 # half hour in seconds
    save_time = time()

    #timing stuff
    start = time()
    fit_time = 0
    nearby_time = 0
    for epoch in range(num_epochs):
        offset = 0 #if epoch else 400
        DH.shuffle_data()
        for step , batch in enumerate(DH.batches(batch_size, offset=offset, neg_per=neg_per)):
            if not step: step = offset
            t0 = time()
            loss = drnn.partial_unsup_fit(*batch)
            fit_time = (fit_time * float(total_step) +  time() - t0) / (total_step + 1) # running average
            if step % 10 == 0:
                m,s = divmod(time()-start, 60)
                h,m = divmod(m, 60)
                left = time_left(num_epochs, num_steps, fit_time, nearby_time, start, nearby_mod)
                ml,sl = divmod(left, 60)
                hl,ml = divmod(ml, 60)
                pps = batch_size*(neg_per + 1) / fit_time 
                print("(%i:%i:%i) step %i/%i, epoch %i Training Loss = %1.5f :: %0.3f phrases/sec :: (%i:%i:%i) hours left" 
                      % (h,m,s, step, num_steps, epoch, loss, pps, hl, ml, sl))
            if (total_step-1) % nearby_mod == 0: # do one right away so we get a good timing estimate
                t0 = time()
                run_validation_test(num_nearby) # check out the nearby phrases in the validation set
                valid_loss = drnn.validation_loss(*DH.validation_batch())
                print("Validation loss: %0.4f" % valid_loss)
                nearby_time = (nearby_time * float(total_step) + time() - t0) / (total_step + 1) # running average

            if (time() - save_time) > save_interval:
                print("Saving model...")
                drnn.checkpoint()
                save_time = time()
            total_step +=1
        valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
        label_set = set(train['labels'])
        preds, dists = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3], return_probs=True)
        cm, stats = confusion_matrix(preds, valid['labels'], label_set)
        print("Macro P: %2.4f, R: %3.4f, F1: %0.4f" % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
    drnn.checkpoint()

    batch_size = 50
    num_steps = len(train['labels']) // batch_size
    num_epochs = 5
    display_mod = 10
    valid_mod = 50



    print(num_steps)

    start = time()
    for class_epoch in range(3):
    #     class_batch = DH.classification_batch(batch_size, train['sdps'], train['targets'], train['labels'], offset=0)
    #     random.shuffle(class_batch)

        for class_step in range(num_steps):
            inputs, targets, labels, lens = DH.classification_batch(batch_size, train['sdps'], train['targets'], train['labels'], offset=class_step)
            class_batch = zip(inputs, targets, labels, lens)
            random.shuffle(class_batch)
            class_batch = zip(*class_batch)
            xent = drnn.partial_class_fit(*class_batch)
            if step % display_mod == 0:   
                m,s = divmod(time()-start, 60)
                h,m = divmod(m, 60)
                print("(%i:%i:%i) s %i/%i, e %i avg class xent loss = %0.4f" % (h,m,s, class_step, num_steps, class_epoch, xent))
            if step % valid_mod == 0:
                valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
                valid_xent = drnn.validation_class_loss(*valid_batch)
                m,s = divmod(time()-start, 60)
                h,m = divmod(m, 60)
                print("="*80)
                print("(%i:%i:%i) s %i/%i, e %i validation avg class xent loss = %0.4f" % (h,m,s, class_step, num_steps, class_epoch, valid_xent))
                print("="*80)
    #             print("Saving model...")
    #             drnn.checkpoint()
        label_set = set(train['labels'])
        preds, dists = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3], return_probs=True)
        cm, stats = confusion_matrix(preds, valid['labels'], label_set)
        print("Macro P: %2.4f, R: %3.4f, F1: %0.4f" % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
    drnn.checkpoint()
    print("Done")

In [None]:
# check out predictions
# valid_batch = DH.classification_batch(len(train['labels']), train['sdps'], train['targets'], train['labels'])

valid_batch = DH.classification_batch(len(valid['labels']), valid['sdps'], valid['targets'], valid['labels'])
preds, dists = drnn.predict(valid_batch[0], valid_batch[1], valid_batch[3], return_probs=True)

In [None]:
# for i, p in enumerate(preds):
#     print("%i, pred: %s, true: %s" %(i, int2label[p], int2label[valid['labels'][i]]))
#     target = DH.sequence_to_sentence(valid['targets'][i]).split(' ')
#     sdp = DH.sequence_to_sentence(valid['sdps'][i], show_dep=True)
#     print('<%s> "%s" <%s>' % (target[0], sdp, target[1]))
#     print(valid['raws'][i])
#     print(valid['comments'][i])
#     print("="*80)

In [None]:
def confusion_matrix(preds, labels, label_set):
    size = len(label_set)
    matrix = np.zeros([size, size]) # rows are predictions, columns are truths
    # fill in matrix
    for p, l in zip(preds, labels):
        matrix[p,l] += 1
    # compute class specific scores
    class_precision = np.zeros(size)
    class_recall = np.zeros(size)
    for label in range(size):
        tp = matrix[label, label]
        fp = np.sum(matrix[label, :]) - tp
        fn = np.sum(matrix[:, label]) - tp
        class_precision[label] = tp/float(tp + fp) if tp or fp else 0
        class_recall[label] = tp/float(tp + fn) if tp or fn else 0
    micro_f1 = np.array([2*(p*r)/(p+r) if p or r else 0 for (p, r) in zip(class_precision, class_recall)])
    avg_precision = np.mean(class_precision)
    avg_recall = np.mean(class_recall)
    macro_f1 = (2*avg_precision*avg_recall) / (avg_precision + avg_recall)
    stats = {'micro_precision':class_precision,
             'micro_recall':class_recall, 
             'micro_f1':micro_f1,
             'macro_precision':avg_precision, 
             'macro_recall':avg_recall,
             'macro_f1':macro_f1}
    return matrix, stats
label_set = set(train['labels'])
cm, stats = confusion_matrix(preds, valid['labels'], label_set)
print("Macro F1: %0.4f" % stats['macro_f1'])

In [None]:
def plot_confusion_matrix(cm, label_names, save_name=None, title='Normed Confusion matrix', cmap=plt.cm.Blues, stats=None):
    fig, ax = plt.subplots(figsize=(20,20))
    
    # calc normalized cm
    x, y = np.meshgrid(range(cm.shape[0]), range(cm.shape[1]))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_normalized[np.isnan(cm_normalized)] = 0.0
    
    # print nonzero raw counts
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        norm = cm_normalized[x_val, y_val]
        c = "%i" % (cm.astype('int')[x_val, y_val])
        if norm > 0.0:
            color = 'white' if norm > .5 else 'black'
            ax.text(y_val, x_val, c, va='center', ha='center', color=color)
    
    # actual plot
    im = ax.imshow(cm_normalized, interpolation='nearest', origin='upper', cmap=cmap)
#     divider = plt.make_axes_locatable(ax)
#     cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, fraction=0.046, pad=0.04)
    
    # set ticks and offset grid
    tick_marks = np.arange(len(label_names))
    tick_marks_offset = np.arange(len(label_names)) - .5
    ax.set_xticks(tick_marks, minor=False)
    ax.set_yticks(tick_marks, minor=False)
    ax.set_xticks(tick_marks_offset, minor=True)
    ax.set_yticks(tick_marks_offset, minor=True)
    ax.grid(which='minor')
    if stats:
        # include micro precisio, recall, and f1
        aug_y_labels = []
        for i in range(len(label_names)):
            aug = ("%s\nP:%0.2f, R:%0.2f, F1:%0.2f" 
                   % (label_names[i],
                      stats['micro_precision'][i],
                      stats['micro_recall'][i],
                      stats['micro_f1'][i],))
            aug_y_labels.append(aug)
    else:
        aug_x_labels = label_names
    ax.set_xticklabels(label_names, rotation=75, horizontalalignment='left', x=1)
    ax.xaxis.tick_top()
    ax.set_yticklabels(aug_y_labels)
    
    # other stuff
    plt.tight_layout()
    plt.ylabel('Predicted Labels', fontsize=16)
    if stats:
        # include macro 
        aug_x_label = ("True Labels\n Macro P:%0.2f, R:%0.2f, F1:%0.2f" 
                       % (stats['macro_precision'], stats['macro_recall'], stats['macro_f1']))
    else:
        aug_x_label = "True Label"
    plt.xlabel(aug_x_label, fontsize=16)
    plt.title(title, y=1.12, fontsize=20)
    if save_name:
        plt.savefig(save_name+'.pdf')
        
save_name = raw_input("Enter save name: ")
plot_confusion_matrix(cm, int2label.values(), save_name=save_name, stats=stats)