In [1]:
import numpy as np
import scipy.sparse
import pickle
import xgboost as xgb
import csv
from collections import defaultdict
import redis
import json
import time
import sklearn.utils
import psycopg2
from psycopg2.sql import Identifier, SQL
from datetime import datetime
import os

In [24]:
conn = redis.Redis(password='3oSYTdtZjsuSigRWLcG6VJt9gm4IMvYjQiqsSuGcAc-U4gMNpWGERAevXi9_SHNrn19piz7bBJG0iTLgx7DvknLHTECcHYrqmWb2rsuCWs89svKmhKDD_aMYaXq8IhSeg_89ooPZb0AqLRyR1-fa1zVjrh2UuV0sWFGSk5SjtW0', 
                   host='localhost', port=7380, decode_responses=True)
conn_psql_kongzi2 = psycopg2.connect(database='aida', user='zding', password='dingzishuo', host='localhost',
                                     port=5432)
conn_psql = psycopg2.connect(database='zding', user='zding', password='dingzishuo', host='localhost',
                                     port=5432)

In [11]:
def get_feature_results(exp_id, d_t, data_source='aida_conll'):
    fea_res = conn.lrange('result:::::' + str(exp_id) + ':::::' + data_source + ':::::' + str(d_t), 0, -1)
    fea_vecs = [json.loads(res) for res in fea_res]
    return fea_vecs

def get_dataset_info(exp_id, d_t, data_source='aida_conll', can_size=50):
    valid_mens_size = conn.scard('valid_qry_ids:::::' + str(exp_id) + ':::::' + data_source + ':::::' + str(d_t))
    no_g_can_info_size = conn.scard('no_g_candidate_info:::::' + str(exp_id) + ':::::' + data_source + ':::::' + str(d_t))
    missed_g_candidate_size = conn.scard('missed_g_candidate:::::' + str(exp_id) + ':::::' + data_source + 
                                         ':::::' + str(d_t) + ':::::' + str(can_size))
    no_candidate_size = conn.scard('no_candidate:::::' + str(exp_id) + ':::::' + data_source + ':::::' + str(d_t))
    return valid_mens_size, no_candidate_size, no_g_can_info_size, missed_g_candidate_size

def fetch_all_features(exp_id, data_type, data_source='aida_conll'):
    res_feas = get_feature_results(exp_id, data_type, data_source)
    res_feas_ids = [[res[0].strip('(').split(', ')[0]] + res[1:-1] + [res[-1]] for res in res_feas]
    res_feas_ids = np.array(res_feas_ids, dtype=np.float64)
    return res_feas_ids

def fetch_all_features_delete_max_prior(exp_id, data_type, data_source='aida_conll'):
    res_feas = get_feature_results(exp_id, data_type, data_source)
    res_feas_ids = [[res[0].strip('(').split(', ')[0]] + res[1:2] + [res[2] if res[2]!=0 else res[1]] + res[3:] for res in res_feas]
    res_feas_ids = np.array(res_feas_ids, dtype=np.float64)
    return res_feas_ids

def trans_data(data):
    d_np = data[:, 1:-1]
    #print(d_np)
    d_labels = data[:, -1]
    #print(d_labels)
    idxs = np.where(d_labels == 1)[0]
    d_groups = np.append(np.delete(idxs, 0), len(d_labels)) - idxs
    xgb_data = xgb.DMatrix(data=d_np, label=d_labels)
    xgb_data.set_group(d_groups)
    return xgb_data

def combine_features(original_feas, new_features):
    men_id_feas_dict = defaultdict(list)
    print("Building idx for new features...")
    for fea in new_features:
        men_id_feas_dict[fea[0]].append(fea)
    #for k, v in men_id_feas_dict.items():
    #    print(k)
    #    print(v)
    comb_feas = []
    pre_men_id = 0
    print("Combine original and new features...")
    for fea_idx, fea in enumerate(original_feas):
        #print('Processing: [%d] %s' % (fea_idx, fea))
        if pre_men_id == fea[0]:
            #print('skip')
            continue
        else:
            pre_men_id = fea[0]
            fea_size = len(men_id_feas_dict[fea[0]])
            res = np.append(original_feas[fea_idx: fea_idx + fea_size, :-1], np.array(men_id_feas_dict[fea[0]])[:, 1:], axis=1)
            #print('res:', res)
            comb_feas.append(res)
    return np.concatenate(comb_feas, axis=0)

