In [1]:
from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm import trange
from tqdm.autonotebook import tqdm

from data_loader.hybrid_data_loaders import *
from data_loader.header_data_loaders import *
from data_loader.CT_Wiki_data_loaders import *
from model.configuration import TableConfig
from model.model import HybridTableMaskedLM, HybridTableCER, TableHeaderRanking, HybridTableCT
from model.transformers import BertTokenizer, WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
from utils.util import *
from baselines.row_population.metric import average_precision,ndcg_at_k
from baselines.cell_filling.cell_filling import *
from model import metric



In [2]:
logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    'CER': (TableConfig, HybridTableCER, BertTokenizer),
    'CF' : (TableConfig, HybridTableMaskedLM, BertTokenizer),
    'HR': (TableConfig, TableHeaderRanking, BertTokenizer),
    'CT': (TableConfig, HybridTableCT, BertTokenizer)
}

In [3]:
data_dir = 'data/wikisql_entity/'

In [4]:
config_name = "configs/table-base-config.json"
device = torch.device('cuda')
entity_vocab = load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=2)
entity_wikid2id = {entity_vocab[x]['wiki_id']:x for x in entity_vocab}

total number of entity: 368793
remove because of empty title: 5426
remove because count<2: 467625


# Viz

In [None]:
checkpoint = "output/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))

In [None]:
checkpoint['table.embeddings.ent_embeddings.weight'].shape

In [None]:
dump_loc = "output/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0"
entity_vocab_with_type = []
with open("data/wikisql_entity/entity_vocab_with_type.tsv", 'r', encoding='utf8') as f:
    next(f)
    for line in f:
        wiki_id = line.strip().split('\t')[0]
        entity_vocab_with_type.append(int(wiki_id))
with open(os.path.join(dump_loc, "entity_embedding_with_type.tsv"), "w") as f_e:
    for wiki_id in entity_vocab_with_type:
        f_e.write('{}\n'.format('\t'.join([str(z) for z in checkpoint['table.embeddings.ent_embeddings.weight'][entity_wikid2id[wiki_id]].tolist()])))

# CER

In [None]:
dataset = WikiHybridTableDataset(data_dir,entity_vocab,max_cell=100, max_input_tok=350, max_input_ent=150, src="dev", max_length = [50, 10, 10], force_new=False, tokenizer = None, mode=0)

In [None]:
checkpoint = "./output/CER/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0_seed_1_10000/checkpoint-7500/pytorch_model.bin"


config_class, model_class, _ = MODEL_CLASSES['CER']
config = config_class.from_pretrained(config_name)
config.output_attentions = True

model = model_class(config, is_simple=True)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
all_entity_set = set(dataset.entity_wikid2id.keys())
tables_ignored = 0
dev_result = {}
cached_baseline = "data/wikisql_entity/dev_result_CER.pkl"
with open(cached_baseline, "rb") as f:
    cached_baseline_result = pickle.load(f)

