# TransE的pytorch实现

TransE不是神经网络模型，它无法理解句子的含义，因而得分不高，随便跑跑mrr大概在42左右，调下参大概有45-50。

TransE的优势在于不吃算力，仅靠CPU就能运行，这里我写的GPU版本速度更快。而且代码简单。

又试了一下，我已经复现不出自己的得分了，太玄学了，大概策略就是先norm取1得到最优模型，加载该模型并norm改为2训练得到最优模型，再加载该模型norm取3得到最优模型，再反复调整后最优分数就是50左右；
在算力允许的情况下，batchsize尽可能大，比如10万-150万（就是全部）；然后就是embedding维度我感觉影响不大，100维包含的信息足够丰富了，试过256和512，貌似大一点有时要好一些，不太确定，太大了会比较费算力；学习率粗调可以在0.01-0.001之间，微调就在0.001-0.0001之间吧。

提供给大家学习，代码也可以在我的github下载：https://github.com/renqi1/TransE_Pytorch_OpenBG500


In [1]:
import torch
from torch import nn
from torch.utils import data
import numpy as np
import tqdm

## 构建数据集

In [2]:
# 训练集和验证集
class TripleDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, triple_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = triple_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation, tail = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        tail_id = self.ent2id[tail]
        return head_id, relation_id, tail_id

# 测试集    
class TestDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, test_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = test_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        return head_id, relation_id

## TransE模型

In [3]:
class TransE(nn.Module):

    def __init__(self, entity_num, relation_num, norm=1, dim=100):
        super(TransE, self).__init__()
        self.norm = norm
        self.dim = dim
        self.entity_num = entity_num
        self.entities_emb = self._init_emb(entity_num)
        self.relations_emb = self._init_emb(relation_num)

    def _init_emb(self, num_embeddings):
        embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=self.dim)
        uniform_range = 6 / np.sqrt(self.dim)
        embedding.weight.data.uniform_(-uniform_range, uniform_range)
        embedding.weight.data = torch.div(embedding.weight.data, embedding.weight.data.norm(p=2, dim=1, keepdim=True))
        return embedding

    def forward(self, positive_triplets: torch.LongTensor, negative_triplets: torch.LongTensor):
        positive_distances = self._distance(positive_triplets.cuda())
        negative_distances = self._distance(negative_triplets.cuda())
        return positive_distances, negative_distances

    def _distance(self, triplets):
        heads = self.entities_emb(triplets[:, 0])
        relations = self.relations_emb(triplets[:, 1])
        tails = self.entities_emb(triplets[:, 2])
        return (heads + relations - tails).norm(p=self.norm, dim=1)

    def link_predict(self, head, relation, tail=None, k=10):
        # h_add_r: [batch size, embed size] -> [batch size, 1, embed size] -> [batch size, entity num, embed size]
        h_add_r = self.entities_emb(head) + self.relations_emb(relation)
        h_add_r = torch.unsqueeze(h_add_r, dim=1)
        h_add_r = h_add_r.expand(h_add_r.shape[0], self.entity_num, self.dim)
        # embed_tail: [batch size, embed size] -> [batch size, entity num, embed size]
        embed_tail = self.entities_emb.weight.data.expand(h_add_r.shape[0], self.entity_num, self.dim)
        # values: [batch size, k] scores, the smaller, the better
        # indices: [batch size, k] indices of entities ranked by scores
        values, indices = torch.topk(torch.norm(h_add_r - embed_tail, dim=2), k=self.entity_num, dim=1, largest=False)
        if tail is not None:
            tail = tail.view(-1, 1)
            rank_num = torch.eq(indices, tail).nonzero().permute(1, 0)[1]+1
            rank_num[rank_num > 9] = 10000
            mrr = torch.sum(1/rank_num)
            hits_1_num = torch.sum(torch.eq(indices[:, :1], tail)).item()
            hits_3_num = torch.sum(torch.eq(indices[:, :3], tail)).item()
            hits_10_num = torch.sum(torch.eq(indices[:, :10], tail)).item()
            return mrr, hits_1_num, hits_3_num, hits_10_num     # 返回一个batchsize, mrr的和，hit@k的和
        return indices[:, :k]

    def evaluate(self, data_loader, dev_num=5000.0):
        mrr_sum = hits_1_nums = hits_3_nums = hits_10_nums = 0
        for heads, relations, tails in tqdm.tqdm(data_loader):
            mrr_sum_batch, hits_1_num, hits_3_num, hits_10_num = self.link_predict(heads.cuda(), relations.cuda(), tails.cuda())
            mrr_sum += mrr_sum_batch
            hits_1_nums += hits_1_num
            hits_3_nums += hits_3_num
            hits_10_nums += hits_10_num
        return mrr_sum/dev_num, hits_1_nums/dev_num, hits_3_nums/dev_num, hits_10_nums/dev_num

