In [14]:
import numpy as np
import torch
import sys
import os

data_path = '/root/Low_Dimension_KGC/data/FB15k-237'
train_data_path = data_path + '/train'

triples = []
query_dict = {}
rt_dict = {}

entity_num = 14541
relation_num = 237

with open(train_data_path) as fin:
    for line in fin:
        h, r, t = line.strip().split('\t')
        h, r, t = int(h), int(r), int(t)
        triples.append((h, r, t))
        triples.append((t, r+relation_num, h))
        
        if r not in rt_dict:
            rt_dict[r] = []
        rt_dict[r].append(t)
        
        if (r+relation_num) not in rt_dict:
            rt_dict[r+relation_num] = []
        rt_dict[r+relation_num].append(h)
        
        if (h, r) not in query_dict:
            query_dict[(h, r)] = []
        query_dict[(h, r)].append(t)
        
        if (t, r+relation_num) not in query_dict:
            query_dict[(t, r+relation_num)] = []
        query_dict[(t, r+relation_num)].append(h)


output_path = os.path.join(data_path, 'rt_dict.txt')  # 输出文件路径

# 按照关系 r 的大小排序
sorted_rt_dict = sorted(rt_dict.items(), key=lambda x: x[0])

# 统计每个 r 的 tail 数量
tail_counts = []

with open(output_path, 'w') as fout:
    for r, tails in sorted_rt_dict:
        # 写入文件
        line = f"{r}\t" + '\t'.join(map(str, tails)) + '\n'
        fout.write(line)

        # 记录每个 r 对应的 tail 数量
        tail_counts.append(len(tails))

# 计算统计信息
tail_counts_np = np.array(tail_counts)
total_relations = len(tail_counts)
stats = {
    "total_relations": total_relations,
    "min_tails": tail_counts_np.min(),
    "max_tails": tail_counts_np.max(),
    "mean_tails": tail_counts_np.mean(),
    "median_tails": np.median(tail_counts_np)
}

# 打印统计信息
print(f"Total relations: {stats['total_relations']}")
print(f"Min tails per relation: {stats['min_tails']}")
print(f"Max tails per relation: {stats['max_tails']}")
print(f"Mean tails per relation: {stats['mean_tails']:.2f}")
print(f"Median tails per relation: {stats['median_tails']:.2f}")

print(f"rt_dict has been written to {output_path}")


Total relations: 474
Min tails per relation: 37
Max tails per relation: 15989
Mean tails per relation: 1148.16
Median tails per relation: 373.00
rt_dict has been written to /root/Low_Dimension_KGC/data/FB15k-237/rt_dict.txt


In [15]:
import numpy as np

# 读取测试集排名文件
test_file_path = '/root/Low_Dimension_KGC/models/FB15k-237_2000/test_detail_result.txt'  # 替换为实际文件路径

# 初始化字典保存结果
results = {}

# 读取测试文件并统计每个关系的 MRR、HIT@1、HIT@3、HIT@10
with open(test_file_path, 'r') as f:
    for line in f:
        h, r, t, rank = map(int, line.strip().split())
        if r not in results:
            results[r] = {"ranks": [], "tails_count": len(rt_dict.get(r, []))}
        results[r]["ranks"].append(rank)

# 计算每个 r 的统计信息
for r in results:
    ranks = np.array(results[r]["ranks"])
    results[r]["MRR"] = np.mean(1.0 / ranks)
    results[r]["HIT@1"] = np.mean(ranks <= 1)
    results[r]["HIT@3"] = np.mean(ranks <= 3)
    results[r]["HIT@10"] = np.mean(ranks <= 10)

# 按照 tails_count 升序排序
sorted_results = sorted(results.items(), key=lambda x: x[1]["tails_count"])

# 打印每个关系的统计信息
print(f"{'Relation':<10} {'MRR':<10} {'HIT@1':<10} {'HIT@3':<10} {'HIT@10':<10} {'Tail Count':<10}")
for r, stats in sorted_results:
    print(f"{r:<10} {stats['MRR']:<10.4f} {stats['HIT@1']:<10.4f} {stats['HIT@3']:<10.4f} {stats['HIT@10']:<10.4f} {stats['tails_count']:<10}")


