In [3]:
from random import shuffle
import csv
import matplotlib.pyplot as plt
import os
import sklearn.metrics
import tensorflow as tf
from src.utils import *

In [24]:
def better_data(
    chkpt_path,
    dataset_name = "nyt2",
    encoder = "pcnn",
    selector = "att",
    l2_lambd = 0.0,
    batch_size = 64):
    data_dir = os.path.join("data", dataset_name)
    preprocessed_data_dir = get_preprocessed_dir(data_dir)
    if not os.path.exists(preprocessed_data_dir):
        os.mkdir(preprocessed_data_dir)

    #preprocessor_batch(data_dir)

    from src.model import Model
    tf.reset_default_graph()
    _, word_vec_mat = load_word_vec(os.path.join("data", dataset_name))
    max_classes = 53
    model = Model(word_vec_mat, encoder = encoder, selector=selector, no_of_classes=max_classes)
    print("Setting max class size to : ", max_classes)
    train_data = load_data(preprocessed_data_dir, max_classes)
    dev_data = load_data(preprocessed_data_dir, max_classes, "test")


    n_epochs = 4
    pair_bag_loc = train_data[-1]
    pairs = list(pair_bag_loc.keys())
    dev_pair_bag_loc = dev_data[-1]
    dev_pairs = list(dev_pair_bag_loc.keys())
    dev_pairs_dict = get_dev_pairs_dict(dev_pairs)
    n_dev_batches = len(dev_pairs) // batch_size
    
    # Load model from checkpoint
    model.mloader(os.path.join("saved", "models", chkpt_path))
    model.reset_optimizer()
    
    def na_nonNA(pairs):
        not_NA_rels = 0
        naPairs = []
        nonNaPairs = []
        for k in pairs:
            if k.split("#")[2] != "0":
                nonNaPairs.append(k)
            else:
                naPairs.append(k)
        return naPairs, nonNaPairs

    # Store NA and nonNA pairs from train, dev.
    trainNa, trainNonNa = na_nonNA(pairs)
    print('Douglas#Ken' in trainNa)
    devNa, devNonNa = na_nonNA(dev_pairs)
    not_NA_rels = len(devNonNa)
    trainNonNA_pairs_dict = get_dev_pairs_dict(trainNonNa)
    trainNA_pairs_dict = get_dev_pairs_dict(trainNa)
    print("Pairs in train dataset NA and non NA : ", len(trainNa), len(trainNonNa))
    print("Pairs in train dataset NA and non NA : ", len(devNa), len(devNonNa))
    """
    all_words, all_pos1, all_pos2, all_masks, all_lengths, \
    all_inst_rels, pair_bag_loc = train_data
    all_dev_words, all_dev_pos1, all_dev_pos2, all_dev_masks, \
    all_dev_lengths, all_dev_inst_rels, dev_pair_bag_loc = dev_data
    """
    entPairSents = {} 
        
    import json 
    with open(os.path.join(data_dir, "train.json"), "r") as f: 
        data = json.load(f)
        for sent in data: 
          e1 = sent['head']['word'] 
          e2 = sent['tail']['word'] 
          try: 
              rel = relToId[sent['relation']] 
          except Exception as ex: 
              rel = 0 
          k = e1 + "#" + e2  
          if k in entPairSents: 
            entPairSents[k].append(sent['sentence']) 
          else: 
            entPairSents[k] = [sent['sentence']]
    with open(os.path.join(data_dir, "idToRel.json"), "r") as f: 
        idToRel = json.load(f)
    
    def dump_best_data(pairs, name = "NonNA"):
        n_batches = len(pairs) // batch_size
        test_res = []
        print(n_batches)
        print("Running model on data ...", end = " ")
        for i in range(n_batches):
          if i % 100 == 0:
            print(i , end = " ")
          batch_keys = pairs[i * batch_size : (i + 1) * batch_size]
          words, pos1, pos2, inst_rels, masks, lengths, \
            rels, scope = batch_maker(train_data, batch_keys)
          pos1[pos1 > 239] = 239
          pos2[pos2 > 239] = 239
          pos1[pos1 < 0] = 0
          pos2[pos2 < 0] = 0
          output, atts = model.test_batch(words, pos1, pos2, inst_rels, 
            masks, lengths, rels, scope)
          for i, k in enumerate(batch_keys):
            entPair = "#".join(k.split("#")[:2])
            entPairRels = trainNA_pairs_dict[entPair]
