In [1]:
# -*- coding:utf-8 -*-
import os, time, re, torch, math
import numpy as np
from collections import defaultdict
from tqdm import tqdm

In [2]:
e2i, r2i = {}, {}
with open('datasets/umls/entities.txt') as e:
    for line in e:
        e2i[line.strip()] = len(e2i)

i = 0
with open('datasets/umls/relations.txt') as r:
    for line in r:
        r2i[line.strip()] = i
        r2i['inv_' + line.strip()] = i+1
        r2i[line.strip() + '_inv'] = i+1
        i += 2
r2i['no_jump'] = i
r2i['no_jump_inv'] = i+1
r2i['inv_no_jump'] = i+1
i += 2

neigs = [[] for _ in range(i)]
with open('datasets/umls/all.txt') as f:
    for line in f:
        h,r,t = line.strip().split()
        h,r,r_inv,t = e2i[h],r2i[r],r2i['inv_'+r],e2i[t]
        neigs[r].append((h,t))
        neigs[r_inv].append((t,h))
for i in range(len(e2i)):
    neigs[r2i['no_jump']].append((i,i))
    neigs[r2i['no_jump_inv']].append((i,i))

av_ents = {}
with open('datasets/umls/train.txt') as f:
    for line in f:
        h,a,v = line.strip().split("\t")
        h,a,v = e2i[h], r2i[a], e2i[v]
        try:
            av_ents[(a,v)][h] = 1
        except KeyError:
            av_ents[(a,v)] = torch.zeros(len(e2i))
for key in av_ents.keys():
    av_ents[key] = av_ents[key].cuda().to_sparse()

cnt = max([len(lis) for lis in neigs])
ones = torch.ones(cnt)
neigs = [torch.LongTensor(lis).T for lis in neigs]
neigs = [torch.sparse.LongTensor(pos, ones[:pos.size(1)], torch.Size([len(e2i),len(e2i)])).coalesce().cpu() for pos in neigs]
# neigs = [torch.sparse.LongTensor(pos, ones[:pos.size(1)], torch.Size([len(e2i),len(e2i)])).to_dense().cuda() for pos in neigs]


In [3]:
filt = []
with open('exps/umls-3-rule/rulesXXXX') as f:
    for line in f:
        tmp = line.strip().split('<-')
        cnt = int(tmp[0].split('-')[0])
        wei = float(tmp[0].split('-')[1])
        rela = tmp[1].strip()
        filt.append((line, cnt, wei))

# for UMLS and Family-gender
filt = sorted(filt, key=lambda x: -(x[2] * ((x[1] > 80) + 1)))
# for FB15K237
# filt = sorted(filt, key=lambda x: -(x[2] * ((x[1] > 300) - (x[1] < 20) + 5)))

lines = []
for i in range(400):
    lines.append(filt[i][0])


In [4]:
# filter始、终节点
# dic = defaultdict(list)
inv = re.compile('_inv_inv')
quot = re.compile("'")
ones = torch.ones(len(e2i)).cuda().to_sparse()
rules = []
# with open('exps/18rr-3-rule/rulesXXXX') as f:
if True:
    lines = f.readlines()
    lines = [quot.sub('',inv.sub('', line)) for line in lines]
    for index,line in enumerate(tqdm(lines, desc='计算confidence')):
        w,rela,rule = line.strip().split('<-')
        rela = r2i[inv.sub('', rela+'_inv')]
        rule = [p.strip() for p in rule.strip().split('^')[:-1]]
        for i in range(len(rule)):
            if '(' in rule[i]:
                limits = (rule[i].split('(')[1]).split('&')[:-1]
                rule[i] = [rule[i].split('(')[0].strip(), ones.clone()]
                for limit in limits:
                    limit_ents = ((limit.split('[')[1]).split(']')[0]).split(', ')
                    limit = limit.split('-')[0]
                    for limit_ent in limit_ents:
                        rule[i][1] *= av_ents[(r2i[limit], e2i[limit_ent])]
            else:
                rule[i] = (rule[i].strip(), None)
        
        for i in range(len(rule)):
            if i == 0:
                path = neigs[r2i[rule[i][0]]].clone()
            else:
                path = torch.sparse.mm(path, neigs[r2i[rule[i][0]]])
            if rule[i][1] != None:
                path = path.to_dense() * rule[i][1].to_dense()
                path = path.to_sparse()
        
        cnt1 = torch.sum(path.coalesce().values().bool())
        cnt2 = torch.sum(path.to_dense().bool()[neigs[rela].indices().cpu().numpy()])
        
        rules.append(  ((cnt2/cnt1).item() if cnt1!=0 else 0, str(index)+line)  )

        path = cnt1 = cnt2 = None
        torch.cuda.empty_cache()


计算confidence: 100%|██████████| 400/400 [3:12<00:00, 2.08it/s]


In [5]:
cons = []
for i in range(400):
    cons.append(rules[i][0])

for i in [50,100,200]:
    print('Top@{}:{:.4f}'.format(i, sum(cons[:i])/i), end='   ')
    print(f'Top@{i}:{sum(cons[:i])/i}  ', end='')

Top@50:0.5462   Top@100:0.3864   Top@200:0.4241   

In [6]:
confidence = np.array([x[0] for x in rules])
print([sum(confidence < x) for x in [0.3,0.6,1.1]])

[244, 30, 161]
