In [1]:
import os, math, random, collections, time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
try:
    from tqdm import tqdm
except:
    tqdm = lambda x, **k: x

SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DATA_DIR =  r"D:\workspace\***\小论文-知识图谱\data\OpenBG500"
TRAIN = os.path.join(DATA_DIR, "OpenBG500_train.tsv")
DEV   = os.path.join(DATA_DIR, "OpenBG500_dev.tsv")
TEST  = os.path.join(DATA_DIR, "OpenBG500_answer.tsv")

EMB_DIM = 200   
MARGIN  = 6.0   
LR      = 1e-3
BATCH   = 1024
NEG_K   = 25
EPOCHS  = 30
EVAL_BLOCK = 20000

random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)


In [2]:
def read_triples(path: str):
    triples = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            p = line.rstrip('\n').split('\t')
            if len(p) >= 3:
                triples.append((p[0], p[1], p[2]))
    return triples

train_triples = read_triples(TRAIN)
dev_triples   = read_triples(DEV)
test_triples  = read_triples(TEST)

ents, rels = set(), set()
for (h,r,t) in (train_triples + dev_triples + test_triples):
    ents.add(h); ents.add(t); rels.add(r)
ent2id = {e:i for i,e in enumerate(sorted(ents))}
rel2id = {r:i for i,r in enumerate(sorted(rels))}

def encode(triples):
    return torch.tensor([[ent2id[h], rel2id[r], ent2id[t]] for h,r,t in triples], dtype=torch.long)

train_ids = encode(train_triples)
dev_ids   = encode(dev_triples)
test_ids  = encode(test_triples)

n_ent, n_rel = len(ent2id), len(rel2id)
print("Entities:", n_ent, "Relations:", n_rel, "Train:", len(train_ids), "Dev:", len(dev_ids), "Test:", len(test_ids))

Entities: 249746 Relations: 500 Train: 1242550 Dev: 5000 Test: 5000


In [3]:
def build_filter_dict(all_ids: torch.Tensor):
    hr2t = collections.defaultdict(set)
    rt2h = collections.defaultdict(set)
    for h,r,t in all_ids.tolist():
        hr2t[(h,r)].add(t)
        rt2h[(r,t)].add(h)
    return hr2t, rt2h

hr2t, rt2h = build_filter_dict(torch.vstack([train_ids, dev_ids, test_ids]))

In [4]:
class KGDataset(Dataset):
    def __init__(self, triples_tensor: torch.Tensor):
        self.data = triples_tensor
    def __len__(self): return self.data.size(0)
    def __getitem__(self, idx): return self.data[idx]

def negative_sample(batch_pos, n_ent, num_negs=NEG_K, device=DEVICE):
    B = batch_pos.size(0)
    mask = (torch.rand(B, device=device) < 0.5)
    neg = batch_pos.unsqueeze(1).repeat(1, num_negs, 1).clone()
    rand_ents = torch.randint(0, n_ent, (B, num_negs), device=device)
    neg[mask,:,0] = rand_ents[mask]     # 打头
    neg[~mask,:,2] = rand_ents[~mask]   # 打尾
    return neg.view(-1,3)