#             for j in range(1, 53):
#               correct = 0
#               if j in entPairRels:
#                 correct = 1
#               if output[i][j] > 0.01:
#                 if correct:
#                     test_res.append({"entPair" : entPair, 
#                         "score" : output[i][j], 
#                         "correct" : correct, 
#                         "pred" : j, 
#                         "actual" : entPairRels,
#                         "atts" : atts[i][j] if selector is "att" else []})
            j = 0
            correct = 0
            if j in entPairRels:
                correct = 1
            if output[i][j] > 0.5:
                if correct:
                    test_res.append({"entPair" : entPair, 
                        "score" : output[i][j], 
                        "correct" : correct, 
                        "pred" : j, 
                        "actual" : entPairRels,
                        "atts" : atts[i][j] if selector is "att" else []})
        prec = []
        recall = []
        correct = 0
        sorted_test_result = sorted(test_res, key=lambda x: x['score'], reverse = True)

        import time
        start = time.time()

        for i, item in enumerate(sorted_test_result):
          if item["correct"]:
            correct += 1  
          prec.append(float(correct) / (i + 1))
          recall.append(float(correct) / not_NA_rels)
        size_ = 1000
        for j in range(50):
            with open(os.path.join("saved", "aucs", chkpt_path +"_" + name + "_"+ str(j) + "_" "_preds.csv"), "w", newline="") as f:
                preds = []
                for i, item in enumerate(sorted_test_result[j * size_: (j + 1) * size_]):
                    sents = "\n".join(entPairSents[item["entPair"]])
                    temp_str = [item["entPair"], sents, idToRel[str(item["pred"])], " ",
                                "\n".join(idToRel[str(k)] for k in item["actual"]), str(item["correct"]), 
                                " ".join(str(x) for x in item["atts"]), str(item["score"])]
                    preds.append(temp_str)
                writer = csv.writer(f)
                writer.writerows(preds)
        
    dump_best_data(trainNa, "NA")


In [25]:
better_data("pcnn_att_nyt2_none_53_n_0_0.3455")

Creating model with encoder and selector :  pcnn att
121 (?, 120, 230)
121 (?, 120, 230)
(?, 690)
Created model with no bootstrapping, bs val : 0.0
<tf.Variable 'word_embedding/word_embedding:0' shape=(114042, 50) dtype=float32_ref>
<tf.Variable 'word_embedding/unk_word_embedding:0' shape=(1, 50) dtype=float32_ref>
<tf.Variable 'word_embedding/start_word_embedding:0' shape=(1, 50) dtype=float32_ref>
<tf.Variable 'word_embedding/end_word_embedding:0' shape=(1, 50) dtype=float32_ref>
<tf.Variable 'pos_embedding/real_pos1_embedding:0' shape=(240, 5) dtype=float32_ref>
<tf.Variable 'pos_embedding/real_pos2_embedding:0' shape=(240, 5) dtype=float32_ref>
<tf.Variable 'pcnn/conv1d/kernel:0' shape=(3, 60, 230) dtype=float32_ref>
<tf.Variable 'pcnn/conv1d/bias:0' shape=(230,) dtype=float32_ref>
<tf.Variable 'attention/logit/relation_matrix:0' shape=(53, 690) dtype=float32_ref>
<tf.Variable 'attention/logit/bias:0' shape=(53,) dtype=float32_ref>
<tf.Variable 'beta1_power:0' shape=() dtype=float3

KeyboardInterrupt: 