In [1]:
import json
import numpy as np
import collections
import copy
from os import listdir
from os.path import isfile, join

In [2]:
import findspark
findspark.init()
from pyspark import SparkContext
import pyspark
conf = pyspark.SparkConf().setAll([('spark.executor.memory', '8g'), ('spark.executor.cores', '2'),('spark.executor.instances','7'), ('spark.driver.memory','32g'), ('spark.driver.maxResultSize','10g')])
sc = SparkContext(conf=conf)

In [3]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.sql.types import Row
from pyspark.sql import SparkSession
spark = SparkSession(sc)

In [4]:
def convert_ndarray_back(x):
    x['entityCell'] = np.array(x['entityCell'])
    return x
data_dir = "../../data/"
train_tables = sc.textFile(data_dir+"train_tables.jsonl").map(lambda x:convert_ndarray_back(json.loads(x.strip())))

In [45]:
def get_core_entity_caption_label(x):
    core_entities = set()
    for i,j in zip(*x['entityCell'].nonzero()):
        if j==0 and j in x['entityColumn']:
            core_entities.add(x['tableData'][i][j]['surfaceLinks'][0]['target']['id'])
    return list(core_entities), x["_id"], x['tableCaption'], x["processed_tableHeaders"][0]

In [16]:
from operator import add

In [46]:
table_rdd = train_tables.map(get_core_entity_caption_label)
entity_rdd = table_rdd.flatMap(lambda x:[(z,x[1],x[2],x[3]) for z in x[0]])

In [43]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover

In [47]:
table_df = spark.createDataFrame(table_rdd,["entities","table_id","caption","header"])

In [54]:
caption_tokenizer = Tokenizer(inputCol="caption", outputCol="caption_term")
header_tokenizer = Tokenizer(inputCol="header", outputCol="header_term")
list_stopwords = StopWordsRemover.loadDefaultStopWords("english")
caption_remover = StopWordsRemover(inputCol="caption_term", outputCol="caption_term_cleaned")
header_remover = StopWordsRemover(inputCol="header_term", outputCol="header_term_cleaned")

In [87]:
list_stopwords