## 设置参数

In [4]:
# batchsize增大，得分略有上升
train_batch_size = 100000
dev_batch_size = 20  # 显存不够就调小
test_batch_size = 20
epochs = 40
margin = 1
print_frequency = 5  # 每多少step输出一次信息
validation = True  # 是否验证，验证比较费时
dev_interval = 5  # 每多少轮验证一次，微调设小一点，会保存最佳权重
best_mrr = 0
learning_rate = 0.001  # 学习率建议粗调0.01-0.001，精调0.001-0.0001
distance_norm = 3  # 论文是L1距离效果不好，取2或3效果好
embedding_dim = 100  # 维度增大可能会有提升，我感觉没用，100维包含的信息足够丰富

## 加载数据集

In [5]:
with open('OpenBG500_entity2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip('\n').split('\t') for line in dat]
ent2id = {line[0]: i for i, line in enumerate(lines)}
id2ent = {i: line[0] for i, line in enumerate(lines)}
with open('OpenBG500_relation2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip().split('\t') for line in dat]
rel2id = {line[0]: i for i, line in enumerate(lines)}
with open('OpenBG500_train.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    train = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500_dev.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    dev = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500_test.tsv', 'r', encoding='utf-8') as fp:
    test = fp.readlines()
    test = [line.strip('\n').split('\t') for line in test]
# 构建数据集
train_dataset = TripleDataset(ent2id, rel2id, train)
dev_dataset = TripleDataset(ent2id, rel2id, dev)
train_data_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
dev_data_loader = data.DataLoader(dev_dataset, batch_size=dev_batch_size)
test_dataset = TestDataset(ent2id, rel2id, test)
test_data_loader = data.DataLoader(test_dataset, batch_size=test_batch_size)

## 训练和验证

In [6]:
# 构建模型
model = TransE(len(ent2id), len(rel2id), norm=distance_norm, dim=embedding_dim).cuda()
# model.load_state_dict(torch.load('transE_best.pth'))
# 优化器adam
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 损失函数， 对于本例，loss=max(0, (pd-nd)+1)， 负样本距离越小，正样本距离越大越好
criterion = nn.MarginRankingLoss(margin=margin, reduction='mean')
# 开始训练
print('start training...')
for epoch in range(epochs):
    all_loss = 0
    for i, (local_heads, local_relations, local_tails) in enumerate(train_data_loader):

        positive_triples = torch.stack((local_heads, local_relations, local_tails), dim=1).cuda()

        # 生成负样本
        head_or_tail = torch.randint(high=2, size=local_heads.size())
        random_entities = torch.randint(high=len(ent2id), size=local_heads.size())
        broken_heads = torch.where(head_or_tail == 1, random_entities, local_heads)
        broken_tails = torch.where(head_or_tail == 0, random_entities, local_tails)
        negative_triples = torch.stack((broken_heads, local_relations, broken_tails), dim=1).cuda()

        # # 生成负样本, 只打乱tail
        # random_entities = torch.randint(high=len(ent2id), size=local_heads.size())
        # negative_triples = torch.stack((random_entities, local_relations, random_entities), dim=1).cuda()

        optimizer.zero_grad()
        pd, nd = model(positive_triples, negative_triples)
        # pd要尽可能小， nd要尽可能大
        loss = criterion(pd, nd, torch.tensor([-1], dtype=torch.long).cuda())
        loss.backward()
        all_loss += loss.data
        optimizer.step()
        if i % print_frequency == 0:
            print(
                f"epoch:{epoch}/{epochs}, step:{i}/{len(train_data_loader)}, loss={loss.item()}, avg_loss={all_loss / (i + 1)}")
    print(f"epoch:{epoch}/{epochs}, all_loss={all_loss}")

    # 验证
    if validation and (epoch + 1) % dev_interval == 0:
        print('testing...')
        improve = ''
        mrr, hits1, hits3, hits10 = model.evaluate(dev_data_loader)
        if mrr >= best_mrr:
            best_mrr = mrr
            improve = '*'
            torch.save(model.state_dict(), 'transE_best.pth')
        torch.save(model.state_dict(), 'transE_latest.pth')
        print(f'mrr: {mrr}, hit@1: {hits1}, hit@3: {hits3}, hit@10: {hits10}  {improve}')
    if not validation:
        torch.save(model.state_dict(), 'transE_latest.pth')

start training...
epoch:0/40, step:0/13, loss=1.0004600286483765, avg_loss=1.0004600286483765
epoch:0/40, step:5/13, loss=0.9924210906028748, avg_loss=0.9964226484298706
epoch:0/40, step:10/13, loss=0.984215259552002, avg_loss=0.9923558831214905
epoch:0/40, all_loss=12.878626823425293
epoch:1/40, step:0/13, loss=0.9761244058609009, avg_loss=0.9761244058609009
epoch:1/40, step:5/13, loss=0.9668070077896118, avg_loss=0.9715768694877625
epoch:1/40, step:10/13, loss=0.9587149620056152, avg_loss=0.9672104716300964
epoch:1/40, all_loss=12.5520658493042
epoch:2/40, step:0/13, loss=0.9526633620262146, avg_loss=0.9526633620262146
epoch:2/40, step:5/13, loss=0.9433406591415405, avg_loss=0.94773268699646
epoch:2/40, step:10/13, loss=0.9345195293426514, avg_loss=0.943464994430542
epoch:2/40, all_loss=12.2422513961792
epoch:3/40, step:0/13, loss=0.9284957647323608, avg_loss=0.9284957647323608
epoch:3/40, step:5/13, loss=0.9203886985778809, avg_loss=0.9242886304855347
epoch:3/40, step:10/13, loss=0.

  0%|          | 1/250 [00:00<00:46,  5.30it/s]

epoch:4/40, all_loss=11.642279624938965
testing...


100%|██████████| 250/250 [00:06<00:00, 39.07it/s]


mrr: 0.026399999856948853, hit@1: 0.0264, hit@3: 0.386, hit@10: 0.4986  *
epoch:5/40, step:0/13, loss=0.8833775520324707, avg_loss=0.8833775520324707
epoch:5/40, step:5/13, loss=0.8748413920402527, avg_loss=0.8788034319877625
epoch:5/40, step:10/13, loss=0.8668171763420105, avg_loss=0.8747958540916443
epoch:5/40, all_loss=11.350102424621582
epoch:6/40, step:0/13, loss=0.8607047200202942, avg_loss=0.8607047200202942
epoch:6/40, step:5/13, loss=0.8524377942085266, avg_loss=0.8565347194671631
epoch:6/40, step:10/13, loss=0.8449050784111023, avg_loss=0.8524928092956543
epoch:6/40, all_loss=11.0615234375
epoch:7/40, step:0/13, loss=0.8379879593849182, avg_loss=0.8379879593849182
epoch:7/40, step:5/13, loss=0.8299741744995117, avg_loss=0.833824872970581
epoch:7/40, step:10/13, loss=0.8220023512840271, avg_loss=0.8298749923706055
epoch:7/40, all_loss=10.767749786376953
epoch:8/40, step:0/13, loss=0.8162221908569336, avg_loss=0.8162221908569336
epoch:8/40, step:5/13, loss=0.8068988919258118, a

  2%|▏         | 4/250 [00:00<00:06, 36.73it/s]

epoch:9/40, all_loss=10.168965339660645
testing...


100%|██████████| 250/250 [00:06<00:00, 40.03it/s]


mrr: 0.31839999556541443, hit@1: 0.3184, hit@3: 0.5502, hit@10: 0.7048  *
epoch:10/40, step:0/13, loss=0.7709746956825256, avg_loss=0.7709746956825256
epoch:10/40, step:5/13, loss=0.7604309320449829, avg_loss=0.765308141708374
epoch:10/40, step:10/13, loss=0.7523741722106934, avg_loss=0.7608752250671387
epoch:10/40, all_loss=9.868080139160156
epoch:11/40, step:0/13, loss=0.7459456324577332, avg_loss=0.7459456324577332
epoch:11/40, step:5/13, loss=0.7371506094932556, avg_loss=0.7412559390068054
epoch:11/40, step:10/13, loss=0.7281594276428223, avg_loss=0.7366960644721985
epoch:11/40, all_loss=9.551908493041992
epoch:12/40, step:0/13, loss=0.7223385572433472, avg_loss=0.7223385572433472
epoch:12/40, step:5/13, loss=0.712777316570282, avg_loss=0.7171534895896912
epoch:12/40, step:10/13, loss=0.7030938863754272, avg_loss=0.7125251889228821
epoch:12/40, all_loss=9.23984432220459
epoch:13/40, step:0/13, loss=0.6953514814376831, avg_loss=0.6953514814376831
epoch:13/40, step:5/13, loss=0.68605

  2%|▏         | 4/250 [00:00<00:06, 36.55it/s]

epoch:14/40, all_loss=8.590277671813965
testing...


100%|██████████| 250/250 [00:06<00:00, 40.10it/s]


mrr: 0.3951999843120575, hit@1: 0.3952, hit@3: 0.5966, hit@10: 0.7438  *
epoch:15/40, step:0/13, loss=0.6451003551483154, avg_loss=0.6451003551483154
epoch:15/40, step:5/13, loss=0.6357419490814209, avg_loss=0.641369104385376
epoch:15/40, step:10/13, loss=0.6278892159461975, avg_loss=0.6366351246833801
epoch:15/40, all_loss=8.252645492553711
epoch:16/40, step:0/13, loss=0.6209930181503296, avg_loss=0.6209930181503296
epoch:16/40, step:5/13, loss=0.6099186539649963, avg_loss=0.6163473129272461
epoch:16/40, step:10/13, loss=0.6017064452171326, avg_loss=0.6112903952598572
epoch:16/40, all_loss=7.92219877243042
epoch:17/40, step:0/13, loss=0.5952631831169128, avg_loss=0.5952631831169128
epoch:17/40, step:5/13, loss=0.5852000713348389, avg_loss=0.5905779004096985
epoch:17/40, step:10/13, loss=0.5738667249679565, avg_loss=0.585422158241272
epoch:17/40, all_loss=7.585738658905029
epoch:18/40, step:0/13, loss=0.5691049695014954, avg_loss=0.5691049695014954
epoch:18/40, step:5/13, loss=0.560958

  2%|▏         | 5/250 [00:00<00:06, 40.19it/s]

epoch:19/40, all_loss=6.947169303894043
testing...


100%|██████████| 250/250 [00:06<00:00, 39.86it/s]


mrr: 0.4195999801158905, hit@1: 0.4196, hit@3: 0.613, hit@10: 0.7562  *
epoch:20/40, step:0/13, loss=0.5202771425247192, avg_loss=0.5202771425247192
epoch:20/40, step:5/13, loss=0.5125833749771118, avg_loss=0.5165330767631531
epoch:20/40, step:10/13, loss=0.5044776797294617, avg_loss=0.51270592212677
epoch:20/40, all_loss=6.646373271942139
epoch:21/40, step:0/13, loss=0.4988667964935303, avg_loss=0.4988667964935303
epoch:21/40, step:5/13, loss=0.49106523394584656, avg_loss=0.49422797560691833
epoch:21/40, step:10/13, loss=0.4829302132129669, avg_loss=0.4908514618873596
epoch:21/40, all_loss=6.362898349761963
epoch:22/40, step:0/13, loss=0.47766733169555664, avg_loss=0.47766733169555664
epoch:22/40, step:5/13, loss=0.4700329303741455, avg_loss=0.4742520749568939
epoch:22/40, step:10/13, loss=0.4629727900028229, avg_loss=0.47101855278015137
epoch:22/40, all_loss=6.105311393737793
epoch:23/40, step:0/13, loss=0.4578325152397156, avg_loss=0.4578325152397156
epoch:23/40, step:5/13, loss=0.4

  2%|▏         | 4/250 [00:00<00:06, 36.51it/s]

epoch:24/40, all_loss=5.637955188751221
testing...


100%|██████████| 250/250 [00:06<00:00, 40.03it/s]


mrr: 0.42980000376701355, hit@1: 0.4298, hit@3: 0.6226, hit@10: 0.7646  *
epoch:25/40, step:0/13, loss=0.4240056276321411, avg_loss=0.4240056276321411
epoch:25/40, step:5/13, loss=0.4187479615211487, avg_loss=0.4217613935470581
epoch:25/40, step:10/13, loss=0.4119781255722046, avg_loss=0.41844654083251953
epoch:25/40, all_loss=5.423857688903809
epoch:26/40, step:0/13, loss=0.40720611810684204, avg_loss=0.40720611810684204
epoch:26/40, step:5/13, loss=0.40435922145843506, avg_loss=0.40535134077072144
epoch:26/40, step:10/13, loss=0.3975265622138977, avg_loss=0.40281590819358826
epoch:26/40, all_loss=5.228335857391357
epoch:27/40, step:0/13, loss=0.3926262855529785, avg_loss=0.3926262855529785
epoch:27/40, step:5/13, loss=0.38815367221832275, avg_loss=0.39130398631095886
epoch:27/40, step:10/13, loss=0.3843480348587036, avg_loss=0.38882681727409363
epoch:27/40, all_loss=5.042997360229492
epoch:28/40, step:0/13, loss=0.3785274922847748, avg_loss=0.3785274922847748
epoch:28/40, step:5/13, 

  2%|▏         | 4/250 [00:00<00:06, 36.75it/s]

epoch:29/40, all_loss=4.700049877166748
testing...


100%|██████████| 250/250 [00:06<00:00, 40.12it/s]


mrr: 0.43199998140335083, hit@1: 0.432, hit@3: 0.6212, hit@10: 0.7682  *
epoch:30/40, step:0/13, loss=0.3555942177772522, avg_loss=0.3555942177772522
epoch:30/40, step:5/13, loss=0.351566344499588, avg_loss=0.35308966040611267
epoch:30/40, step:10/13, loss=0.34627529978752136, avg_loss=0.3509579598903656
epoch:30/40, all_loss=4.551311492919922
epoch:31/40, step:0/13, loss=0.3436848819255829, avg_loss=0.3436848819255829
epoch:31/40, step:5/13, loss=0.3413189649581909, avg_loss=0.34217900037765503
epoch:31/40, step:10/13, loss=0.33666178584098816, avg_loss=0.340436190366745
epoch:31/40, all_loss=4.416213035583496
epoch:32/40, step:0/13, loss=0.3317122459411621, avg_loss=0.3317122459411621
epoch:32/40, step:5/13, loss=0.33024853467941284, avg_loss=0.33117160201072693
epoch:32/40, step:10/13, loss=0.325892835855484, avg_loss=0.32948869466781616
epoch:32/40, all_loss=4.27821683883667
epoch:33/40, step:0/13, loss=0.32548654079437256, avg_loss=0.32548654079437256
epoch:33/40, step:5/13, loss=

  2%|▏         | 4/250 [00:00<00:06, 36.87it/s]

epoch:34/40, all_loss=4.0526909828186035
testing...


100%|██████████| 250/250 [00:06<00:00, 40.10it/s]


mrr: 0.4235999882221222, hit@1: 0.4236, hit@3: 0.6152, hit@10: 0.7704  
epoch:35/40, step:0/13, loss=0.3064205050468445, avg_loss=0.3064205050468445
epoch:35/40, step:5/13, loss=0.3014841079711914, avg_loss=0.3046500086784363
epoch:35/40, step:10/13, loss=0.30210477113723755, avg_loss=0.30372706055641174
epoch:35/40, all_loss=3.9455783367156982
epoch:36/40, step:0/13, loss=0.29986220598220825, avg_loss=0.29986220598220825
epoch:36/40, step:5/13, loss=0.2972918450832367, avg_loss=0.2990383207798004
epoch:36/40, step:10/13, loss=0.2935571074485779, avg_loss=0.2972831428050995
epoch:36/40, all_loss=3.8579840660095215
epoch:37/40, step:0/13, loss=0.29156604409217834, avg_loss=0.29156604409217834
epoch:37/40, step:5/13, loss=0.2910342216491699, avg_loss=0.2913801372051239
epoch:37/40, step:10/13, loss=0.2894535958766937, avg_loss=0.2902997136116028
epoch:37/40, all_loss=3.7649052143096924
epoch:38/40, step:0/13, loss=0.2863129675388336, avg_loss=0.2863129675388336
epoch:38/40, step:5/13, lo

  2%|▏         | 5/250 [00:00<00:06, 40.53it/s]

epoch:39/40, all_loss=3.5971519947052
testing...


100%|██████████| 250/250 [00:06<00:00, 40.31it/s]


mrr: 0.4203999936580658, hit@1: 0.4204, hit@3: 0.6144, hit@10: 0.7728  


## 预测

In [7]:
predict_all = []
model.load_state_dict(torch.load('transE_best.pth'))
for heads, relations in tqdm.tqdm(test_data_loader):
    # 预测的id,结果为tensor(batch_size*10)
    predict_id = model.link_predict(heads.cuda(), relations.cuda())
    # 结果取到cpu并转为一行的list以便迭代
    predict_list = predict_id.cpu().numpy().reshape(1,-1).squeeze(0).tolist()
    # id转为实体
    predict_ent = map(lambda x: id2ent[x], predict_list)
    # 保存结果
    predict_all.extend(predict_ent)
print('prediction finished !')

100%|██████████| 250/250 [00:05<00:00, 42.09it/s]

prediction finished !





## 写入文件并保存

In [8]:
# 写入文件，按提交要求
with open('submission.tsv', 'w', encoding='utf-8') as f:
    for i in range(len(test)):
        # 直接writelines没有空格分隔，手工加分割符，得按提交格式来
        list = [x + '\t' for x in test[i]] + [x + '\n' if i == 9 else x + '\t' for i, x in enumerate(predict_all[i*10:i*10+10])]
        f.writelines(list)
print('file saved !')

file saved !