def evalerror(preds, dt, d_tal_size):
    d_l = dt.get_label()
    idxs = np.where(d_l == 1)[0]
    d_groups = np.append(np.delete(idxs, 0), len(d_l)) - idxs
    matched_ids = []
    q_id = 0
    for x in d_groups:
        pre_res = preds[q_id: x + q_id]
        if(preds[q_id] == max(pre_res)):
            if len([x for x in pre_res if x == preds[q_id]]) == 1:
                matched_ids.append(q_id)
        q_id += x
    precision = float(len(matched_ids)) / len(d_groups)
    recall = float(len(matched_ids)) / d_tal_size
    f1 = 2 * precision * recall / (precision + recall)
    return len(matched_ids), precision, recall, f1

from collections import defaultdict
def evalerror_detail_log(preds, dt, d_tal_size):
    d_l = dt.get_label()
    idxs = np.where(d_l == 1)[0]
    d_groups = np.append(np.delete(idxs, 0), len(d_l)) - idxs
    correct_results = {}
    wrong_results = {}
    duplicates_results = {}
    group_info = {}
    matched_ids = []
    q_id = 0
    for x in d_groups:
        pre_res = preds[q_id: x + q_id]
        if(preds[q_id] == max(pre_res)):
            correct_results[q_id] = pre_res
            if len([x for x in pre_res if x == preds[q_id]]) == 1:
                matched_ids.append(q_id)
            else:
                duplicates_results[q_id] = pre_res
        else:
            wrong_results[q_id] = pre_res
        q_id += x
    precision = float(len(matched_ids)) / len(d_groups)
    recall = float(len(matched_ids)) / d_tal_size
    f1 = 2 * precision * recall / (precision + recall)
    return len(matched_ids), precision, recall, f1, correct_results, wrong_results, duplicates_results

from collections import defaultdict
from ast import literal_eval
def get_groups_results(preds, dt, res_features, top_k=None):
    d_l = dt.get_label()
    idxs = np.where(d_l == 1)[0]
    d_groups = np.append(np.delete(idxs, 0), len(d_l)) - idxs
    correct_res_groups = []
    wrong_res_groups = []
    dup_res_groups = []
    top_k_indices = [] if top_k is not None else None
    q_id = 0
    for x in d_groups:
        pre_res = preds[q_id: x + q_id]
        pre_res_feas = res_features[q_id: x + q_id]
        pred_q_id, pred_ent = literal_eval(res_features[q_id+np.argmax(pre_res)][0])
        
        if(preds[q_id] == max(pre_res)):
            correct_res_groups.append([pred_q_id, pred_ent])
        else:
            wrong_res_groups.append([pred_q_id, pred_ent])
            
        if top_k is not None:  # save indices of top-k scores in each group
            for i, score in sorted(enumerate(pre_res),key=lambda x:x[1], reverse=True)[:top_k]:
                # print(i, score)
                top_k_indices.append(i+q_id)
            
        q_id += x
        
    return correct_res_groups, wrong_res_groups, top_k_indices

def fetch_inlinks_by_ent(ent):
    cur = conn_psql_kongzi2.cursor()
    sql = "SELECT _id FROM wikipedia_links_2014 WHERE target=%s;"
    cur.execute(sql, (ent,))
    rows = cur.fetchall()
    cur.close()
    return rows

# fetch an entity's outlinks with duplicates
def fetch_outlinks_by_ent(ent):
    cur = conn_psql_kongzi2.cursor()
    sql = "SELECT target FROM wikipedia_links_2014 WHERE _id=%s;"
    cur.execute(sql, (ent,))
    rows = cur.fetchall()
    cur.close()
    return rows

def fetch_entity_by_mention_emnlp17(mention):
    # print(mention)
    cur = conn_psql_kongzi2.cursor()
    # do a PostgreSQL join to select the entity namestring from the tables dictionary and entity_ids
    sql = "SELECT entity, prior FROM men_ent_dict_emnlp2017 WHERE men_ent_dict_emnlp2017.mention = (E\'%s\') ORDER BY prior DESC;"
    cur.execute(sql % mention.replace("'", "\\'"))
    rows = cur.fetchall()
    cur.close()
    return rows

def fetch_inlinks_redis(ent, link_type='inlinks'):
    inlinks = conn.hmget(link_type, ent)[0]
    return json.loads(inlinks) if inlinks else []

def has_inlinks_redis(ent, link_type='inlinks'):
    inlinks = conn.hexists(link_type, ent)
    return inlinks

def save_inlinks_redis(ent, inlinks, link_type='inlinks'):
    conn.hset(link_type, ent, json.dumps(inlinks))
    
