### 一.原理

链接预测是预测边的存在性，注意它与边预测任务很不同，边预测是去预测已存在的边的属性。但通过**负采样**的技巧，我们可以将链接预测的问题转换为边预测的问题，可以看作：   

$$
链接预测=负采样+边预测
$$  

我们可以这样理解，如果两节点间存在边，那么我们定义该边上的属性为1，如果不存在边那么我们定义该边上的属性为0，所以我们将链接预测问题就转换为了边上的1/0预测问题，如果俩节点上的预测值靠近1，我们就可以认为它们之间存在一条边，如果预测值靠近0，就认为它们之间不存在边。但是“不存在的边”的量往往很大，这需要考虑任意两两之间的连接，所以我们采用负采样，从所有不存在的边中随机采样部分出来训练，如下示例图：   

![avatar](./pic/链接预测.jpg)   

对于上面的假设，我们可以类似于logistic任务，使用交叉熵损失函数：   
$$
L=-log\sigma(y_{u,v})-\sum\{[1-log(y_{u,k})]\mid k\in P(u)\}
$$  

这里，$u,v$是存在连接的点，$P(u)$是对$u$的负采样点的集合，$y_{u,v}$类似于上一节输入两向量，输出一个标量的函数，比如做内积，而$\sigma(\cdot)$是sigmoid函数，将输出约束在(0,1)之间，除了交叉熵，我们还可以选择其他函数，[参考>>](https://docs.dgl.ai/guide_cn/training-link.html)

### 二.实现

这里需要实现的内容其实相比上一节主要多了两部分内容：   
（1）第一部分是多了负采样；   
（2）另一部分是需要修改损失函数的定义

In [1]:
import numpy as np
import torch
import dgl
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

Using backend: pytorch


In [2]:
#1.生成异构图
n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10

follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)

hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'click', 'item'): (click_src, click_dst),
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
# 在user类型的节点和click类型的边上随机生成训练集的掩码
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

In [3]:
#2.定义模型
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv，in_feats是输入特征的维度，out_feats是输出特征的维度，aggregate是聚合函数的类型
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # 输入是节点的特征字典
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h


class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h是从5.1节中对异构图的每种类型的边所计算的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']
        
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()

    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

In [4]:
#3.定义负采样函数，将负样本采样为另外一张图
def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

In [5]:
#4.定义损失函数
def compute_loss(pos_score, neg_score):
    # 间隔损失
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()

In [6]:
#5.训练模型
k = 5
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(5):
    negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))#这里只对"click"关系进行预测
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

1.327864170074463
1.3069148063659668
1.277003288269043
1.2585861682891846
1.246128797531128