In [5]:
class ComplEx(nn.Module):
    def __init__(self, n_ent, n_rel, dim=200, margin=6.0):
        super().__init__()
        self.dim = dim
        self.margin = margin
        # 实体/关系均为 real+imag 两套参数
        self.ent_re = nn.Embedding(n_ent, dim)
        self.ent_im = nn.Embedding(n_ent, dim)
        self.rel_re = nn.Embedding(n_rel, dim)
        self.rel_im = nn.Embedding(n_rel, dim)
        # 初始化
        r = 0.1
        nn.init.uniform_(self.ent_re.weight, -r, r)
        nn.init.uniform_(self.ent_im.weight, -r, r)
        nn.init.uniform_(self.rel_re.weight, -r, r)
        nn.init.uniform_(self.rel_im.weight, -r, r)

    def score(self, h, r, t):
        # 返回越大越好；Re( h ∘ r ∘ conj(t) )
        h_re = self.ent_re(h); h_im = self.ent_im(h)
        t_re = self.ent_re(t); t_im = self.ent_im(t)
        r_re = self.rel_re(r); r_im = self.rel_im(r)
        # 先 h∘r： (a+ib)(c+id) = (ac - bd) + i(ad + bc)
        hr_re = h_re * r_re - h_im * r_im
        hr_im = h_re * r_im + h_im * r_re
        # 与 conj(t) 相乘的实部：hr_re * t_re + hr_im * t_im（逐维相加）
        s = (hr_re * t_re + hr_im * t_im).sum(dim=1)
        return s  # 分数

model = ComplEx(n_ent, n_rel, EMB_DIM, MARGIN).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


In [6]:
def train_one_epoch(model, loader, optimizer):
    model.train()
    total = 0.0
    for pos in tqdm(loader, desc="train"):
        pos = pos.to(DEVICE)
        neg = negative_sample(pos, n_ent, NEG_K, DEVICE)

        pos_s = model.score(pos[:,0], pos[:,1], pos[:,2])
        neg_s = model.score(neg[:,0], neg[:,1], neg[:,2]).view(pos.size(0), -1).mean(-1)

        loss = torch.relu(model.margin + neg_s - pos_s).mean()
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total += float(loss)
    return total / len(loader)

