In [1]:
from collections import defaultdict
import networkx as nx
import os

In [2]:
# check if file Wiki_ET.txt
PATH = '../wikidata/'
if not os.path.exists(PATH+"Wiki_ET.txt"):
    raise FileNotFoundError("Please first run postprocess.py in data_mining_scripts folder")

In [3]:
# load Wikidata and WiKC taxonomy
oriWikiTaxonDown = defaultdict(set)
with open(PATH+'wiki_taxonomy.tsv', 'r') as ori:
    for line in ori:
        triple = line.strip().split('\t')
        if len(triple) > 3:
            oriWikiTaxonDown[triple[2]].add(triple[0])
oriWiki = nx.DiGraph(oriWikiTaxonDown)

cleanWikiTaxonDown = defaultdict(set)
with open('../wikc.tsv', 'r') as clean:
    for line in clean:
        child, parent = line.strip().split('\t')
        cleanWikiTaxonDown[parent].add(child)
wikc = nx.DiGraph(cleanWikiTaxonDown)

In [4]:
mapped_wiki_ents_en = set()
with open(os.path.join('../wikipedia', 'enwiki'), 'r') as file:
    for line in file:
        qid = line.strip().split(',')[1]
        prefix_qid = 'wd:'+str(qid)
        mapped_wiki_ents_en.add(prefix_qid)

In [5]:
# reload edges deleted
if not os.path.exists('edges_del.tsv'):
    raise FileNotFoundError("Please first run clean.ipynb to get the edges_del.tsv file")
edges_del = set()
with open('edges_del.tsv', 'r') as file:
    for line in file:
        parent, child = line.strip().split('\t')
        edges_del.add(tuple([parent, child]))
print("Reload Number of edges deleted: ", len(edges_del))

Reload Number of edges deleted:  10698


In [6]:
# useful functions
def get_ancestors(ori_graph, ancestors, cls, cur_graph, depth):
    for sp in ori_graph.predecessors(cls):
        # path should not include irrelevant edges
        if tuple([sp, cls]) in edges_del:
            continue
        if cur_graph.has_node(sp):
            ancestors.add(tuple([sp, depth]))
            continue
        get_ancestors(ori_graph, ancestors, sp, cur_graph, depth+1)

def get_valid_first_ancestors(ori_graph, cls, cur_graph):
    ancestors = set()
    get_ancestors(ori_graph, ancestors, cls, cur_graph, depth=0)
    return ancestors

* retype instances to WiKC

In [7]:
inst2type = defaultdict(set)
with open(PATH+'Wiki_ET.txt', 'r') as file:
    for line in file:
        inst, rel, cls = line.strip().split('\t')
        if inst not in mapped_wiki_ents_en:
            continue
        # instance has a Wikipedia page
        if wikc.has_node(cls):
            inst2type[inst].add(cls)
            continue
        # No direct class exists anymore
        fir_ancestors = list(get_valid_first_ancestors(oriWiki, cls, wikc))
        sorted_ances = sorted(fir_ancestors, key=lambda x: x[1])
        # only keep ancestors with the minimum depth
        min_depth = 1e9
        for tupl in sorted_ances:
            ances, depth = tupl
            if depth <= min_depth:
                min_depth = depth
                inst2type[inst].add(ances)
            else:
                break

In [8]:
from tqdm import tqdm
from itertools import permutations
# remove transtive types for each instance
for inst in tqdm(inst2type.keys()):
    if len(inst2type[inst]) == 1:
        continue
    valcls = inst2type[inst].copy()
    for (node1,node2) in permutations(list(inst2type[inst]), 2):
        if node1 not in valcls or node2 not in valcls:
            continue
        if node1 == 'wd:Q35120':
            # it's pointless to retype to root class
            valcls.remove(node1)
            continue
        if nx.has_path(oriWiki, node1, node2):
            # oriWiki -> some insts retype to top level classes being wrong
            valcls.remove(node1)
    inst2type[inst] = valcls.copy() # set

100%|██████████| 7204533/7204533 [05:03<00:00, 23720.45it/s] 


In [9]:
# filtering out instances typed to class of top3lavel classes
root = 'wd:Q35120'
top3_cls = set(nx.single_source_shortest_path_length(wikc, source=root, cutoff=3).keys())
for inst in tqdm(inst2type.keys()):
    del_set = inst2type[inst].intersection(top3_cls)
    if len(del_set) > 0:
        for cls in del_set:
            inst2type[inst].remove(cls)

100%|██████████| 7204533/7204533 [00:05<00:00, 1257552.80it/s]


In [10]:
# check if empty set exists
inst2cls = {}
for inst in tqdm(inst2type.keys()):
    if len(inst2type[inst]) > 0:
        inst2cls[inst] = inst2type[inst].copy()
# delete inst2type
del inst2type