['i',
 'me',
 'my',
 'myself',
 'we',
 'our',
 'ours',
 'ourselves',
 'you',
 'your',
 'yours',
 'yourself',
 'yourselves',
 'he',
 'him',
 'his',
 'himself',
 'she',
 'her',
 'hers',
 'herself',
 'it',
 'its',
 'itself',
 'they',
 'them',
 'their',
 'theirs',
 'themselves',
 'what',
 'which',
 'who',
 'whom',
 'this',
 'that',
 'these',
 'those',
 'am',
 'is',
 'are',
 'was',
 'were',
 'be',
 'been',
 'being',
 'have',
 'has',
 'had',
 'having',
 'do',
 'does',
 'did',
 'doing',
 'a',
 'an',
 'the',
 'and',
 'but',
 'if',
 'or',
 'because',
 'as',
 'until',
 'while',
 'of',
 'at',
 'by',
 'for',
 'with',
 'about',
 'against',
 'between',
 'into',
 'through',
 'during',
 'before',
 'after',
 'above',
 'below',
 'to',
 'from',
 'up',
 'down',
 'in',
 'out',
 'on',
 'off',
 'over',
 'under',
 'again',
 'further',
 'then',
 'once',
 'here',
 'there',
 'when',
 'where',
 'why',
 'how',
 'all',
 'any',
 'both',
 'each',
 'few',
 'more',
 'most',
 'other',
 'some',
 'such',
 'no',
 'nor',
 '

In [77]:
table_df_tokenizered = header_remover.transform(\
                            header_tokenizer.transform(\
                                caption_remover.transform(\
                                    caption_tokenizer.transform(table_df)))).select("entities","table_id","caption_term_cleaned","header_term_cleaned","header")

In [78]:
table_df_tokenizered.show()

+--------------------+----------+--------------------+--------------------+--------------------+
|            entities|  table_id|caption_term_cleaned| header_term_cleaned|              header|
+--------------------+----------+--------------------+--------------------+--------------------+
|          [27282555]|27281853-1|        [references]| [military, offices]|    military offices|
|   [450099, 1702543]|   27282-1|[main, office, ho...|            [office]|              office|
|  [23867939, 429187]|27282227-1|   [primate, poland]|[catholic, church...|catholic church t...|
|          [27283377]|27282555-1|        [references]| [military, offices]|    military offices|
|          [22583176]|27282731-3|   [external, links]|      [achievements]|        achievements|
|[2086865, 2172188...|27283077-1|     [qualification]|           [country]|             country|
|[1019331, 4019429...|27283077-2|            [venues]|        [gothenburg]|          gothenburg|
|           [4061083]|27283077

In [79]:
caption_term_freq = table_df_tokenizered.select("caption_term_cleaned").rdd \
                        .flatMap(lambda x:[(z,1) for z in x["caption_term_cleaned"]])\
                        .reduceByKey(add).collect()
header_term_freq = table_df_tokenizered.select("header_term_cleaned").rdd \
                        .flatMap(lambda x:[(z,1) for z in x["header_term_cleaned"]])\
                        .reduceByKey(add).collect()
header_freq = table_df_tokenizered.select("header").rdd \
                        .map(lambda x:(x["header"],1))\
                        .reduceByKey(add).collect()

In [80]:
len(header_freq)

20415

In [81]:
entity_df = table_df_tokenizered.select(F.explode("entities").alias("entity"), "table_id","caption_term_cleaned","header_term_cleaned","header")

In [82]:
entity_caption_term_freq = entity_df.select("entity", "caption_term_cleaned").rdd \
                                .flatMap(lambda x:[((x["entity"],z),1) for z in x["caption_term_cleaned"]])\
                                .reduceByKey(add)\
                                .map(lambda x:(x[0][0], [(x[0][1],x[1])]))\
                                .reduceByKey(add).collect()
entity_header_term_freq = entity_df.select("entity", "header_term_cleaned").rdd \
                                .flatMap(lambda x:[((x["entity"],z),1) for z in x["header_term_cleaned"]])\
                                .reduceByKey(add)\
                                .map(lambda x:(x[0][0], [(x[0][1],x[1])]))\
                                .reduceByKey(add).collect()
entity_header_freq = entity_df.select("entity", "header").rdd \
                                .map(lambda x:((x["entity"],x["header"]),1))\
                                .reduceByKey(add)\
                                .map(lambda x:(x[0][0], [(x[0][1],x[1])]))\
                                .reduceByKey(add).collect()

In [84]:
entity_tables = entity_df.select("entity","table_id")\
                    .groupBy("entity").agg(F.collect_list("table_id").alias("tables"))\
                    .rdd.map(lambda x:(x['entity'],x['tables'])).collect()

In [100]:
import pickle

In [105]:
with open("../../data/entity_tables.pkl","wb") as f:
    pickle.dump(entity_tables, f)

In [103]:
for e in entity_header_freq:
    entity_header_freq[e] = [sum([count for _,count in entity_header_freq[e]]),dict(entity_header_freq[e])]

with open("../../data/entity_header_freq.pkl","wb") as f:
    pickle.dump(entity_header_freq, f)

In [106]:
entity_header_term_freq = dict(entity_header_term_freq)
for e in entity_header_term_freq:
    entity_header_term_freq[e] = [sum([count for _,count in entity_header_term_freq[e]]),dict(entity_header_term_freq[e])]

with open("../../data/entity_header_term_freq.pkl","wb") as f:
    pickle.dump(entity_header_term_freq, f)

In [107]:
entity_caption_term_freq = dict(entity_caption_term_freq)
for e in entity_caption_term_freq:
    entity_caption_term_freq[e] = [sum([count for _,count in entity_caption_term_freq[e]]),dict(entity_caption_term_freq[e])]

with open("../../data/entity_caption_term_freq.pkl","wb") as f:
    pickle.dump(entity_caption_term_freq, f)

In [109]:
caption_term_freq = dict(caption_term_freq)
with open("../../data/caption_term_freq.pkl","wb") as f:
    pickle.dump([sum([count for _,count in caption_term_freq.items()]),caption_term_freq], f)
    
header_term_freq = dict(header_term_freq)
with open("../../data/header_term_freq.pkl","wb") as f:
    pickle.dump([sum([count for _,count in header_term_freq.items()]),header_term_freq], f)
    
header_freq = dict(header_freq)
with open("../../data/header_freq.pkl","wb") as f:
    pickle.dump([sum([count for _,count in header_freq.items()]),header_freq], f)

In [99]:
for e in entity_tables:
    if len(entity_tables[e]) != sum([count for _,count in entity_header_freq[e]]):
        print(e, len(entity_tables[e]), sum([count for _,count in entity_header_freq[e]]))
        break

In [108]:
caption_term_freq[0]

('references', 35319)

In [102]:
entity_header_freq[1677]

[7,
 {'titles in pretence': 1,
  'descendent': 2,
  'image': 1,
  'political offices': 1,
  'name of descendant': 1,
  'name': 1}]

In [113]:
entity_rdd.filter(lambda x:x[0]==5839439).take(10)

[(5839439, '39405618-3', 'Other', 'record name')]

In [3]:
from metric import *

In [40]:
with open("../../data/dev_result.pkl","rb") as f:
    dev_result = pickle.load(f)

In [7]:
def load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=1):
    entity_vocab = {}
    bad_title = 0
    few_entity = 0
    with open(os.path.join(data_dir, 'entity_vocab.txt'), 'r', encoding="utf-8") as f:
        for line in f:
            _, entity_id, entity_title, entity_mid, count = line.strip().split('\t')
            if ignore_bad_title and entity_title == '':
                bad_title += 1
            elif int(count) < min_ent_count:
                few_entity += 1
            else:
                entity_vocab[len(entity_vocab)] = {
                    'wiki_id': int(entity_id),
                    'wiki_title': entity_title,
                    'mid': entity_mid,
                    'count': int(count)
                }
    print('total number of entity: %d\nremove because of empty title: %d\nremove because count<%d: %d'%(len(entity_vocab),bad_title,min_ent_count,few_entity))
    return entity_vocab

In [9]:
entity_vocab = load_entity_vocab("../../data", True, 2)
train_all_entities = set([x['wiki_id'] for _,x in entity_vocab.items()])

total number of entity: 368789
remove because of empty title: 5426
remove because count<2: 467625


In [76]:
dev_final = {}
for id,result in dev_result.items():
    _, target_entities, pneural, pall, pee, pce, ple, cand_e, cand_c = result
    target_entities = set(target_entities)
    cand_e = set([e for e in cand_e if e in train_all_entities])
    cand_c = set([e for e in cand_c if e in train_all_entities])
    cand_all = set([e for e in cand_c|cand_e if e in train_all_entities])
    recall_e = len(cand_e&target_entities)/len(target_entities)
    recall_c = len(cand_c&target_entities)/len(target_entities)
    recall_all = len(cand_all&target_entities)/len(target_entities)
    
    ranked_neural = sorted(pneural.items(),key=lambda z:z[1]+30*pee[z[0]],reverse=True)
    ranked_neural = [1 if z[0] in target_entities else 0 for z in ranked_neural if z[0] in train_all_entities]
    ap_neural = average_precision(ranked_neural)
    
    ranked_all = sorted(pall.items(),key=lambda z:100*pee[z[0]]+1*pce[z[0]]+0.5*ple[z[0]],reverse=True)
    ranked_all = [1 if z[0] in target_entities else 0 for z in ranked_all if z[0] in train_all_entities]
    ap_all = average_precision(ranked_all)
    
#     ranked_e = sorted(pee.items(),key=lambda z:z[1],reverse=True)
#     ranked_e = [1 if z[0] in target_entities else 0 for z in ranked_e if z[0] in train_all_entities]
#     assert len(ranked_e) == len(ranked_neural)
#     ap_e = average_precision(ranked_e)
    
#     ranked_c = sorted(pce.items(),key=lambda z:z[1],reverse=True)
#     ap_c = average_precision([1 if z[0] in target_entities else 0 for z in ranked_c if z[0] in train_all_entities])
    
#     ranked_l = sorted(ple.items(),key=lambda z:z[1],reverse=True)
#     ap_l = average_precision([1 if z[0] in target_entities else 0 for z in ranked_l if z[0] in train_all_entities])
    
    dev_final[id] = [recall_all,recall_e,recall_c,ap_neural,ap_all,ap_e,ap_c,ap_l]

for i in range(8):
    print(np.mean([z[i] for _,z in dev_final.items()]))

0.7994353789875439
0.7628692822631273
0.3988389334320421
0.5528707452445315
0.5310890647775968
0.43977591036414587
0.06322279546749321
0.11025661380426


In [39]:
dev_result['13591903-1'][2]

{22890502: -10.864021301269531,
 31750: -8.485289573669434,
 38963209: -11.514930725097656,
 38761483: -8.730037689208984,
 41826324: -7.295381546020508,
 34387991: -10.285350799560547,
 38770715: -7.950537204742432,
 38973476: -8.864492416381836,
 52882469: -6.8088507652282715,
 26667: -9.346258163452148,
 24680494: -11.487751960754395,
 37678127: -5.8102922439575195,
 271409: -8.295778274536133,
 171058: -10.087865829467773,
 1822771: -8.71945858001709,
 22240306: -5.999110698699951,
 38957122: -11.334463119506836,
 3261506: -10.804362297058105,
 4769866: -6.925134181976318,
 52882506: -7.622045040130615,
 1000530: -8.985393524169922,
 861274: -13.701908111572266,
 72796: -13.323681831359863,
 28260447: -7.540836334228516,
 17514: -10.122989654541016,
 18308206: -14.094464302062988,
 31853: -9.865266799926758,
 31026290: -6.690680027008057,
 38973556: -8.724361419677734,
 26748: -9.640986442565918,
 52930689: -6.52101469039917,
 242821: -9.655917167663574,
 25734: -9.333027839660645,

In [26]:
len([1 for z in dev_final if z[4]>=z[5]])

2111

In [36]:
[(i,z[3],z[4],z[5]) for i, z in dev_final.items() if z[4]>=z[5]]

[('27292980-1',
  0.056598001463750915,
  0.07847670450666021,
  0.03099193779270056),
 ('6677720-3', 0.09771898129625992, 0.3849797143275404, 0.2857327761868526),
 ('66818-2', 0.008466366077789217, 0.019806186058798112, 0.010479058925306118),
 ('670900-1', 0.833895174620261, 1.0, 1.0),
 ('15670619-3', 0.4966919443947785, 0.6409771156784415, 0.5340251014266074),
 ('15680241-2', 0.0, 0.0, 0.0),
 ('1568033-4', 0.06666666666666667, 0.05263157894736842, 0.02),
 ('15683270-1', 0.22857142857142856, 0.27976190476190477, 0.13370026373122348),
 ('1568649-3', 0.9008736213281666, 0.6682317283991925, 0.515588788316061),
 ('5639548-1', 0.24405637763365717, 0.13862030846748963, 0.09885448735580918),
 ('5645700-1', 0.0, 0.0, 0.0),
 ('5651527-4', 1.0, 1.0, 1.0),
 ('5652552-1', 0.29078054645937323, 0.1298340701345479, 0.08117577523359497),
 ('5653012-3', 0.11506905632284803, 0.056484878066629586, 0.04486438656104909),
 ('5653960-3', 0.43123959357587677, 0.04849157265237174, 0.023182540727918526),
 ('56