In [None]:
seed_num = 1
results = {}
with open(os.path.join(data_dir,"dev_tables.jsonl"), 'r') as f:
    for line in tqdm(f):
        table = json.loads(line.strip())
        table_id = table.get("_id", "")
        pgEnt = table["pgId"]
        if not pgEnt in all_entity_set:
            pgEnt = -1
        pgTitle = table.get("pgTitle", "").lower()
        secTitle = table.get("sectionTitle", "").lower()
        caption = table.get("tableCaption", "").lower()
        headers = table.get("processed_tableHeaders", [])
        rows = table.get("tableData", {})
        entity_columns = table.get("entityColumn", [])
        headers = [headers[j] for j in entity_columns]
        entity_cells = np.array(table.get("entityCell",[[]]))
        core_entities = []
        num_rows = len(rows)
        for i in range(num_rows):
            if entity_cells[i,0] == 1:
                entity = rows[i][0]['surfaceLinks'][0]['target']['id']
                entity_text = rows[i][0]['text']
                core_entities.append([entity_text,entity])
        core_entities = [z for z in core_entities if z[1] in all_entity_set]
        if len(core_entities) < 5:
            tables_ignored += 1
            continue
        seed_entities = [z[1] for z in core_entities[:seed_num]]
        seed_entities_text = [z[0] for z in core_entities[:seed_num]]
        target_entities = set([z[1] for z in core_entities[seed_num:]])
        seeds_1, _, _, pall, pee, pce, ple, cand_e, cand_c = cached_baseline_result[table_id]
        if len(target_entities) == 0:
            tables_ignored += 1
            continue
        results[table_id] = {}
        assert seeds_1 == set(seed_entities)
        cand_e = set([z for z in cand_e if z in all_entity_set and z not in seed_entities])
        cand_c = set([z for z in cand_c if z in all_entity_set and z not in seed_entities])
        entity_cand = list(cand_e|cand_c)
        
        pee = {k:v for k,v in pee.items() if k in entity_cand}
        pce = {k:v for k,v in pce.items() if k in entity_cand}
        ple = {k:v for k,v in ple.items() if k in entity_cand}
        pall = {k:v for k,v in pall.items() if k in entity_cand}

        input_tok, input_tok_type, input_tok_pos, input_mask,\
            input_ent, input_ent_text, input_ent_text_length, input_ent_type, candidate_entity_set = CER_build_input(pgEnt, pgTitle, secTitle, caption, headers[0], seed_entities, seed_entities_text, entity_cand, dataset)
        
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_ent = input_ent.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent_type = input_ent_type.to(device)
        input_mask = input_mask.to(device)
        candidate_entity_set = candidate_entity_set.to(device)
        
        with torch.no_grad():
            ent_outputs = model(input_tok, input_tok_type, input_tok_pos, input_mask,
                            input_ent, input_ent_text, input_ent_text_length, input_ent_type, input_mask,
                            candidate_entity_set, None, None)
            ent_prediction_scores = ent_outputs[0][0].tolist()

            p_neural = {}
            
            for i, entity in enumerate(entity_cand):
                p_neural[entity] = ent_prediction_scores[i]
        results[table_id] = {
            'pgTitle': pgTitle,
            'secTitle': secTitle,
            'caption': caption,
            'headers': headers,
            'cand_all': entity_cand,
            'cand_e': cand_e,
            'cand_c': cand_c,
            'seed_e': seed_entities,
            'target_e': target_entities,
            'p_neural': p_neural,
            'pee': pee,
            'pce': pce,
            'ple': ple,
            'pall': pall
        }

In [None]:
len(results)

In [None]:
results.keys()

In [None]:
print('recall all', \
      np.mean([len(set(x['cand_all'])&x['target_e'])/len(x['target_e']) for _,x in results.items()]), \
      np.mean([len(set(x['cand_all'])) for _,x in results.items()]))
print('recall e', \
      np.mean([len(x['cand_e']&x['target_e'])/len(x['target_e']) for _,x in results.items()]), \
      np.mean([len(set(x['cand_e'])) for _,x in results.items()]))
print('recall c', \
      np.mean([len(x['cand_c']&x['target_e'])/len(x['target_e']) for _,x in results.items()]), \
      np.mean([len(set(x['cand_c'])) for _,x in results.items()]))

In [None]:
def get_ap(scores, target_e):
    ranked = sorted(scores.items(),key=lambda z:z[1],reverse=True)
    ranked_l = [1 if z[0] in target_e else 0 for z in ranked]
    ap = average_precision(ranked_l)
    return ap

In [None]:
print('map neural', np.mean([get_ap(x['p_neural'],x['target_e']) for _,x in results.items()]))
print('map neural - only cand_e', np.mean([get_ap({z:score if z in x['cand_e'] else -10000 for z, score in x['p_neural'].items()},x['target_e']) for _,x in results.items()]))
print('map ee', np.mean([get_ap(x['pee'],x['target_e']) for _,x in results.items()]))
print('map le', np.mean([get_ap(x['ple'],x['target_e']) for _,x in results.items()]))
print('map ce', np.mean([get_ap(x['pce'],x['target_e']) for _,x in results.items()]))
print('map all', np.mean([get_ap(x['pall'],x['target_e']) for _,x in results.items()]))

