In [1]:
# 深度优先搜索+限制条件
class Node:
    def __init__(self, idd, label):
        self.id = idd
        self.label =label
        self.neighbours = []

class Graph:
    def __init__(self):
        self.nodes = {}
    
    def add_node(self, idd, label):
        node = Node(idd, label)
        self.nodes[idd] = node

    def add_edge(self, start_id, end_id):
        start_node = self.nodes[start_id]
        end_node = self.nodes[end_id]
        start_node.neighbours.append(end_node)

def build_graph(node_file, relation_file):
    nodes = {}
    with open(node_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            idd, label = line.strip().split('\t')
            nodes[idd] = label
    relations = []
    with open(relation_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            start_id, end_id, _ = line.strip().split('\t')
            relations.append((start_id, end_id))
    graph = Graph()
    for idd, label in nodes.items():
        graph.add_node(idd, label)
    for start_id, end_id in relations:
        graph.add_edge(start_id, end_id)
    
    return graph

def dfs(graph, start_id, end_id):
    visited = set()
    stack = [(graph.nodes[start_id],[])]

    while stack:
        node, path = stack.pop()
        if node.id == end_id:
            return path + [node.id]
        if node.id not in visited:
            visited.add(node.id)
            for neighbour in node.neighbours:
                stack.append((neighbour, path+[node.id]))
    return None
    

In [2]:
import pandas as pd
data = pd.read_csv('KG/java/entity2id.txt',sep='\t',names=['id','label'])
java_dict = dict(zip(data['label'],data['id']))

In [17]:
graph = build_graph('KG/java/entity2id.txt', 'KG/java/train.txt')
start_id = java_dict[1]
end_id = java_dict[5]
path = dfs(graph, start_id, end_id)
path

['Java语言', 'Java应用', '网络应用']

In [18]:
'->'.join(graph.nodes[idd].label for idd in path)

'1->4->5'

In [19]:
def cal_p_r(rec_path, true_path):
    x = LCS(rec_path,true_path)
    r = x*1.0/len(true_path)
    p = x*1.0/len(rec_path)
    f = (2*p*r)/(p+r+0.0000001)
    return p, r, f
def LCS(text1, text2):
    m, n = len(text1), len(text2)
    dp = [0] * (n+1)

    for i in range(1, m+1):
        dp2 = [0] * (n+1)
        for j in range(1, n+1):
            if text1[i-1] == text2[j-1]:
                dp2[j] = dp[j-1] + 1
            else:
                dp2[j] = max(dp[j], dp2[j-1])
        dp = dp2
    
    return dp[n]

In [23]:
with open('KG/java/java_test_10.csv', 'r') as f:
    lines = f.readlines()[1:]
    num = len(lines)
    all_f = 0
    all_p = 0
    all_r = 0
    for line in lines:
        _,inter,_,_,target,true_path = line.strip().split('\t')
        start_id = java_dict[int(inter.strip('[').strip(']').split(',')[-1])]
        end_id = java_dict[int(target)]
        true_path = true_path.strip('[').strip(']').split(',')
        true_path = [int(x) for x in true_path]
        pred_path = dfs(graph, start_id, end_id)
        if pred_path:
            pred_path = [graph.nodes[idd].label for idd in path]
            #print(pred_path)
            #print(true_path)
            p,r,f = cal_p_r(pred_path,true_path)
            all_p += p
            all_r += r
            all_f += f
    print("precision: ",all_p/num," recall: ",all_r/num, " f1 score: ",all_f/num)

precision:  0.0  recall:  0.0  f1 score:  0.0
