In [32]:
import pandas as pd
import numpy as np
import re
from dataLoader import *
from utils import *
import argparse
import yaml
import os

In [25]:
def category_correspondence(selected_terms, categories, ename2eid, eid2DocProb, eidDocPair2Prob):
    score = 0
    for cat in categories:
        if cat not in ename2eid: continue
        scores = []
        for term in selected_terms:
            scores += [get_nmi(ename2eid[cat], ename2eid[term], eid2DocProb, eidDocPair2Prob)]
        score += sum(scores)
    return score

In [44]:
def result(config_file):
    with open(config_file, 'r') as ymlfile:
        config = yaml.load(ymlfile, Loader=yaml.FullLoader)
    
    domain_path = config['dataset']['domain_path']
    result_path = os.path.join(config['dataset']['domain_path'], config['dataset']['result_folder'])
    
    eid2ename, ename2eid = loadEidToEntityMap(domain_path + 'intermediate/entity2id.txt')
    eid2DocProb = loadEid2DocFeature(domain_path + 'intermediate/eid2DocProb.txt')
    eidDocPair2Prob = loadEidDocPairFeature(domain_path + 'intermediate/eidDocPair2prob.txt')
    
    gt_file = '../data/groundtruths/category_correspondence/arxivcs_categories.txt'
    cs_cates = []
    with open(gt_file, 'r') as f:
        for line in f:
            cat = line.strip()
            cat = cat.lower()
            cat = re.sub(r'[^\x00-\x7F]+', ' ', cat)
            cat = cat.replace("-", " ")
            cat = "_".join(cat.split())
            cs_cates.append(cat)
    cs_cates = np.array(cs_cates)
    
    selected_terms_rf = pd.read_csv(result_path+'rf.txt', header=None, sep='\n').values[:,0]
    selected_terms_lo = pd.read_csv(result_path+'lo.txt', header=None, sep='\n').values[:,0]
    selected_terms_fl = pd.read_csv(result_path+'fl.txt', header=None, sep='\n').values[:,0]
    selected_terms_kl_rf = pd.read_csv(result_path+'kl_rf.txt', header=None, sep='\n').values[:,0]
    selected_terms_mm = pd.read_csv(result_path+'mm.txt', header=None, sep='\n').values[:,0]
    selected_terms_kl_mm = pd.read_csv(result_path+'kl_mm.txt', header=None, sep='\n').values[:,0]
    
    
    ks = [10, 20, 30, 40, 50, 100, 200, 500]
    ccs = []
    for k in ks:
        cc = []
        cc.append(category_correspondence(selected_terms_rf[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        cc.append(category_correspondence(selected_terms_lo[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        cc.append(category_correspondence(selected_terms_fl[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        cc.append(category_correspondence(selected_terms_kl_rf[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        cc.append(category_correspondence(selected_terms_mm[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        cc.append(category_correspondence(selected_terms_kl_mm[:k], cs_cates, ename2eid, eid2DocProb, eidDocPair2Prob))
        ccs.append(", ".join([str(round(c,4)) for c in cc]))
        
        print(ccs[-1])

In [None]:
config_file = "configs/arxivcs_ap.yaml" # candidate keywords: authoprhase extracted keywords
result(config_file)

Loading: ../data/arxiv/cs/all_ap/intermediate/entity2id.txt: 100%|██████████| 93148/93148 [00:00<00:00, 702473.24it/s]
Loading: ../data/arxiv/cs/all_ap/intermediate/eid2DocProb.txt: 100%|██████████| 93148/93148 [00:00<00:00, 727896.24it/s]
Loading: ../data/arxiv/cs/all_ap/intermediate/eidDocPair2prob.txt:  44%|████▍     | 5973648/13653903 [00:19<00:18, 409761.40it/s]

In [None]:
config_file = "configs/arxivcs_sp.yaml" # candidate keywords: springer
result(config_file)

In [45]:
config_file = "configs/arxivcs_am.yaml" # candidate keywords: aminer
result(config_file)

Loading: ../data/arxiv/cs/all_am/intermediate/entity2id.txt: 100%|██████████| 44938/44938 [00:00<00:00, 677034.27it/s]
Loading: ../data/arxiv/cs/all_am/intermediate/eid2DocProb.txt: 100%|██████████| 44938/44938 [00:00<00:00, 736842.73it/s]
Loading: ../data/arxiv/cs/all_am/intermediate/eidDocPair2prob.txt: 100%|██████████| 14735109/14735109 [00:52<00:00, 279165.55it/s]


1.102, 2.1205, 1.1073, 1.1109, 2.1205, 2.1645
2.1782, 2.305, 2.1699, 2.1956, 3.3048, 3.3079
2.2433, 3.432, 2.2508, 2.2795, 3.4394, 4.4522
2.3159, 3.5312, 2.3012, 2.3944, 5.5693, 5.57
2.4617, 5.6519, 3.4657, 2.4675, 5.6463, 5.6829
3.8432, 8.1523, 3.8038, 3.8429, 9.1788, 8.1679
6.5778, 12.1533, 6.532, 6.5431, 12.2007, 12.1013
13.6675, 19.5232, 13.4622, 12.4942, 19.558, 20.5255