In [None]:
for w in [0.999, 0.99, 0.9, 0.5, 0.1, 0.05, 0.06, 0.07, 0.08, 0.09, 0.01]:
    print('map neural - ensemble {}'.format(w), np.mean([get_ap({z:w*score+(1-w)*x['pee'][z] for z, score in x['p_neural'].items()},x['target_e']) for _,x in results.items()]))

In [None]:
inspect_ids = []
for table_id, x in results.items():
    recall = len(set(x['cand_all'])&x['target_e'])/len(x['target_e'])
    ap_neural = get_ap(x['p_neural'],x['target_e'])
    ap_ee = get_ap(x['pee'],x['target_e'])
    if recall != 0 and (ap_neural < 0.4 or ap_neural < ap_ee):
        inspect_ids.append(table_id)
print(len(inspect_ids))

In [None]:
def inspect_result(result):
    ap_neural = get_ap(result['p_neural'],result['target_e'])
    ap_ee = get_ap(result['pee'],result['target_e'])
    print('ap_neural: {}\nap_ee: {}'.format(ap_neural, ap_ee))
    print('{} - {} - {}'.format(result['pgTitle'], result['secTitle'], result['caption']))
    print(result['headers'])
    print('seed:')
    print('; '.join([entity_vocab[entity_wikid2id[e]]['wiki_title'] for e in result['seed_e']]))
    target_entities = [entity_vocab[entity_wikid2id[z]] for z in result['target_e']]
    print('target:\n%s'%('; '.join([z['wiki_title'] for z in target_entities])))
    ranked_neural = sorted(result['p_neural'].items(),key=lambda z:z[1],reverse=True)
    print('neural:')
    print('; '.join([
        '[%s:%f]'%(entity_vocab[entity_wikid2id[e]]['wiki_title'],score) if e in result['target_e'] \
        else '%s:%.2f'%(entity_vocab[entity_wikid2id[e]]['wiki_title'],score)
    for e,score in ranked_neural[:10]]))
    ranked_e = sorted(result['pee'].items(),key=lambda z:z[1],reverse=True)
    print('ee:')
    print('; '.join([
        '[%s:%f]'%(entity_vocab[entity_wikid2id[e]]['wiki_title'],score) if e in result['target_e'] \
        else '%s:%.2f'%(entity_vocab[entity_wikid2id[e]]['wiki_title'],score)
    for e,score in ranked_e[:10]]))

In [None]:
inspect_result(results[inspect_ids[3]])
print(len([id for id in inspect_ids if results[id]['headers'][0] in ['opponent', 'team 1', 'home team']]))

In [None]:
inspect_result(results[inspect_ids[6]])
print(len([id for id in inspect_ids if 'miss dominican republic' in results[id]['pgTitle']]))

In [None]:
inspect_result(results[inspect_ids[32]])
print(len([id for id in inspect_ids if results[id]['headers'][0]=='constituency']))

# CF

In [None]:
config_class, model_class, _ = MODEL_CLASSES['CF']
config = config_class.from_pretrained(config_name)
config.output_attentions = True

checkpoint = "output/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
CF = cell_filling(data_dir)

In [None]:
with open(os.path.join(data_dir,"CF_dev_data.json"), 'r') as f:
    dev_data = json.load(f)
dataset = WikiHybridTableDataset(data_dir,entity_vocab,max_cell=100, max_input_tok=350, max_input_ent=150, src="dev", max_length = [50, 10, 10], force_new=False, tokenizer = None, mode=0)