Relation   MRR        HIT@1      HIT@3      HIT@10     Tail Count
236        1.0000     1.0000     1.0000     1.0000     37        
473        0.0131     0.0000     0.0000     0.0000     37        
209        0.2198     0.1429     0.2857     0.2857     90        
446        0.6870     0.5714     0.7143     0.8571     90        
44         0.1420     0.0909     0.1818     0.1818     93        
281        0.8030     0.6364     1.0000     1.0000     93        
153        0.0651     0.0000     0.1000     0.1000     99        
390        0.0669     0.0000     0.1000     0.2000     99        
207        0.2648     0.0000     0.5556     0.7778     100       
444        0.0804     0.0000     0.0000     0.2222     100       
220        0.8333     0.6667     1.0000     1.0000     100       
457        0.7333     0.6667     0.6667     1.0000     100       
206        0.1347     0.0000     0.1667     0.3333     100       
443        0.0682     0.0000     0.1667     0.1667     100       
69        

In [22]:
detail_result = {}

# 读取测试文件并统计每个关系的 MRR、HIT@1、HIT@3、HIT@10
with open(test_file_path, 'r') as f:
    for line in f:
        h, r, t, rank = map(int, line.strip().split())
        if (h, r) not in detail_result:
            detail_result[(h,r)] = {"ranks": [], "relation_tails": len(rt_dict.get(r, []))}
            if (h,r) not in query_dict:
                detail_result[(h,r)]['query_tails'] = 0
            else:
                detail_result[(h,r)]['query_tails'] = len(query_dict[(h,r)])
        detail_result[(h,r)]["ranks"].append(rank)

# 计算每个 r 的统计信息
for (h, r) in detail_result:
    ranks = np.array(detail_result[(h,r)]["ranks"])
    detail_result[(h,r)]["MRR"] = np.mean(1.0 / ranks)
    detail_result[(h,r)]["HIT@1"] = np.mean(ranks <= 1)
    detail_result[(h,r)]["HIT@3"] = np.mean(ranks <= 3)
    detail_result[(h,r)]["HIT@10"] = np.mean(ranks <= 10)
    detail_result[(h,r)]["query_num"] = len(ranks)


# 过滤和排序
filtered_results = [
    (key, value) for key, value in detail_result.items()
    if value['query_tails'] > 0
]
sorted_results = sorted(
    filtered_results,
    key=lambda x: x[1]['relation_tails'] - x[1]['query_tails']
)

# 打印排序后的详细信息
# 打印排序后的详细信息，包括 query_num
print(f"{'Head':<10} {'Relation':<10} {'MRR':<10} {'HIT@1':<10} {'HIT@3':<10} {'HIT@10':<10} {'relation_tails':<12} {'query_tails':<12} {'Tail-Query Diff':<15} {'Query Num':<12}")
for (h, r), stats in sorted_results:
    tail_query_diff = stats['relation_tails'] - stats['query_tails']
    print(f"{h:<10} {r:<10} {stats['MRR']:<10.4f} {stats['HIT@1']:<10.4f} {stats['HIT@3']:<10.4f} {stats['HIT@10']:<10.4f} {stats['relation_tails']:<12} {stats['query_tails']:<12} {tail_query_diff:<15} {stats['query_num']:<12}")


output_file = '/root/sample_negative_case_study.txt'  # 输出文件路径

# 打印和写入排序后的详细信息，包括 query_num
with open(output_file, 'w') as fout:
    # 写入表头
    header = f"{'Head':<10} {'Relation':<10} {'MRR':<10} {'HIT@1':<10} {'HIT@3':<10} {'HIT@10':<10} {'relation_tails':<12} {'query_tails':<12} {'Tail-Query Diff':<15} {'Query Num':<12}\n"
    print(header.strip())  # 打印表头
    fout.write(header)  # 写入表头

    # 写入数据
    for (h, r), stats in sorted_results:
        tail_query_diff = stats['relation_tails'] - stats['query_tails']
        line = f"{h:<10} {r:<10} {stats['MRR']:<10.4f} {stats['HIT@1']:<10.4f} {stats['HIT@3']:<10.4f} {stats['HIT@10']:<10.4f} {stats['relation_tails']:<12} {stats['query_tails']:<12} {tail_query_diff:<15} {stats['query_num']:<12}\n"
        print(line.strip())  # 打印到命令行
        fout.write(line)  # 写入文件

print(f"Detailed results written to {output_file}")


FileNotFoundError: [Errno 2] No such file or directory: '/root/Low_Dimension_KGC/models/FB15k-237_2000/test_detail_result.txt'