100%|██████████| 7204533/7204533 [00:19<00:00, 372467.38it/s] 


* extract a subset of instances for extrinsic evaluation

In [13]:
# The cumulative stats is for instances of classes
def getSuperClasses(cls, classes, WikiTaxonomyUp):
    """Adds all superclasses of a class <cls> (including <cls>) to the set <classes>"""
    classes.add(cls)
    # Make a check before because it's a defaultdict,
    # which would create cls if it's not there
    if cls in WikiTaxonomyUp:
        for sc in WikiTaxonomyUp[cls]:
            getSuperClasses(sc, classes, WikiTaxonomyUp)      

def getAncestors(cls, WikiTaxonomyUp):
    """Returns the set of all parent classes of <cls> (including <cls>!)"""
    classes=set()
    getSuperClasses(cls, classes, WikiTaxonomyUp)        
    return classes

def cumulative_stats(oristats, topTaxonomyUp):
    """Cumulative statistics of classes"""
    cum_stats = defaultdict(int)
    for instantiated_cls in oristats.keys():
        ancestors = getAncestors(instantiated_cls, topTaxonomyUp) # including cls itself
        for ancestor in ancestors:
            cum_stats[ancestor] += oristats[instantiated_cls]
    return cum_stats

In [15]:
# check current cumulative number of instances for 'Person'
from collections import defaultdict
cls_stats = defaultdict(int)
for inst in inst2cls.keys():
    for cls in inst2cls[inst]:
        cls_stats[cls] += 1
cum_cls_insts = cumulative_stats(cls_stats, nx.to_dict_of_lists(wikc.reverse()))
print("Current cumulative number of instances for 'Person': ", cum_cls_insts['wd:Q215627'])
# too much instances for 'Person' class -> imbalance

Current cumulative number of instances for 'Person':  2575432


In [16]:
cls_insts = defaultdict(set)
for inst in tqdm(inst2cls.keys()):
    for cls in inst2cls[inst]:
        cls_insts[cls].add(inst)

100%|██████████| 5019581/5019581 [00:05<00:00, 901315.51it/s] 


In [19]:
# step1: limit each class less than 1000 instances
import random
for cls in cls_insts:
    if len(cls_insts[cls]) > 100:
        sampled_insts = random.sample(cls_insts[cls], 100)
        cls_insts[cls] = set(sampled_insts).copy()

In [18]:
inst_cls = defaultdict(set)
for cls in tqdm(cls_insts):
    for inst in cls_insts[cls]:
        inst_cls[inst].add(cls)

100%|██████████| 13613/13613 [00:00<00:00, 41759.26it/s]


In [20]:
# step2: select at random 100k instances overall
sample_keys = random.sample(list(inst_cls.keys()), 100000)
sampled_inst2type = {key: inst_cls[key] for key in sample_keys} # wikc

In [21]:
# step3: classes for sampled instances in original Wikidata
ori_inst2type = defaultdict(set)
with open(PATH+'Wiki_ET.txt', 'r') as etreader:
    for line in etreader:
        inst, rel, cls = line.strip().split('\t')
        if inst in sampled_inst2type:
            ori_inst2type[inst].add(cls)

# same as before: remove transtive types for each instance
for inst in ori_inst2type.keys():
    if len(ori_inst2type[inst]) == 1:
        continue
    valcls = ori_inst2type[inst].copy()
    for (node1,node2) in permutations(list(ori_inst2type[inst]), 2):
        if node1 not in valcls or node2 not in valcls:
            continue
        if node1 == 'wd:Q35120':
            valcls.remove(node1)
            continue
        if nx.has_path(oriWiki, node1, node2):
            valcls.remove(node1)
    ori_inst2type[inst] = valcls.copy()

print("Number of instances: ", len(ori_inst2type))

Number of instances:  100000


In [24]:
# load labels and descriptions for instances
ent2label, ent2desc = {}, {}
with open(PATH+'Wiki_literals.txt', 'r') as file:
    for line in file:
        qid, rel, literal = line.strip().split('\t')
        if rel == 'rdfs:label':
            ent2label[qid] = literal[1:-4]
            continue
        if rel == 'schema:description':
            ent2desc[qid] = literal[1:-4]

In [25]:
print("Number of instances with labels: ", len(ent2label))
print("Number of instances with descriptions: ", len(ent2desc))

Number of instances with labels:  38400123
Number of instances with descriptions:  38400123


In [22]:
from collections import deque
def get_parents_with_hops(graph, node):
    '''
    @param node: qid of wikidata entity
    '''
    # dictionary to store parents of node and their hop count
    parents_with_hops = defaultdict(list) # {hop_count: [parents]}
    # queue to perform BFS
    queue = deque([(node, 0)])
    visited = set()
    
    while queue:
        current_node, hop_count = queue.popleft()
        
        if current_node not in visited:
            visited.add(current_node)
            
            # For each predecessor (parent) of the current node
            for parent in graph.predecessors(current_node):
                if parent not in visited:
                    parents_with_hops[hop_count + 1].append(parent)
                    queue.append((parent, hop_count + 1))
    
    return parents_with_hops