In [7]:
@torch.no_grad()
def evaluate_filtered(model, triples, hr2t, rt2h, k_list=(1,3,10), block=EVAL_BLOCK, device=DEVICE):
    model.eval()
    # 实体矩阵常驻CPU，分块搬GPU
    ent_re_cpu = model.ent_re.weight.detach().cpu()
    ent_im_cpu = model.ent_im.weight.detach().cpu()

    # tail 方向：score(h,r,*) = Re( (h∘r) ∘ conj(E) )，可化为矩阵乘：
    # scores = (hr_re @ E_re^T) + (hr_im @ E_im^T)
    def rank_tail(batch):
        h, r, t = batch[:,0], batch[:,1], batch[:,2]
        B = h.size(0)
        ranks = torch.full((B,), float('inf'), device=device)

        h_re = model.ent_re(h.to(device)); h_im = model.ent_im(h.to(device))
        r_re = model.rel_re(r.to(device)); r_im = model.rel_im(r.to(device))
        hr_re = h_re * r_re - h_im * r_im
        hr_im = h_re * r_im + h_im * r_re

        start = 0
        while start < ent_re_cpu.size(0):
            end = min(start + block, ent_re_cpu.size(0))
            E_re = ent_re_cpu[start:end].to(device)  # [M, D]
            E_im = ent_im_cpu[start:end].to(device)  # [M, D]
            # [B,D] @ [D,M] -> [B,M]
            scores = hr_re @ E_re.T + hr_im @ E_im.T
            # Filter：屏蔽其它真 tail
            for i in range(B):
                true_ts = hr2t.get((int(h[i]), int(r[i])), set())
                if true_ts:
                    mask_idx = [x for x in true_ts if (start <= x < end) and (x != int(t[i]))]
                    if mask_idx:
                        scores[i, torch.tensor([x-start for x in mask_idx], device=device)] = -1e-9
            in_blk = (t >= start) & (t < end)
            if in_blk.any():
                gold = scores[in_blk, (t[in_blk]-start)]
                rank = (scores[in_blk] > gold.unsqueeze(1)).sum(1).float() + 1
                ranks[in_blk.to(device)] = torch.minimum(ranks[in_blk.to(device)], rank)
            start = end
        return ranks

    # head 方向：score(*,r,t) = Re( H ∘ (r∘conj(t)) )
    # 令 w = r∘conj(t)，w_re = r_re*t_re + r_im*t_im；w_im = -r_re*t_im + r_im*t_re
    # 则 Re(h*w) = H_re @ w_re^T - H_im @ w_im^T
    def rank_head(batch):
        h, r, t = batch[:,0], batch[:,1], batch[:,2]
        B = h.size(0)
        ranks = torch.full((B,), float('inf'), device=device)

        r_re = model.rel_re(r.to(device)); r_im = model.rel_im(r.to(device))
        t_re = model.ent_re(t.to(device)); t_im = model.ent_im(t.to(device))
        w_re = r_re * t_re + r_im * t_im
        w_im = - r_re * t_im + r_im * t_re

        start = 0
        while start < ent_re_cpu.size(0):
            end = min(start + block, ent_re_cpu.size(0))
            H_re = ent_re_cpu[start:end].to(device)  # 候选头实体
            H_im = ent_im_cpu[start:end].to(device)
            # 注意 Re(h*w) = h_re·w_re - h_im·w_im
            scores = H_re @ w_re.T - H_im @ w_im.T    # [M,B]
            scores = scores.transpose(0,1).contiguous()  # [B,M]
            # Filter：屏蔽其它真 head
            for i in range(B):
                true_hs = rt2h.get((int(r[i]), int(t[i])), set())
                if true_hs:
                    mask_idx = [x for x in true_hs if (start <= x < end) and (x != int(h[i]))]
                    if mask_idx:
                        scores[i, torch.tensor([x-start for x in mask_idx], device=device)] = -1e-9
            in_blk = (h >= start) & (h < end)
            if in_blk.any():
                gold = scores[in_blk, (h[in_blk]-start)]
                rank = (scores[in_blk] > gold.unsqueeze(1)).sum(1).float() + 1
                ranks[in_blk.to(device)] = torch.minimum(ranks[in_blk.to(device)], rank)
            start = end
        return ranks

    B_eval = 64
    ranks_h, ranks_t = [], []
    for i in tqdm(range(0, len(triples), B_eval), desc="eval"):
        batch = triples[i:i+B_eval]
        if len(batch) == 0: continue
        batch = batch.to(device)
        ranks_t.append(rank_tail(batch))
        ranks_h.append(rank_head(batch))

    if not ranks_h and not ranks_t:
        return 0.0, {k:0.0 for k in k_list}

    ranks = torch.cat(ranks_h + ranks_t, dim=0)
    mrr = (1.0 / ranks).mean().item()
    hits = {k: (ranks <= k).float().mean().item() for k in k_list}
    return mrr, hits

In [None]:
train_loader = DataLoader(KGDataset(train_ids), batch_size=BATCH, shuffle=True)

best_dev_mrr, best_state, no_improve = -1, None, 0
PATIENCE = 3

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    loss = train_one_epoch(model, train_loader, optimizer)
    dev_mrr, dev_hits = evaluate_filtered(model, dev_ids, hr2t, rt2h, (1,3,10), block=EVAL_BLOCK, device=DEVICE)
    dt = time.time() - t0
    print(f"[Epoch {epoch:02d}] loss={loss:.4f}  dev MRR={dev_mrr:.4f}  "
          f"Hits@1={dev_hits[1]:.4f}  Hits@3={dev_hits[3]:.4f}  Hits@10={dev_hits[10]:.4f}  ({dt:.1f}s)")

    if dev_mrr > best_dev_mrr:
        best_dev_mrr, no_improve = dev_mrr, 0
        best_state = {k: v.cpu().clone() for k,v in model.state_dict().items()}
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

In [None]:
test_mrr, test_hits = evaluate_filtered(model, test_ids, hr2t, rt2h, (1,3,10), block=EVAL_BLOCK, device=DEVICE)
print(f"[TEST] MRR={test_mrr:.4f}  Hits@1={test_hits[1]:.4f}  Hits@3={test_hits[3]:.4f}  Hits@10={test_hits[10]:.4f}")