In [None]:
import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
train_data_path = data_path + '/train'

triples = []
query_dict = {}

entity_num = 14541
relation_num = 237

with open(train_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        triples.append((int(h), int(r), int(t)))
        triples.append((int(t), int(r)+relation_num, int(h)))
        
        if (int(h), int(r)) not in query_dict:
            query_dict[(int(h), int(r))] = []
        query_dict[(int(h), int(r))].append(int(t))
        
        if (int(t), int(r)+relation_num) not in query_dict:
            query_dict[(int(t), int(r)+relation_num)] = []
        query_dict[(int(t), int(r)+relation_num)].append(int(h))

print('所有的三元组共有：', len(triples), '个')
print('不同的query共有：', len(query_dict), '个')


所有的三元组共有： 544230 个
不同的query共有 149689 个


In [2]:
""""
把在trainset中的query筛选出来,每个query只出现以此就可以
"""

import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
train_data_path = data_path + '/train'

query_aware_triples = []
query_dict = {}

entity_num = 14541
relation_num = 237

with open(train_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        h, r, t = int(h), int(r), int(t)
        if (h,r)not in query_dict:
            query_aware_triples.append((h,r,t))
            query_dict[(h,r)] = True
        
        if(t, r+relation_num) not in query_dict:
            query_aware_triples.append((t, r+relation_num, h))
            query_dict[(t, r+relation_num)] = True

        
output_file_path = os.path.join(data_path, 'single_query.txt')

# Writing the triples to the output file
with open(output_file_path, 'w') as fout:
    for triple in query_aware_triples:
        fout.write(f"{triple[0]}\t{triple[1]}\t{triple[2]}\n")

print(f"Unique Query have been successfully written to {output_file_path}")




Query-aware triples have been successfully written to /root/Low_Dimension_KGC/data/FB15k-237/single_query.txt


In [None]:
""""
生成关系感知负采样样本
"""

import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
train_data_path = data_path + '/train'

triples = []
query_dict = {}
rt_dict = {}

entity_num = 14541
relation_num = 237

with open(train_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        h, r, t = int(h), int(r), int(t)
        triples.append((h, r, t))
        triples.append((t, r+relation_num, h))
        
        if r not in rt_dict:
            rt_dict[r] = []
        rt_dict[r].append(t)
        
        if (r+relation_num) not in rt_dict:
            rt_dict[r+relation_num] = []
        rt_dict[r+relation_num].append(h)
        
        if (h, r) not in query_dict:
            query_dict[(h, r)] = []
        query_dict[(h, r)].append(t)
        
        if (t, r+relation_num) not in query_dict:
            query_dict[(t, r+relation_num)] = []
        query_dict[(t, r+relation_num)].append(h)


output_path = os.path.join(data_path, 'rt_dict.txt')  # 输出文件路径

# 按照关系 r 的大小排序
sorted_rt_dict = sorted(rt_dict.items(), key=lambda x: x[0])

# 统计每个 r 的 tail 数量
tail_counts = []

with open(output_path, 'w') as fout:
    for r, tails in sorted_rt_dict:
        # 写入文件
        line = f"{r}\t" + '\t'.join(map(str, tails)) + '\n'
        fout.write(line)

        # 记录每个 r 对应的 tail 数量
        tail_counts.append(len(tails))

# 计算统计信息
tail_counts_np = np.array(tail_counts)
total_relations = len(tail_counts)
stats = {
    "total_relations": total_relations,
    "min_tails": tail_counts_np.min(),
    "max_tails": tail_counts_np.max(),
    "mean_tails": tail_counts_np.mean(),
    "median_tails": np.median(tail_counts_np)
}

# 打印统计信息
print(f"Total relations: {stats['total_relations']}")
print(f"Min tails per relation: {stats['min_tails']}")
print(f"Max tails per relation: {stats['max_tails']}")
print(f"Mean tails per relation: {stats['mean_tails']:.2f}")
print(f"Median tails per relation: {stats['median_tails']:.2f}")

print(f"rt_dict has been written to {output_path}")


In [4]:
# 统计每个 query 的答案数量
single_answer_count = 0  # 记录只有一个答案的 query 数量
multiple_answers_count = 0  # 记录有多个答案的 query 数量

for query, answers in query_dict.items():
    if len(answers) == 1:
        single_answer_count += 1
    elif len(answers) > 1:
        multiple_answers_count += 1

# 打印统计结果
print(f"训练集中只有答案的query有: {single_answer_count}个")
print(f"训练集中有多个答案的query有: {multiple_answers_count}个")


from collections import defaultdict

# 初始化一个字典来统计答案数量的分布
answer_count_distribution = defaultdict(int)

# 遍历 query_dict，统计每个 query 的答案数量
for query, answers in query_dict.items():
    answer_count_distribution[len(answers)] += 1

# 打印统计结果
print("Answer count distribution:")
for answer_count, query_count in sorted(answer_count_distribution.items()):
    print(f"Queries with {answer_count} answer(s): {query_count}")



训练集中只有答案的query有: 88518个
训练集中有多个答案的query有: 61171个
Answer count distribution:
Queries with 1 answer(s): 88518
Queries with 2 answer(s): 20787
Queries with 3 answer(s): 11502
Queries with 4 answer(s): 7169
Queries with 5 answer(s): 4394
Queries with 6 answer(s): 3129
Queries with 7 answer(s): 2235
Queries with 8 answer(s): 1734
Queries with 9 answer(s): 1427
Queries with 10 answer(s): 1100
Queries with 11 answer(s): 943
Queries with 12 answer(s): 761
Queries with 13 answer(s): 602
Queries with 14 answer(s): 613
Queries with 15 answer(s): 456
Queries with 16 answer(s): 379
Queries with 17 answer(s): 343
Queries with 18 answer(s): 262
Queries with 19 answer(s): 234
Queries with 20 answer(s): 225
Queries with 21 answer(s): 199
Queries with 22 answer(s): 159
Queries with 23 answer(s): 169
Queries with 24 answer(s): 140
Queries with 25 answer(s): 87
Queries with 26 answer(s): 104
Queries with 27 answer(s): 83
Queries with 28 answer(s): 85
Queries with 29 answer(s): 64
Queries with 30 answer(s)

In [8]:
import copy

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
valid_data_path = data_path + '/valid'
test_data_path = data_path + '/test'

all_triples = copy.deepcopy(triples)
all_query_dict = copy.deepcopy(query_dict)

entity_num = 14541
relation_num = 237

with open(valid_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        all_triples.append((int(h), int(r), int(t)))
        all_triples.append((int(t), int(r)+relation_num, int(h)))
        
        if (int(h), int(r)) not in all_query_dict:
            all_query_dict[(int(h), int(r))] = []
        all_query_dict[(int(h), int(r))].append(int(t))
        
        if (int(t), int(r)+relation_num) not in all_query_dict:
            all_query_dict[(int(t), int(r)+relation_num)] = []
        all_query_dict[(int(t), int(r)+relation_num)].append(int(h))

with open(test_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        all_triples.append((int(h), int(r), int(t)))
        all_triples.append((int(t), int(r)+relation_num, int(h)))
        
        if (int(h), int(r)) not in all_query_dict:
            all_query_dict[(int(h), int(r))] = []
        all_query_dict[(int(h), int(r))].append(int(t))
        
        if (int(t), int(r)+relation_num) not in all_query_dict:
            all_query_dict[(int(t), int(r)+relation_num)] = []
        all_query_dict[(int(t), int(r)+relation_num)].append(int(h))

In [None]:
# 初始化变化统计字典
from collections import defaultdict

answer_changes = defaultdict(int)

# 比较两种数据集的 query 和答案数量
all_queries = set(all_query_dict.keys())
train_queries = set(query_dict.keys())

# 计算变化
for query in all_queries:
    train_answers = query_dict.get(query, [])
    all_answers = all_query_dict.get(query, [])

    train_answer_count = len(train_answers)
    all_answer_count = len(all_answers)

    # 记录变化 (train -> all)
    change = (train_answer_count, all_answer_count)
    answer_changes[change] += 1

# 转换为有序列表，按照训练集的 query 数量（train_answer_count）从小到大排序
sorted_changes = sorted(answer_changes.items(), key``=lambda x: x[0])

# 打印统计结果
print("Query answer count changes (train -> all):")
for (train_count, all_count), count in sorted_changes:
    print(f"From {train_count} to {all_count}: {count} queries")


Query answer count changes (train -> all):
From 0 to 1: 11834 queries
From 0 to 2: 366 queries
From 0 to 3: 32 queries
From 0 to 4: 1 queries
From 1 to 1: 83728 queries
From 1 to 2: 4237 queries
From 1 to 3: 475 queries
From 1 to 4: 65 queries
From 1 to 5: 13 queries
From 2 to 2: 17307 queries
From 2 to 3: 2854 queries
From 2 to 4: 507 queries
From 2 to 5: 98 queries
From 2 to 6: 19 queries
From 2 to 7: 1 queries
From 2 to 8: 1 queries
From 3 to 3: 8770 queries
From 3 to 4: 2047 queries
From 3 to 5: 525 queries
From 3 to 6: 131 queries
From 3 to 7: 26 queries
From 3 to 8: 3 queries
From 4 to 4: 5082 queries
From 4 to 5: 1463 queries
From 4 to 6: 472 queries
From 4 to 7: 119 queries
From 4 to 8: 26 queries
From 4 to 9: 5 queries
From 4 to 10: 1 queries
From 4 to 11: 1 queries
From 5 to 5: 2684 queries
From 5 to 6: 1151 queries
From 5 to 7: 397 queries
From 5 to 8: 112 queries
From 5 to 9: 35 queries
From 5 to 10: 13 queries
From 5 to 11: 1 queries
From 5 to 12: 1 queries
From 6 to 6: 17

In [6]:
def read_file(file_path):
    triplets = {}
    with open(file_path, 'r') as file:
        for line in file:
            head, relation, tail, rank = line.strip().split()
            rank = int(rank)
            key = (head, relation, tail)
            if key not in triplets or triplets[key] > rank:
                triplets[key] = rank
    return triplets

def merge_and_calculate_metrics(file1, file2):
    triplets1 = read_file(file1)
    triplets2 = read_file(file2)
    
    # 合并两个字典，并取最小rank
    final_triplets = {**triplets1, **{k: min(triplets1.get(k, float('inf')), triplets2[k]) for k in triplets2}}
    
    # 计算MR, MRR, HIT@K
    ranks = list(final_triplets.values())
    mr = sum(ranks) / len(ranks)
    mrr = sum(1.0 / rank for rank in ranks) / len(ranks)
    hits = {k: sum(1 for rank in ranks if rank <= k) / len(ranks) for k in [1, 3, 10, 50, 100]}
    
    return mr, mrr, hits

# 使用该函数
file1 = '/root/Low_Dimension_KGC/models/FB15k-237_1005/test_detail_result_1.txt'
file2 = '/root/Low_Dimension_KGC/models/FB15k-237_1005/test_detail_result_2.txt'
# file1 = '/root/Low_Dimension_KGC/models/LorentzKG_FB15k-237_40/test_detail_result.txt'
# file2 = '/root/Low_Dimension_KGC/models/RotatE_FB15k-237_102/test_detail_result.txt'
mr, mrr, hits = merge_and_calculate_metrics(file1, file2)

# 打印结果
print(f'MR: {mr}')
print(f'MRR: {mrr}')
for k in sorted(hits.keys()):
    print(f'HIT@{k}: {hits[k]:.4f}')



MR: 198.76710153425194
MRR: 0.39385005402396805
HIT@1: 0.3111
HIT@3: 0.4263
HIT@10: 0.5568
HIT@50: 0.7197
HIT@100: 0.7806
