In [1]:
from data_util import load_embeddings, DMConfig, tfConfig
try:
    import _pickle as cPickle
except ImportError:
    import cPickle
import gzip
import os
import numpy as np
from sklearn import metrics
from defs import LBLS

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
config = DMConfig()
if os.path.exists(config.tok2id_path):
    with gzip.open(os.path.abspath(config.tok2id_path)) as f:
        tok2id = cPickle.load(f)
if os.path.exists(config.id2tok_path):
    with gzip.open(os.path.abspath(config.id2tok_path)) as f:
        id2tok = cPickle.load(f)

In [4]:
embeddings = load_embeddings(config, tok2id)

INFO:Initialized embeddings.


## Training

In [5]:
from tf_linear_classifier import TfLinearClassifier
from defs import PROJECT_DIR

In [6]:
model = TfLinearClassifier(embedding=embeddings, 
                           ckpts_prefix='{}/checkpoints/linear_model/linear-model'.format(PROJECT_DIR),
                           summaries_dir='{}/summaries/'.format(PROJECT_DIR)
                          )
model.fit(tfConfig, restore_weights=True, batches_to_eval=100, max_iter=0)

INFO:Initializing data manager...
INFO:took 0 s
INFO:Building model...


Instructions for updating:
Use the retry module or similar alternatives.


Instructions for updating:
Use the retry module or similar alternatives.
INFO:took 7 s


INFO:tensorflow:Restoring parameters from /home/tuckerleavitt/checkpoints/linear_model/linear-model-10000


INFO:Restoring parameters from /home/tuckerleavitt/checkpoints/linear_model/linear-model-10000
INFO:-- Restored model


<tf_linear_classifier.TfLinearClassifier at 0x7f3c22e83b70>

## Evaluation

In [7]:
def print_example(pred, label, feats, length, email_id):
    ids_ = feats[:length][:, 0]
    print("==== {} ====".format(email_id))
    print("Predicted label: {}. True label: {}".format(LBLS[pred], LBLS[label]))
    print( " ".join((id2tok[i] for i in ids_)) )
    print("------------------------------------")


def print_examples(preds, outputs, inputs, lens, email_ids):
    for p, o, i, l, d in zip(preds, outputs, inputs, lens, email_ids):
        print_example(p, o, i, l, d)

In [9]:
outputs, preds, word_ids, email_ids = model.evaluate_tfdata(model.sess, dataset_name='test', batch_lim=100, writer=None)

preds = np.asarray(preds); outputs = np.array(outputs)    
correct = preds == outputs
prec, rec, f1, _ = metrics.precision_recall_fscore_support(outputs, preds, average='binary')

print(" Predicted {} of {} examples correctly".format(np.sum(correct), correct.shape[0]) )
print(" Precision: {:1.3f}, Recall: {:1.3f}, F1: {:1.3f}".format(prec, rec, f1) )

INFO:Ran 100 batches


 Predicted 2186 of 3200 examples correctly
 Precision: 0.535, Recall: 0.428, F1: 0.476


In [12]:
batch_preds, inputs, lens, batch_outputs, email_ids = model.predict()
print_examples(preds, outputs, inputs, lens, email_ids)

==== b'<4250772.1075857358369.JavaMail.evans@thyme>' ====
Predicted label: NORESPONSE. True label: NORESPONSE
<s> following discussions with you and doug , attached is a draft parking transaction agreement for your review and , if acceptable , for UUUNKKK to the counterparty . please call me with any questions . -- david </s>
------------------------------------
==== b'<25189453.1075851750174.JavaMail.evans@thyme>' ====
Predicted label: NORESPONSE. True label: NORESPONSE
<s> we have received an executed financial master agreement : type of contract : isda master agreement ( UUUNKKK border ) effective date : october NNNUMMM , NNNUMMM enron entity : enron canada corp. counterparty : UUUNKKK corporation limited , by and through its division , UUUNKKK paper company transactions covered : approved for all products with the exception of : foreign exchange confirming entity : enron canada corp. governing law : ontario copies will be distributed . stephanie UUUNKKK </s>
-----------------------