def fetch_outlinks_redis(ent, link_type='outlinks'):
    outlinks = conn.hmget(link_type, ent)[0]
    return json.loads(outlinks) if outlinks else []

def has_outlinks_redis(ent, link_type='outlinks'):
    inlinks = conn.hexists(link_type, ent)
    return inlinks

def save_outlinks_redis(ent, outlinks, link_type='outlinks'):
    conn.hset(link_type, ent, json.dumps(outlinks))
    
def check_links_between_ents(ent_1, ent_2, bidirection=False):
    wiki_pre_str = 'en.wikipedia.org/wiki/'
    inlinks_ent_1 = fetch_inlinks_redis(ent_1, link_type='inlinks')
    if not inlinks_ent_1 and not has_inlinks_redis(ent_1):
        print("PostgreSQL: fetching inlinks for entity {}...".format(ent_1))
        wiki_ents_1 = wiki_pre_str + ent_1
        inlinks_ent_1_db = fetch_inlinks_by_ent(wiki_ents_1)
        inlinks_ent_1 = [x[0].replace(wiki_pre_str, '') for x in inlinks_ent_1_db]
        print("Redis: caching inlinks for entity {}...".format(ent_1))
        save_inlinks_redis(ent_1, inlinks_ent_1)
    inlinks_ent_2 = fetch_inlinks_redis(ent_2, link_type='inlinks')
    if not inlinks_ent_2 and not has_inlinks_redis(ent_2):        
        wiki_ents_2 = wiki_pre_str + ent_2
        print("PostgreSQL: fetching inlinks for entity {}...".format(ent_2))
        inlinks_ent_2_db = fetch_inlinks_by_ent(wiki_ents_2)
        inlinks_ent_2 = [x[0].replace(wiki_pre_str, '') for x in inlinks_ent_2_db]
        print("Redis: caching inlinks for entity {}...".format(ent_2))
        save_inlinks_redis(ent_2, inlinks_ent_2)
    return (ent_1 in inlinks_ent_2 or ent_2 in inlinks_ent_1) if not bidirection else (ent_1 in inlinks_ent_2 and ent_2 in inlinks_ent_1)

def get_links_by_ent(ent, link_type='inlinks'):
    wiki_pre_str = 'en.wikipedia.org/wiki/'
    if link_type == 'inlinks':
        inlinks_ent = fetch_inlinks_redis(ent, link_type='inlinks')
        if not inlinks_ent and not has_inlinks_redis(ent):
#             print("PostgreSQL: fetching inlinks for entity {}...".format(ent))
            wiki_ents = wiki_pre_str + ent
            inlinks_ent_db = fetch_inlinks_by_ent(wiki_ents)
            inlinks_ent = [x[0].replace(wiki_pre_str, '') for x in inlinks_ent_db]
            print("Redis: caching inlinks for entity {}...".format(ent))
            save_inlinks_redis(ent, inlinks_ent)
        return inlinks_ent
    if link_type == 'outlinks':
        outlinks_ent = fetch_outlinks_redis(ent)
        if not outlinks_ent and not has_outlinks_redis(ent):
#             print("PostgreSQL: fetching outlinks for entity {}...".format(ent))
            wiki_ents = wiki_pre_str + ent
            outlinks_ent_db = fetch_outlinks_by_ent(wiki_ents)
            outlinks_ent = [x[0].replace(wiki_pre_str, '') for x in outlinks_ent_db]
            print("Redis: caching outlinks for entity {}...".format(ent))
            save_outlinks_redis(ent, outlinks_ent)
        return outlinks_ent
    
def fetch_ents_by_doc_redis(doc_id):
    id_ents = conn.hmget('doc-predicted-ents-coref-new', doc_id)[0]
    return json.loads(id_ents) if id_ents else []

## Normalized Google Distance
import math
def ngd_similarity(ents_s, ents_t, index_size = 6274625):
    ent_sets_s = set(ents_s)
    ent_sets_t = set(ents_t)
    min_links, max_links = min(len(ent_sets_s), len(ent_sets_t)), max(len(ent_sets_s), len(ent_sets_t))
    com_links = len(ent_sets_s & ent_sets_t)
    if min_links and max_links and com_links:
        return 1 - (math.log(max_links) - math.log(com_links))/ (math.log(index_size) - math.log(min_links))
    else:
        return 0
    