In [None]:
results = []
for table_id,pgEnt,pgTitle,secTitle,caption,(h1, h2),data_sample in tqdm(dev_data):
    core_entities = []
    core_entities_text = []
    target_entities = []
    all_entity_cand = set()
    entity_cand = []
    for (core_e, core_e_text), target_e in data_sample:
        assert target_e in entity_wikid2id
        core_entities.append(core_e)
        core_entities_text.append(core_e_text)
        target_entities.append(target_e)
        cands = CF.get_cand_row(core_e, h2)
        cands = {key:value for key,value in cands.items() if key in entity_wikid2id}
        entity_cand.append(cands)
        all_entity_cand |= set(cands.keys()) 
    all_entity_cand = list(all_entity_cand)
    input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
        input_ent, input_ent_text, input_ent_text_length, input_ent_type, input_ent_mask, \
        candidate_entity_set = CF_build_input(pgEnt, pgTitle, secTitle, caption, [h1, h2], core_entities, core_entities_text, all_entity_cand, dataset)
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_tok_mask = input_tok_mask.to(device)
    input_ent_text = input_ent_text.to(device)
    input_ent_text_length = input_ent_text_length.to(device)
    input_ent = input_ent.to(device)
    input_ent_type = input_ent_type.to(device)
    input_ent_mask = input_ent_mask.to(device)
    candidate_entity_set = candidate_entity_set.to(device)
    with torch.no_grad():
        _, ent_outputs = model(input_tok, input_tok_type, input_tok_pos, input_tok_mask,
                        input_ent_text, input_ent_text_length, None,
                        input_ent, input_ent_type, input_ent_mask, candidate_entity_set)
        num_sample = len(target_entities)
        ent_prediction_scores = ent_outputs[0][0,num_sample+1:].tolist()
    result = []
    for i, target_e in enumerate(target_entities):
        predictions = ent_prediction_scores[i]
        if len(entity_cand[i]) == 0:
            continue
        tmp_cand_scores = []
        for j, cand_e in enumerate(all_entity_cand):
            if cand_e in entity_cand[i]:
                tmp_cand_scores.append([cand_e, predictions[j]])
        sorted_cand_scores =  sorted(tmp_cand_scores, key=lambda z:z[1], reverse=True)
        sorted_cands = [z[0] for z in sorted_cand_scores]
        base_sorted_cands = CF.rank_cand_h(h2, entity_cand[i])
        result.append([target_e, entity_cand[i], sorted_cands, base_sorted_cands])
    results.append({
            'pgTitle': pgTitle,
            'secTitle': secTitle,
            'caption': caption,
            'headers': [h1, h2],
            'result': result
        })

In [None]:
def get_precision(result):
    recall = 0
    precision_neural = [0, 0, 0, 0]
    precision_base = [0, 0, 0, 0]
    for target_e, cand, p_neural, p_base in result:
        if target_e in cand:
            recall += 1
        if target_e == p_neural[0]:
            precision_neural[0] += 1
        if target_e == p_base[0]:
            precision_base[0] += 1
        if target_e in p_neural[:3]:
            precision_neural[1] += 1
        if target_e in p_neural[:5]:
            precision_neural[2] += 1
        if target_e in p_neural[:10]:
            precision_neural[3] += 1
        if target_e in p_base[:3]:
            precision_base[1] += 1
        if target_e in p_base[:5]:
            precision_base[2] += 1
        if target_e in p_base[:10]:
            precision_base[3] += 1
    if recall != 0:
        return recall/len(result), [z/recall for z in precision_neural], [z/recall for z in precision_base]
    else:
        return 0, [0 for z in precision_neural], [0 for z in precision_base]

