In [21]:
# === 0. 环境 & 配置 ===
import os, math, random, collections, time
from typing import Dict, Set, Tuple, List
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
try:
    from tqdm import tqdm
except:
    tqdm = lambda x, **k: x

SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DATA_DIR = r"D:\workspace\***\小论文-知识图谱\data\OpenBG500"
TRAIN = os.path.join(DATA_DIR, "OpenBG500_train.tsv")
DEV   = os.path.join(DATA_DIR, "OpenBG500_dev.tsv")   # 验证集=valid
# TEST  = os.path.join(DATA_DIR, "OpenBG500_test.tsv") # 竞赛版的test数据集，缺少尾实体列 ，无法用作模型评测
TEST  = os.path.join(DATA_DIR, "OpenBG500_answer.tsv") # 评测使用answer文件，作为测试集

EMB_DIM = 200     # 显存吃紧可降到100
P_NORM  = 1       # 1-> L1 (TransE-L1), 2 -> L2
MARGIN  = 12.0    # 与 DGL-KE 的 gamma 类似
LR      = 1e-3
BATCH   = 1024
NEG_K   = 25      # 每个正样本的负采样数
EPOCHS  = 30
EARLY_STOP_PATIENCE = 3  # 验证集MRR若不提升则早停
EVAL_BLOCK = 20000       # 评测时实体分块，防止显存爆

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [22]:
# === 1. 数据读取与编码 ===
def read_triples(path: str):
    triples = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            p = line.rstrip('\n').split('\t')
            if len(p) >= 3:
                triples.append((p[0], p[1], p[2]))
    return triples

train_triples = read_triples(TRAIN)
dev_triples   = read_triples(DEV)
test_triples  = read_triples(TEST)

# 所有集合一起建字典，保证ID一致
ents, rels = set(), set()
for (h,r,t) in (train_triples + dev_triples + test_triples):
    ents.add(h); ents.add(t); rels.add(r)
ent2id = {e:i for i,e in enumerate(sorted(ents))}
rel2id = {r:i for i,r in enumerate(sorted(rels))}
id2ent = {i:e for e,i in ent2id.items()}  # 需要可视化时可用

def encode(triples):
    import numpy as np
    if len(triples) == 0:
        # 返回二维空张量，避免后续 vstack 维度不匹配
        return torch.empty((0, 3), dtype=torch.long)
    arr = np.array([[ent2id[h], rel2id[r], ent2id[t]] for h,r,t in triples], dtype='int64')
    return torch.from_numpy(arr)

train_ids = encode(train_triples)
dev_ids   = encode(dev_triples)
test_ids  = encode(test_triples)

n_ent = len(ent2id); n_rel = len(rel2id)
n_ent, n_rel, len(train_ids), len(dev_ids), len(test_ids)

(249746, 500, 1242550, 5000, 5000)

In [23]:
# === 2. 构建 Filter 字典（Filtered 评测要用） ===
def build_filter_dict(all_ids: torch.Tensor):
    hr2t: Dict[Tuple[int,int], Set[int]] = collections.defaultdict(set)
    rt2h: Dict[Tuple[int,int], Set[int]] = collections.defaultdict(set)
    for h,r,t in all_ids.tolist():
        hr2t[(h,r)].add(t)
        rt2h[(r,t)].add(h)
    return hr2t, rt2h

hr2t, rt2h = build_filter_dict(torch.vstack([train_ids, dev_ids, test_ids]))

In [16]:
# === 3. 数据集与负采样（随机打头或打尾，二选一） ===
class KGDataset(Dataset):
    def __init__(self, triples: torch.Tensor):
        self.data = triples
    def __len__(self): return self.data.size(0)
    def __getitem__(self, idx): return self.data[idx]