# PMI
def pmi_similarity(ents_s, ents_t, index_size = 6274625, normalize=False):
    ent_sets_s = set(ents_s)
    ent_sets_t = set(ents_t)
    s_links, t_links = len(ent_sets_s), len(ent_sets_t)
    com_links = len(ent_sets_s & ent_sets_t)
    p_s = s_links / index_size
    p_t = t_links / index_size
    p_c = com_links / index_size
    print(p_s, p_t, p_c)
    if p_s and p_t and p_c:
        return p_c/(p_s * p_t) if not normalize else p_c / (p_s * p_t) / min(1/p_s, 1/p_t)
    else:
        return 0

In [36]:
model_dir_path = './new_models_14_Aug'
def save_model(model, name):
    if not os.path.exists(model_dir_path):
        os.makedirs(model_dir_path)
    model_path = os.path.join(model_dir_path, '%s.mdl' % name)
    with open(model_path, 'wb') as f:
        pickle.dump(model, f)
        
def load_model(name):
    model_path = os.path.join(model_dir_path, '%s.mdl' % name)
    with open(model_path, 'rb') as f:
        return pickle.load(f)
    
def get_total_mentions(data_source, data_type) -> int:
    with conn_psql.cursor() as cur:
        sql = SQL("select count(*) from {} where annotation != 'NIL' and type=%s").format(Identifier(data_source))
        cur.execute(sql, (data_type,))
        return cur.fetchone()[0]
    
def process(process_name ,test_set, test_total, 
            n_estimators, max_depths, test_filter=None,
            eval_func=evalerror_detail_log):
    if test_filter is not None:
        dtest_xgboost = trans_data(test_set[test_filter])    
    else:
        dtest_xgboost = trans_data(test_set)
    
    for x in n_estimators:
        num_round = x
        for dep in max_depths:
            model_name = '%d_%d_%s' % (num_round, dep, process_name)
            print(datetime.now(), 'Loading model: %s' % model_name)
            bst = load_model(model_name)

            print(datetime.now(), 'Start evaluation')
            preds = bst.predict(dtest_xgboost)
            a = eval_func(preds, dtest_xgboost, test_total)
            print("n_estimators: {}, max_depth: {}, acc_validation: {}, corr_num: {}".format(num_round, dep, a[2], a[0]))
            print(datetime.now(), 'Evaluation finished')
            

def fetch_q_ids_docs(data_source):
    cur = conn_psql.cursor()
    sql = "SELECT id, doc_id FROM %s WHERE annotation != 'NIL';" % data_source
    cur.execute(sql)
    row = cur.fetchall()
    cur.close()
    return dict(row)
            
def save_local_model_predictions(model, d_test, raw_test, docs_dict, top_k=None):
    d_test_xgboost = trans_data(d_test)
    preds_test = model.predict(d_test_xgboost)
    
    correct_test, wrong_test, top_k_indices_test = get_groups_results(preds_test, d_test_xgboost, raw_test, top_k)
    res_all_test = correct_test + wrong_test
    
    print('Number of groups:', len(res_all_test))
    
    doc_id_q_ent_lists_dict = defaultdict(list)
    for q_ent in res_all_test:
        doc_id_q_ent_lists_dict[docs_dict[q_ent[0]]].append(q_ent)
    
    for key, vals in doc_id_q_ent_lists_dict.items():
        conn.hset('doc-predicted-ents-coref-new-for-review', key, json.dumps(vals))
    
    if top_k is not None:
        conn.hset('doc-predicted-ents-top-k', 'test', repr(top_k_indices_train))

In [34]:
docs_dict_msnbc = fetch_q_ids_docs('msnbc_new')
docs_dict_aquaint = fetch_q_ids_docs('aquaint_new')

In [26]:
d_msnbc_ctx = fetch_all_features('basic_fea_ctx', 'test', 'msnbc_new')
d_msnbc_ctx_raw = get_feature_results('basic_fea_ctx', 'test', 'msnbc_new')
d_msnbc_coref = fetch_all_features('basic_fea_coref', 'test', 'msnbc_new')
d_msnbc_total = get_total_mentions('msnbc_new', 'test')

d_aquaint_ctx = fetch_all_features('basic_fea_ctx', 'test', 'aquaint_new')
d_aquaint_ctx_raw = get_feature_results('basic_fea_ctx', 'test', 'aquaint_new')
d_aquaint_coref = fetch_all_features('basic_fea_coref', 'test', 'aquaint_new')
d_aquaint_total = get_total_mentions('aquaint_new', 'test')