In [None]:
final_results = [get_precision(x['result']) for x in results]
print('recall', np.mean([x[0] for x in final_results]))
print('neural')
print('p@1', np.mean([x[1][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[1][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[1][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[1][3] for x in final_results if x[0]!=0]))
print('base')
print('p@1', np.mean([x[2][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[2][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[2][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[2][3] for x in final_results if x[0]!=0]))

In [None]:
print('neural')
print('p@1', np.mean([x[1][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[1][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[1][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[1][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('base')
print('p@1', np.mean([x[2][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[2][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[2][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[2][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))

In [None]:
from collections import Counter

In [None]:
better_headers = Counter([results[i]['headers'][1] for i,x in enumerate(final_results) if x[0]!=0 and x[1][0]>x[2][0]])

In [None]:
better_headers.most_common()

In [None]:
worse_headers = Counter([results[i]['headers'][1] for i,x in enumerate(final_results) if x[0]!=0 and x[1][0]<x[2][0]])

In [None]:
worse_headers.most_common()

In [None]:
error_headers = Counter([' | '.join(results[i]['headers']) for i,x in enumerate(final_results) if x[0]!=0 and x[1][0]<0.7])

In [None]:
error_headers.most_common(10)

In [None]:
miss_headers = Counter([' | '.join(results[i]['headers']) for i,x in enumerate(final_results) if x[0]==0])

In [None]:
miss_headers.most_common(10)

# Attribute Recommendation

In [None]:
config_class, model_class, _ = MODEL_CLASSES['HR']
config = config_class.from_pretrained(config_name)
config.output_attentions = True

train_dataset = WikiHeaderDataset(data_dir,max_input_tok=350, src="train", max_length = [50, 10], force_new=False, tokenizer = None)
eval_dataset = WikiHeaderDataset(data_dir,max_input_tok=350, src="dev", max_length = [50, 10], force_new=False, tokenizer = None)
config.__dict__['header_vocab_size'] = len(eval_dataset.header_vocab)

checkpoint = "output/HR/hybrid/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0_seed_0/checkpoint-130000/"
# checkpoint = "output/HR/bert_seed_0/"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
eval_batch_size = 64
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = WikiHeaderLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size,is_train=False,seed=0)

In [None]:
results = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    _,input_tok, input_tok_type, input_tok_pos, \
            input_header, input_header_type, \
            input_mask, seed_header, target_header = batch
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_header = input_header.to(device)
    input_header_type = input_header_type.to(device)
    input_mask = input_mask.to(device)
    seed_header = seed_header.to(device)
    target_header = target_header.to(device)
    with torch.no_grad():
        header_outputs = model(input_tok, input_tok_type, input_tok_pos,
                        input_header, input_header_type, input_mask,
                        seed_header, target_header)
        header_loss = header_outputs[0]
        header_prediction_scores = header_outputs[1]
        results.extend(header_prediction_scores.tolist())

In [None]:
def get_ap(scores, target_e):
    ranked = np.argsort(scores)[::-1]
    target_e = set(target_e)
    ranked_l = [1 if z in target_e else 0 for z in ranked]
    ap = average_precision(ranked_l)
    return ap

In [None]:
maps = []
for i, x in tqdm(enumerate(results)):
    maps.append(get_ap(x, eval_dataset[i][5][1:]))
print(np.mean(maps))

In [None]:
errors = [i for i,ap in enumerate(maps) if ap<0.5]

In [None]:
display([eval_dataset.header_vocab[x] for x in np.argsort(results[1])[::-1][:10]])
display([eval_dataset.header_vocab[x] for x in eval_dataset[1][5][1:]])

In [None]:
eval_dataset.tokenizer.decode(eval_dataset[0][1])

In [None]:
def inspect(i):
    print(eval_dataset.tokenizer.decode(eval_dataset[i][1]))
    print(maps[i])
    print('; '.join([eval_dataset.header_vocab[x] for x in np.argsort(results[i])[::-1][:10]]))
    print('; '.join([eval_dataset.header_vocab[x] for x in eval_dataset[i][5][1:]]))

In [None]:
inspect(errors[23])

In [None]:
dump_loc = "output/HR/hybrid/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0_seed_0/"
with open(os.path.join(dump_loc, "header_embedding.tsv"), "w") as f_e, open(os.path.join(dump_loc, "header_names.tsv"), "w", encoding='utf8') as f_n:
    for i, name in eval_dataset.header_vocab.items():
        f_n.write('{}\n'.format(name))
        f_e.write('{}\n'.format('\t'.join([str(z) for z in model.cls.weight.data[i].tolist()])))

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
tfidf = TfidfVectorizer(
    analyzer=lambda x:x,
    token_pattern=None)
train_tfidf = tfidf.fit_transform([x[1] for x in train_dataset])

In [None]:
tfidf.transform([eval_dataset[0][1]])

In [None]:
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors(n_neighbors=1,metric='cosine')
neigh.fit(train_tfidf)

In [None]:
from tqdm.autonotebook import tqdm
for k in [1,3,5,10]:
    print(k)
    maps_base = []
    recalls = []
    for x in tqdm(eval_dataset):
        header_count = Counter()
        neighbor = neigh.kneighbors(tfidf.transform([x[1]]), k, return_distance=False)
        for n in neighbor.reshape([-1]):
            header_count.update(train_dataset[n][5][1:])
        target_e = set(x[5][1:])
        recalls.append(len([z for z in header_count if z in target_e])/len(target_e))
        ap = average_precision([1 if z in target_e else 0 for z,_ in header_count.most_common()]+[1 if z in target_e else 0 for z in range(config.header_vocab_size) if z not in header_count])
        maps_base.append(ap)
    print(np.mean(maps_base))
    print(np.mean(recalls))

In [None]:
from tqdm.autonotebook import tqdm
for k in [15,30,50,100]:
    print(k)
    maps_base = []
    recalls = []
    for x in tqdm(eval_dataset):
        header_count = Counter()
        neighbor = neigh.kneighbors(tfidf.transform([x[1]]), k, return_distance=False)
        for n in neighbor.reshape([-1]):
            header_count.update(train_dataset[n][5][1:])
        target_e = set(x[5][1:])
        recalls.append(len([z for z in header_count if z in target_e])/len(target_e))
        ap = average_precision([1 if z in target_e else 0 for z,_ in header_count.most_common()]+[1 if z in target_e else 0 for z in range(config.header_vocab_size) if z not in header_count])
        maps_base.append(ap)
    print(np.mean(maps_base))
    print(np.mean(recalls))

In [None]:
config.header_vocab_size

# CT

In [5]:
type_vocab = load_type_vocab("./data/wikisql_entity")
eval_dataset = WikiCTDataset(data_dir, entity_vocab, type_vocab, max_input_tok=500, src="dev", max_length = [50, 10, 10], force_new=False, tokenizer = None)

try loading preprocessed data from data/wikisql_entity/procressed_WikiCT/dev.pickle


In [6]:
def average_precision(output, relevance_labels):
    with torch.no_grad():
        sorted_output = torch.argsort(output, dim=-1, descending=True)
        sorted_labels = torch.gather(relevance_labels, -1, sorted_output).float()
        cum_correct = torch.cumsum(sorted_labels, dim=-1)
        cum_precision = cum_correct / torch.arange(start=1,end=cum_correct.shape[-1]+1, device=cum_correct.device)[None, :]
        cum_precision = cum_precision * sorted_labels
        total_valid = torch.sum(sorted_labels, dim=-1)
        total_valid[total_valid==0] = 1
        average_precision = torch.sum(cum_precision, dim=-1)/total_valid

    return average_precision

In [22]:
per_type_accuracy = {}
per_type_precision = {}
per_type_recall = {}
per_type_f1 = {}
map = {}
precision = {}
recall = {}
f1 = {}

In [23]:
from tqdm.autonotebook import tqdm
checkpoints = [
    "output/CT/0/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0/checkpoint-20000/pytorch_model.bin",
    "output/CT/1/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0/checkpoint-20000/pytorch_model.bin",
    "output/CT/2/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0/checkpoint-20000/pytorch_model.bin",
    "output/CT/3/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0/checkpoint-20000/pytorch_model.bin",
    "output/CT/4/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0/checkpoint-20000/pytorch_model.bin"
]
for mode in [0,1,2,3,4]:
    print(mode)
    config_class, model_class, _ = MODEL_CLASSES['CT']
    config = config_class.from_pretrained(config_name)
    config.class_num = len(type_vocab)
    config.mode = mode
    model = model_class(config, is_simple=True)
    checkpoint = checkpoints[mode]
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    eval_batch_size = 20
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = CTLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
    eval_loss = 0.0
    eval_map = 0.0
    nb_eval_steps = 0
    eval_targets = []
    eval_prediction_scores = []
    eval_pred = []
    eval_mask = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        table_id, input_tok, input_tok_type, input_tok_pos, input_tok_mask, \
            input_ent_text, input_ent_text_length, input_ent, input_ent_type, input_ent_mask, \
            column_entity_mask, column_header_mask, labels_mask, labels = batch
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        column_entity_mask = column_entity_mask.to(device)
        column_header_mask = column_header_mask.to(device)
        labels_mask = labels_mask.to(device)
        labels = labels.to(device)
        if mode == 1:
            input_ent_mask = input_ent_mask[:,:,input_tok_mask.shape[1]:]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
        elif mode == 2:
            input_tok_mask = input_tok_mask[:,:,:input_tok_mask.shape[1]]
            input_ent_text = None
            input_ent_text_length = None
            input_ent = None
            input_ent_type = None
            input_ent_mask = None
        elif mode == 3:
            input_ent = None
        elif mode == 4:
            input_ent_mask = input_ent_mask[:,:,input_tok_mask.shape[1]:]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent = None
        with torch.no_grad():
            outputs = model(input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
                input_ent_text, input_ent_text_length, input_ent, input_ent_type, input_ent_mask, column_entity_mask, column_header_mask, labels_mask, labels)
            loss = outputs[0]
            prediction_scores = outputs[1]
            # pdb.set_trace()
            ap = metric.average_precision(prediction_scores.view(-1, config.class_num), labels.view((-1, config.class_num)))
            map = (ap*labels_mask.view(-1)).sum()/labels_mask.sum()
            eval_loss += loss.mean().item()
            eval_map += map.item()
            eval_targets.extend(labels.view(-1, config.class_num).tolist())
            eval_prediction_scores.extend(prediction_scores.view(-1, config.class_num).tolist())
            eval_pred.extend((torch.sigmoid(prediction_scores.view(-1, config.class_num))>0.5).tolist())
            eval_mask.extend(labels_mask.view(-1).tolist())
        nb_eval_steps += 1
    print(eval_map/nb_eval_steps)
    eval_targets = np.array(eval_targets)
    eval_prediction_scores = np.array(eval_prediction_scores)
    eval_mask = np.array(eval_mask)
    eval_prediction_ranks = np.argsort(np.argsort(-eval_prediction_scores))
    eval_pred = np.array(eval_pred)
    eval_tp = eval_mask[:,np.newaxis]*eval_pred*eval_targets
    eval_precision = np.sum(eval_tp,axis=0)/np.sum(eval_mask[:,np.newaxis]*eval_pred,axis=0)
    eval_precision = np.nan_to_num(eval_precision, 1)
    eval_recall = np.sum(eval_tp,axis=0)/np.sum(eval_mask[:,np.newaxis]*eval_targets,axis=0)
    eval_recall = np.nan_to_num(eval_recall, 1)
    eval_f1 = 2*eval_precision*eval_recall/(eval_precision+eval_recall)
    eval_f1 = np.nan_to_num(eval_f1, 0)
    per_type_instance_num = np.sum(eval_mask[:,np.newaxis]*eval_targets,axis=0)
    per_type_correct_instance_num = np.sum(eval_mask[:,np.newaxis]*(eval_prediction_ranks<eval_targets.sum(axis=1)[:,np.newaxis])*eval_targets,axis=0)
    per_type_accuracy[mode] = per_type_correct_instance_num/per_type_instance_num
    per_type_precision[mode] = eval_precision
    per_type_recall[mode] = eval_recall
    per_type_f1[mode] = eval_f1
    precision[mode] = np.sum(eval_tp)/np.sum(eval_mask[:,np.newaxis]*eval_pred)
    recall[mode] = np.sum(eval_tp)/np.sum(eval_mask[:,np.newaxis]*eval_targets)
    f1[mode] = 2*precision[mode]*recall[mode]/(precision[mode]+recall[mode])

0


HBox(children=(IntProgress(value=0, description='Evaluating', max=232, style=ProgressStyle(description_width='…


0.9902757586076342




1


HBox(children=(IntProgress(value=0, description='Evaluating', max=232, style=ProgressStyle(description_width='…


0.9831231909579244
2


HBox(children=(IntProgress(value=0, description='Evaluating', max=232, style=ProgressStyle(description_width='…


0.964737164049313
3


HBox(children=(IntProgress(value=0, description='Evaluating', max=232, style=ProgressStyle(description_width='…


0.9820393208285858
4


HBox(children=(IntProgress(value=0, description='Evaluating', max=232, style=ProgressStyle(description_width='…


0.9482440655601436


In [12]:
for t,i in type_vocab.items():
    print('%s %.3f %.3f %.3f %.3f %.3f  %.3f'%(t, per_type_instance_num[i], per_type_accuracy[0][i], per_type_accuracy[1][i], per_type_accuracy[2][i], per_type_accuracy[3][i], per_type_accuracy[4][i]))
    print()

music.group_member 16.000 0.875 0.750 0.750 0.875  0.125

people.family_member 6.000 0.667 0.667 0.667 0.833  0.500

theater.theater_genre 2.000 1.000 1.000 1.000 1.000  0.500

book.written_work 2.000 0.000 0.500 0.000 0.000  0.000

soccer.football_league_season 20.000 0.950 0.950 0.850 1.000  0.900

cricket.cricket_stadium 8.000 1.000 1.000 1.000 1.000  1.000

location.location 2480.000 0.998 0.992 0.987 0.995  0.971

education.educational_institution 132.000 0.992 0.977 0.917 0.962  0.939

royalty.noble_person 25.000 0.960 0.960 0.960 0.960  0.960

tv.tv_producer 5.000 0.600 0.600 0.400 0.600  0.400

time.event 1349.000 0.993 0.993 0.938 0.985  0.979

film.music_contributor 13.000 0.692 0.615 0.846 0.692  0.385

film.film_genre 17.000 0.941 0.941 0.941 1.000  0.941

royalty.kingdom 3.000 1.000 1.000 0.667 1.000  0.333

location.capital_of_administrative_division 51.000 0.725 0.588 0.549 0.627  0.255

biology.organism 1.000 0.000 1.000 0.000 0.000  0.000

music.musical_group 9.000 1.0

In [30]:
with open('output/CT/dev_per_type_result.csv', 'w') as f:
    for t,i in type_vocab.items():
        f.write('%s,%d'%(t, per_type_instance_num[i]))
        for j in range(5):
            f.write(',%.3f,%.3f,%.3f'%(per_type_f1[j][i],per_type_precision[j][i],per_type_recall[j][i]))
        f.write('\n')

In [None]:
eval_prediction_scores

In [24]:
f1

{0: 0.9459584972991469,
 1: 0.9291703659859919,
 2: 0.8933839637797575,
 3: 0.9254248876733736,
 4: 0.8715658479062649}

In [25]:
precision

{0: 0.9461666340700176,
 1: 0.93782048090805,
 2: 0.8899180447117017,
 3: 0.9247474501000439,
 4: 0.8954939162713962}

In [26]:
recall

{0: 0.945750452079566,
 1: 0.9206783637163384,
 2: 0.8968769854845804,
 3: 0.9261033185083818,
 4: 0.8488832412883046}