def negative_sample(batch_pos: torch.Tensor, n_ent: int, num_negs=NEG_K, device=DEVICE):
    # batch_pos: [B, 3] -> 返回 [B*num_negs, 3]
    B = batch_pos.size(0)
    mask = (torch.rand(B, device=device) < 0.5)
    neg = batch_pos.unsqueeze(1).repeat(1, num_negs, 1).clone()
    rand_ents = torch.randint(0, n_ent, (B, num_negs), device=device)
    neg[mask, :, 0] = rand_ents[mask]        # 打头
    neg[~mask, :, 2] = rand_ents[~mask]      # 打尾
    return neg.view(-1, 3)

In [17]:
# === 4. 模型：TransE ===
class TransE(nn.Module):
    def __init__(self, n_ent, n_rel, dim=200, p=P_NORM, margin=MARGIN):
        super().__init__()
        self.ent = nn.Embedding(n_ent, dim)
        self.rel = nn.Embedding(n_rel, dim)
        self.p = p
        self.margin = margin
        # Xavier-like init
        r = 6.0 / math.sqrt(dim)
        nn.init.uniform_(self.ent.weight, -r, r)
        nn.init.uniform_(self.rel.weight, -r, r)
        self.rel.weight.data = nn.functional.normalize(self.rel.weight.data, p=2, dim=1)

    def score(self, h, r, t):
        # 距离越小越好 -> 我们返回 -distance 作为分数（越大越好）
        h_e = self.ent(h); r_e = self.rel(r); t_e = self.ent(t)
        if self.p == 1:
            dist = (h_e + r_e - t_e).abs().sum(-1)
        else:
            dist = torch.linalg.vector_norm(h_e + r_e - t_e, ord=2, dim=-1)
        return -dist

model = TransE(n_ent, n_rel, EMB_DIM, P_NORM, MARGIN).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [18]:
# === 5. 训练循环（Margin Ranking Loss） ===
def train_one_epoch(model, loader):
    model.train()
    total = 0.0
    for pos in tqdm(loader, desc="train"):
        pos = pos.to(DEVICE)
        neg = negative_sample(pos, n_ent, NEG_K, DEVICE)
        # 正/负打分
        pos_s = model.score(pos[:,0], pos[:,1], pos[:,2])
        neg_s = model.score(neg[:,0], neg[:,1], neg[:,2]).view(pos.size(0), -1).mean(-1)  # 多个负样本取均值
        # margin ranking: max(0, margin + neg - pos)
        loss = torch.relu(model.margin + neg_s - pos_s).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += float(loss)
    return total / len(loader)