print(d_msnbc_ctx.shape)
print(d_msnbc_coref.shape)
print(d_msnbc_total)
print()
print(d_aquaint_ctx.shape)
print(d_aquaint_coref.shape)
print(d_aquaint_total)

(10316, 21)
(10316, 13)
739

(7410, 21)
(7410, 13)
727


In [6]:
d_msnbc_ctx_coref = combine_features(d_msnbc_ctx, d_msnbc_coref)
d_aquaint_ctx_coref = combine_features(d_aquaint_ctx, d_aquaint_coref)

print(d_msnbc_ctx_coref.shape)
print(d_aquaint_ctx_coref.shape)

Building idx for new features...
Combine original and new features...
Building idx for new features...
Combine original and new features...
(10316, 32)
(7410, 32)


In [18]:
process('ctx', d_msnbc_ctx, d_msnbc_total, n_estimators=[4900], max_depths=[6])
process('ctx_coref', d_msnbc_ctx_coref, d_msnbc_total, n_estimators=[4900], max_depths=[6])

2019-08-19 14:49:08.349804 Loading model: 4900_6_ctx
2019-08-19 14:49:08.370359 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.6278755074424899, corr_num: 464
2019-08-19 14:49:08.662315 Evaluation finished
2019-08-19 14:49:08.667954 Loading model: 4900_6_ctx_coref
2019-08-19 14:49:08.693380 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.7117726657645467, corr_num: 526
2019-08-19 14:49:09.003979 Evaluation finished


In [19]:
process('ctx', d_aquaint_ctx, d_aquaint_total, n_estimators=[4900], max_depths=[6])
process('ctx_coref', d_aquaint_ctx_coref, d_aquaint_total, n_estimators=[4900], max_depths=[6])

2019-08-19 14:49:24.595869 Loading model: 4900_6_ctx
2019-08-19 14:49:24.616475 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.53232462173315, corr_num: 387
2019-08-19 14:49:24.859625 Evaluation finished
2019-08-19 14:49:24.865911 Loading model: 4900_6_ctx_coref
2019-08-19 14:49:24.884611 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.5281980742778541, corr_num: 384
2019-08-19 14:49:25.100430 Evaluation finished


In [37]:
model = load_model('4900_6_ctx_coref')
save_local_model_predictions(model, d_msnbc_ctx_coref, d_msnbc_ctx_raw, docs_dict_msnbc, top_k=None)

Number of groups: 543


In [40]:
d_msnbc_coh = fetch_all_features('basic_fea_coh', 'test', 'msnbc_new')
print(d_msnbc_coh.shape)
d_msnbc_ctx_coref_coh = combine_features(d_msnbc_ctx_coref, d_msnbc_coh)
print(d_msnbc_ctx_coref_coh.shape)

(10316, 24)
Building idx for new features...
Combine original and new features...
(10316, 54)


In [42]:
process('ctx_coref_coh', d_msnbc_ctx_coref_coh, d_msnbc_total, n_estimators=[4900], max_depths=[6])
process('ctx_coref_coh_global3', d_msnbc_ctx_coref_coh, d_msnbc_total, n_estimators=[4900], max_depths=[6])

2019-08-19 17:45:32.544743 Loading model: 4900_6_ctx_coref_coh
2019-08-19 17:45:32.574501 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.7090663058186739, corr_num: 524
2019-08-19 17:45:32.933873 Evaluation finished
2019-08-19 17:45:32.941460 Loading model: 4900_6_ctx_coref_coh_global3
2019-08-19 17:45:32.969372 Start evaluation
n_estimators: 4900, max_depth: 6, acc_validation: 0.7036535859269283, corr_num: 520
2019-08-19 17:45:33.331261 Evaluation finished


In [43]:
model = load_model('4900_6_ctx_coref')
save_local_model_predictions(model, d_aquaint_ctx_coref, d_aquaint_ctx_raw, docs_dict_aquaint, top_k=None)

Number of groups: 422


In [None]:
d_acquaint_coh = fetch_all_features('basic_fea_coh', 'test', 'acquaint_new')
print(d_acquaint_coh.shape)
d_acquaint_ctx_coref_coh = combine_features(d_acquaint_ctx_coref, d_msnbc_coh)
print(d_msnbc_ctx_coref_coh.shape)

In [None]:
process('ctx_coref_coh', d_msnbc_ctx_coref_coh, d_msnbc_total, n_estimators=[4900], max_depths=[6])
process('ctx_coref_coh_global3', d_msnbc_ctx_coref_coh, d_msnbc_total, n_estimators=[4900], max_depths=[6])