In [1]:
import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
train_data_path = data_path + '/train'

triples = []
query_dict = {}
rt_dict = {}

entity_num = 14541
relation_num = 237

with open(train_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        h, r, t = int(h), int(r), int(t)
        triples.append((h, r, t))
        triples.append((t, r+relation_num, h))
        
        if r not in rt_dict:
            rt_dict[r] = []
        rt_dict[r].append(t)
        
        if (r+relation_num) not in rt_dict:
            rt_dict[r+relation_num] = []
        rt_dict[r+relation_num].append(h)
        
        if (h, r) not in query_dict:
            query_dict[(h, r)] = []
        query_dict[(h, r)].append(t)
        
        if (t, r+relation_num) not in query_dict:
            query_dict[(t, r+relation_num)] = []
        query_dict[(t, r+relation_num)].append(h)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def parse_qtdict(file_path):
    """
    读取本地txt文件，并将数据转换为字典格式。
    
    参数：
        file_path (str): txt文件路径。
        
    返回：
        dict: 存储数据的字典，格式为：
              {(head, relation): {"PT1_id": [entity_id列表], "PT2_id": [entity_id列表]}}
    """
    data_dict = {}
    
    with open(file_path, 'r') as file:
        lines = file.readlines()
        
        # 每组数据有三行，循环处理
        for i in range(0, len(lines), 3):
            # 解析第一行
            head_relation = lines[i].strip().split("\t")
            head, relation = int(head_relation[0]), int(head_relation[1])
            
            # 解析第二行
            pt1_ids = list(map(int, lines[i + 1].strip().split("\t")))
            
            # 解析第三行
            pt2_ids = list(map(int, lines[i + 2].strip().split("\t")))
            
            # 构造字典
            data_dict[(head, relation)] = {
                "PT1_id": pt1_ids,
                "PT2_id": pt2_ids
            }
    
    return data_dict

qt_dict = parse_qtdict('/root/Low_Dimension_KGC/data/FB15k-237/qt_dict.txt')


In [6]:
import numpy as np

# 初始化计数
in_new_list_count = 0
not_in_new_list_count = 0
total_count = 0

time_count=0

# 遍历 query_dict
for (head, relation), entity_id_data in qt_dict.items():
    # 从 rt_dict 获取 list1
    list1 = rt_dict.get(relation, [])
    
    # 从 query_dict 获取 list2
    list2 = query_dict.get((head,relation), [])
    
    # 将 list1 和 list2 转为 numpy 数组
    list1_np = np.array(list1, dtype=np.int32)
    list2_np = np.array(list2, dtype=np.int32)
    
    # 计算新的 entity_id 列表（list1 去除 list2）
    new_entity_list = np.setdiff1d(list1_np, list2_np, assume_unique=True)
    
    # 使用 numpy 判断 list2 中的元素是否在 new_entity_list 中
    qt_sample_entity = entity_id_data['PT2_id'][:50]
    is_in_new_list = np.isin(qt_sample_entity, new_entity_list)
    
    # 更新统计
    in_new_list_count += np.sum(is_in_new_list)
    not_in_new_list_count += np.sum(~is_in_new_list)
    total_count += len(qt_sample_entity)
    
    time_count+=1
    if time_count >= 1400:
        break

# 计算占比
in_new_list_ratio = (in_new_list_count / total_count) * 100 if total_count > 0 else 0
not_in_new_list_ratio = (not_in_new_list_count / total_count) * 100 if total_count > 0 else 0

# 输出结果
print(f"总数: {total_count}")
print(f"出现在新的 entity_id 列表中的数量: {in_new_list_count}, 占比: {in_new_list_ratio:.2f}%")
print(f"未出现在新的 entity_id 列表中的数量: {not_in_new_list_count}, 占比: {not_in_new_list_ratio:.2f}%")


总数: 70000
出现在新的 entity_id 列表中的数量: 50710, 占比: 72.44%
未出现在新的 entity_id 列表中的数量: 19290, 占比: 27.56%


In [5]:
import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
test_data_path = data_path + '/test'

appear = 0
connect = 0

with open(test_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        h, r, t = int(h), int(r), int(t)
        
        if (h,r) in query_dict:
            if t in rt_dict[r]:
                connect += 1
            appear+=1

        if(t, r+relation_num) in query_dict:
            if h in rt_dict[r+relation_num]:
                connect += 1
            appear+=1

print(connect, appear)
print(connect/appear)

27848 34051
0.8178320754162873