In [19]:
# === 6. 评测：Filtered MRR / Hits@K（分块计算，防爆显存） ===
@torch.no_grad()
def evaluate_filtered(model, triples: torch.Tensor, hr2t, rt2h,
                      k_list=(1,3,10), block=EVAL_BLOCK, device=DEVICE):
    model.eval()
    ent_emb = model.ent.weight.to(device)
    def rank_tail(batch):  # (h,r,?) 预测 t
        h, r, t = batch[:,0], batch[:,1], batch[:,2]
        B = h.size(0)
        best_rank = torch.full((B,), float('inf'), device=device)
        # 预先算 hr
        hr = model.ent(h).to(device) + model.rel(r).to(device)
        start = 0
        while start < n_ent:
            end = min(start+block, n_ent)
            cand = ent_emb[start:end]  # [M,D]
            if P_NORM == 1:
                dist = (hr.unsqueeze(1) - cand.unsqueeze(0)).abs().sum(-1)
            else:
                dist = torch.linalg.vector_norm(hr.unsqueeze(1) - cand.unsqueeze(0), ord=2, dim=-1)
            scores = -dist  # 越大越好
            # filtered：屏蔽其它已知真 tail
            for i in range(B):
                true_ts = hr2t.get((int(h[i]), int(r[i])), set())
                if true_ts:
                    idx = [x for x in true_ts if start <= x < end and x != int(t[i])]
                    if idx:
                        scores[i, torch.tensor([x-start for x in idx], device=device)] = -1e9
            # gold 在本块内则计算排名
            in_blk = (t >= start) & (t < end)
            if in_blk.any():
                gold = scores[in_blk, (t[in_blk]-start)]
                # 1 + (# candidates with score > gold)
                comp = (scores[in_blk] > gold.unsqueeze(1)).sum(dim=1).float() + 1
                best_rank[in_blk] = torch.minimum(best_rank[in_blk], comp)
            start = end
        return best_rank

    def rank_head(batch):  # (?,r,t) 预测 h
        h, r, t = batch[:,0], batch[:,1], batch[:,2]
        B = h.size(0)
        best_rank = torch.full((B,), float('inf'), device=device)
        rt = model.ent(t).to(device) - model.rel(r).to(device)
        start = 0
        while start < n_ent:
            end = min(start+block, n_ent)
            cand = ent_emb[start:end]
            if P_NORM == 1:
                dist = (cand.unsqueeze(1) - rt.unsqueeze(0)).abs().sum(-1)  # [M,B]
            else:
                dist = torch.linalg.vector_norm(cand.unsqueeze(1) - rt.unsqueeze(0), ord=2, dim=-1)
            scores = -dist.transpose(0,1).contiguous()  # [B,M]
            for i in range(B):
                true_hs = rt2h.get((int(r[i]), int(t[i])), set())
                if true_hs:
                    idx = [x for x in true_hs if start <= x < end and x != int(h[i])]
                    if idx:
                        scores[i, torch.tensor([x-start for x in idx], device=device)] = -1e9
            in_blk = (h >= start) & (h < end)
            if in_blk.any():
                gold = scores[in_blk, (h[in_blk]-start)]
                comp = (scores[in_blk] > gold.unsqueeze(1)).sum(dim=1).float() + 1
                best_rank[in_blk] = torch.minimum(best_rank[in_blk], comp)
            start = end
        return best_rank

    ranks_h = []; ranks_t = []
    B = 128  # eval batch size（不大不小即可）
    for i in tqdm(range(0, len(triples), B), desc="eval"):
        batch = triples[i:i+B].to(device)
        ranks_t.append(rank_tail(batch))
        ranks_h.append(rank_head(batch))
    ranks = torch.cat(ranks_h + ranks_t, dim=0)
    mrr = (1.0 / ranks).mean().item()
    hits = {k: (ranks <= k).float().mean().item() for k in k_list}
    return mrr, hits

In [None]:
# === 7. 开始训练 + 早停，期间用 dev 做验证 ===
train_loader = DataLoader(KGDataset(train_ids), batch_size=BATCH, shuffle=True, drop_last=False)

best_dev_mrr = -1
no_improve = 0
best_state = None

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    loss = train_one_epoch(model, train_loader)
    dev_mrr, dev_hits = evaluate_filtered(model, dev_ids, hr2t, rt2h, (1,3,10), block=EVAL_BLOCK)
    t1 = time.time()
    print(f"[Epoch {epoch:02d}] loss={loss:.4f}  dev MRR={dev_mrr:.4f} "
          f"H@1={dev_hits[1]:.4f} H@3={dev_hits[3]:.4f} H@10={dev_hits[10]:.4f}  ({t1-t0:.1f}s)")

    if dev_mrr > best_dev_mrr:
        best_dev_mrr = dev_mrr
        best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= EARLY_STOP_PATIENCE:
            print("Early stopping.")
            break

# 载入最佳参数后在 test 上评测
if best_state is not None:
    model.load_state_dict(best_state)

test_mrr, test_hits = evaluate_filtered(model, test_ids, hr2t, rt2h, (1,3,10), block=EVAL_BLOCK)
print(f"\n[TEST] MRR={test_mrr:.4f}  Hits@1={test_hits[1]:.4f}  Hits@3={test_hits[3]:.4f}  Hits@10={test_hits[10]:.4f}")