In [23]:
def get_parents_with_depth(graph, node, cls_depth):
    # Retrieve the distances for all superclasses of the given node
    superclasses_with_distances = {}
    def collect_superclasses(node):
        for predecessor in graph.predecessors(node):
            if predecessor in cls_depth:
                superclasses_with_distances[predecessor] = cls_depth[predecessor]
                collect_superclasses(predecessor)
    
    collect_superclasses(node)
    
    return superclasses_with_distances

* generate dataset

In [26]:
cls_depth = nx.single_source_shortest_path_length(wikc, root)
with open('wikc_eval.txt', 'w') as writer:
    for inst in tqdm(sampled_inst2type):
        types = sampled_inst2type[inst]
        parents_depth = {}
        for cls in types:
            parents_depth[cls] = cls_depth[cls]
            pars_with_dists = get_parents_with_depth(wikc, cls, cls_depth)
            parents_depth.update(pars_with_dists)
        
        for cls, depth in parents_depth.items():
            writer.write(f"{inst}\t'{ent2label[inst]}'\t'{ent2desc[inst]}'\t{cls}\t{depth}\n")

100%|██████████| 100000/100000 [00:02<00:00, 44929.47it/s]


In [27]:
ori_cls_depth = nx.single_source_shortest_path_length(oriWiki, root)
with open('wikidata_eval.txt', 'w') as writer:
    for inst in tqdm(ori_inst2type):
        types = ori_inst2type[inst]
        parents_depth = {}
        for cls in types:
            parents_depth[cls] = ori_cls_depth[cls]
            pars_with_dists = get_parents_with_depth(oriWiki, cls, ori_cls_depth)
            parents_depth.update(pars_with_dists)
        
        for cls, depth in parents_depth.items():
            writer.write(f"{inst}\t'{ent2label[inst]}'\t'{ent2desc[inst]}'\t{cls}\t{depth}\n")

100%|██████████| 100000/100000 [00:21<00:00, 4699.35it/s]


* a mini-test dataset: only 1k samples

In [28]:
sample_keys = random.sample(list(inst_cls.keys()), 1000)
mini_sampled_inst2type = {key: inst_cls[key] for key in sample_keys}

In [29]:
mini_ori_inst2type = defaultdict(set)
with open(PATH+'Wiki_ET.txt', 'r') as etreader:
    for line in etreader:
        inst, rel, cls = line.strip().split('\t')
        if inst in mini_sampled_inst2type:
            mini_ori_inst2type[inst].add(cls)

# same as before: remove transtive types for each instance
for inst in mini_ori_inst2type.keys():
    if len(mini_ori_inst2type[inst]) == 1:
        continue
    valcls = mini_ori_inst2type[inst].copy()
    for (node1,node2) in permutations(list(mini_ori_inst2type[inst]), 2):
        if node1 not in valcls or node2 not in valcls:
            continue
        if node1 == 'wd:Q35120':
            valcls.remove(node1)
            continue
        if nx.has_path(oriWiki, node1, node2):
            valcls.remove(node1)
    mini_ori_inst2type[inst] = valcls.copy()

print("Number of instances: ", len(mini_ori_inst2type))

Number of instances:  1000


In [None]:
cls_depth = nx.single_source_shortest_path_length(wikc, 'wd:Q35120')
with open('wikc_eval_1k.txt', 'w') as writer:
    for inst in tqdm(mini_sampled_inst2type):
        types = mini_sampled_inst2type[inst]
        parents_depth = {}
        for cls in types:
            parents_depth[cls] = cls_depth[cls]
            pars_with_dists = get_parents_with_depth(wikc, cls, cls_depth)
            parents_depth.update(pars_with_dists)
        
        for cls, depth in parents_depth.items():
            writer.write(f"{inst}\t'{ent2label[inst]}'\t'{ent2desc[inst]}'\t{cls}\t{depth}\n")



ori_cls_depth = nx.single_source_shortest_path_length(oriWiki, 'wd:Q35120')
with open('wikidata_eval_1k.txt', 'w') as writer:
    for inst in tqdm(mini_ori_inst2type):
        types = mini_ori_inst2type[inst]
        parents_depth = {}
        for cls in types:
            parents_depth[cls] = ori_cls_depth[cls]
            pars_with_dists = get_parents_with_depth(oriWiki, cls, ori_cls_depth)
            parents_depth.update(pars_with_dists)
        
        for cls, depth in parents_depth.items():
            writer.write(f"{inst}\t'{ent2label[inst]}'\t'{ent2desc[inst]}'\t{cls}\t{depth}\n")