

## GraphSAGE 核心思想

GraphSAGE的核心：GraphSAGE不是试图学习一个图上所有node的embedding，而是学习一个为每个node产生embedding的映射。


论文中提出的方法称为GraphSAGE, SAGE指的是 Sample and Aggregate，不是对每个顶点都训练一个单独的embeddding向量，而是训练了一组aggregator functions，这些函数学习如何从一个顶点的局部邻居聚合特征信息。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候，使用训练好的系统，通过学习到的聚合函数来对完全未见过的顶点生成embedding。

![](https://img-blog.csdnimg.cn/img_convert/beebbdcaf3b468efee9b9ffdb530c3dd.png)
上面是为红色的目标节点生成embedding的过程。k表示距离目标节点的搜索深度，k=1就是目标节点的相邻节点，k=2表示目标节点相邻节点的相邻节点。
对于上图中的例子：
- 第一步是采样，k=1采样了3个节点，对k=2采用了5个节点；
- 第二步是聚合邻居节点的信息，获得目标节点的embedding；
- 第三步是使用聚合得到的信息，也就是目标节点的embedding,来预测图中想预测的信息;



此课件来自如下链接的整合

https://zhuanlan.zhihu.com/p/410407148

https://zhuanlan.zhihu.com/p/512929377

https://blog.csdn.net/weixin_44027006/article/details/116888648

https://www.heywhale.com/mw/project/608538b1c7cba5001752d619

## graphSAGE 源码

### 采样

In [5]:
# -*- coding: utf-8 -*-
import numpy as np

def sampling(src_nodes, sample_num, neighbor_table):
    """根据源节点采样指定数量的邻居节点，注意使用的是有放回的采样；
    某个节点的邻居节点数量少于采样数量时，采样结果出现重复的节点

    Arguments:
        src_nodes {list, ndarray} -- 源节点列表
        sample_num {int} -- 需要采样的节点数
        neighbor_table {dict} -- 节点到其邻居节点的映射表

    Returns:
        np.ndarray -- 采样结果构成的列表
    """
    results = []
    for sid in src_nodes:
        # # 从节点的邻居中进行有放回地进行采样
        # res = np.random.choice(neighbor_table[sid], size=(sample_num,))
        # results.append(res)


        if len(neighbor_table[sid]) >= sample_num:
            res = np.random.choice(neighbor_table[sid], size=(sample_num,),replace=False)
        else:
            res = np.random.choice(neighbor_table[sid], size=(sample_num,),replace=True)
        results.append(res)
    return np.asarray(results).flatten()


def multihop_sampling(src_nodes, sample_nums, neighbor_table):
    """根据源节点进行多阶采样

    Arguments:
        src_nodes {list, np.ndarray} -- 源节点id
        sample_nums {list of int} -- 每一阶需要采样的个数
        neighbor_table {dict} -- 节点到其邻居节点的映射

    Returns:
        [list of ndarray] -- 每一阶采样的结果
    """
    sampling_result = [src_nodes]
    # print("sampling result = ", sampling_result)
    # print("sample_nums = ", sample_nums)

    for k, hopk_num in enumerate(sample_nums):
        # print("sampling_result[k] = ", sampling_result[k])
        hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table)
        sampling_result.append(hopk_result)
    return sampling_result


### 聚合与训练

In [15]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class NeighborAggregator(nn.Module):
    def __init__(self, input_dim, output_dim,
                    use_bias=False, aggr_method="mean"):
        """聚合节点邻居

        Args:
            input_dim: 输入特征的维度
            output_dim: 输出特征的维度
            use_bias: 是否使用偏置 (default: {False})
            aggr_method: 邻居聚合方式 (default: {mean})
        """
        super(NeighborAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.aggr_method = aggr_method
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_dim))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, neighbor_feature):
        if self.aggr_method == "mean":
            aggr_neighbor = neighbor_feature.mean(dim=1)
        elif self.aggr_method == "sum":
            aggr_neighbor = neighbor_feature.sum(dim=1)
        elif self.aggr_method == "max":
            aggr_neighbor = neighbor_feature.max(dim=1)
        else:
            raise ValueError("Unknown aggr type, expected sum, max, or mean, but got {}"
                                .format(self.aggr_method))

        neighbor_hidden = torch.matmul(aggr_neighbor, self.weight)
        if self.use_bias:
            neighbor_hidden += self.bias

        return neighbor_hidden

    def extra_repr(self):
        return 'in_features={}, out_features={}, aggr_method={}'.format(
            self.input_dim, self.output_dim, self.aggr_method)


class SageGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                    activation=F.relu,
                    aggr_neighbor_method="mean",
                    aggr_hidden_method="sum"):
        """SageGCN层定义

        Args:
            input_dim: 输入特征的维度
            hidden_dim: 隐层特征的维度，
                当aggr_hidden_method=sum, 输出维度为hidden_dim
                当aggr_hidden_method=concat, 输出维度为hidden_dim*2
            activation: 激活函数
            aggr_neighbor_method: 邻居特征聚合方法，["mean", "sum", "max"]
            aggr_hidden_method: 节点特征的更新方法，["sum", "concat"]
        """
        super(SageGCN, self).__init__()
        assert aggr_neighbor_method in ["mean", "sum", "max"]
        assert aggr_hidden_method in ["sum", "concat"]
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.aggr_neighbor_method = aggr_neighbor_method
        self.aggr_hidden_method = aggr_hidden_method
        self.activation = activation
        self.aggregator = NeighborAggregator(input_dim, hidden_dim,
                                                aggr_method=aggr_neighbor_method)
        self.weight = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)

    def forward(self, src_node_features, neighbor_node_features):
        neighbor_hidden = self.aggregator(neighbor_node_features)
        self_hidden = torch.matmul(src_node_features, self.weight)

        if self.aggr_hidden_method == "sum":
            hidden = self_hidden + neighbor_hidden
        elif self.aggr_hidden_method == "concat":
            hidden = torch.cat([self_hidden, neighbor_hidden], dim=1)
        else:
            raise ValueError("Expected sum or concat, got {}"
                                .format(self.aggr_hidden))
        if self.activation:
            return self.activation(hidden)
        else:
            return hidden

    def extra_repr(self):
        output_dim = self.hidden_dim if self.aggr_hidden_method == "sum" else self.hidden_dim * 2
        return 'in_features={}, out_features={}, aggr_hidden_method={}'.format(
            self.input_dim, output_dim, self.aggr_hidden_method)


class GraphSage(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                    num_neighbors_list):
        super(GraphSage, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_neighbors_list = num_neighbors_list
        self.num_layers = len(num_neighbors_list)
        self.gcn = nn.ModuleList()
        self.gcn.append(SageGCN(input_dim, hidden_dim[0]))
        for index in range(0, len(hidden_dim) - 2):
            self.gcn.append(SageGCN(hidden_dim[index], hidden_dim[index + 1]))
        self.gcn.append(SageGCN(hidden_dim[-2], hidden_dim[-1], activation=None))

    def forward(self, node_features_list):
        hidden = node_features_list

        for l in range(self.num_layers):
            # print(f"========= 第 {l} 层 =========")
            next_hidden = []
            gcn = self.gcn[l]
            for hop in range(self.num_layers - l):
            # print("self.num_layers - l " , self.num_layers - l-1)
            # for hop in range(self.num_layers - l-1,l,-1):
            #     print(f"======== hop {hop} ============ " )
                src_node_features = hidden[hop]
                src_node_num = len(src_node_features)
                # print(" src_node_num = ", src_node_features.shape)

                neighbor_node_features = hidden[hop + 1].view((src_node_num, self.num_neighbors_list[hop], -1))
                # print(" neighbor_node_features = ", neighbor_node_features.shape)

                h = gcn(src_node_features, neighbor_node_features)
                # print(" after gcn h = ", h.shape)

                next_hidden.append(h)
            hidden = next_hidden
            # print("hidden shape = ",len(hidden))
        return hidden[0]

    def extra_repr(self):
        return 'in_features={}, num_neighbors_list={}'.format(
            self.input_dim, self.num_neighbors_list
        )




## 基于pyg的graphsage实现

In [1]:
import os.path as osp
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from tqdm import tqdm

from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv

In [3]:
dataset =Planetoid(root=r"./data",name='Cora')
data = dataset[0]
num_nodes_list = torch.arange(data.num_nodes)
train_idx = num_nodes_list[data['train_mask']]

In [15]:
len(train_idx)

140

### NeighborSampler

https://blog.csdn.net/qq_40671063/article/details/126803861

In [4]:
train_loader = NeighborSampler(data.edge_index, node_idx=train_idx,
                               sizes=[15, 10, 5], batch_size=70,
                               shuffle=True, num_workers=12)
subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],
                                  batch_size=256, shuffle=False,
                                  num_workers=12)

In [16]:
len(train_loader)

2

In [23]:
for batch_size, n_id, adjs in train_loader:
    print(adjs)
    print(batch_size)
    print(n_id)
    break

[EdgeIndex(edge_index=tensor([[ 70,  71,  72,  ..., 537, 789, 889],
        [  0,   0,   0,  ..., 863, 863, 863]]), e_id=tensor([1734, 5180, 8373,  ..., 8329, 3680, 1412]), size=(1431, 864)), EdgeIndex(edge_index=tensor([[ 70,  71,  72,  ..., 199, 200, 863],
        [  0,   0,   0,  ..., 326, 326, 326]]), e_id=tensor([1734, 5180, 8373,  ..., 8660, 8862, 5993]), size=(864, 327)), EdgeIndex(edge_index=tensor([[ 70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
          20, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 115, 137,
         138, 139, 140, 141, 142, 143, 144, 145,  47, 146, 147, 148, 149, 150,
         151, 152, 153, 154, 155, 156, 157,  84, 158, 159,  83,  84,  10, 127,
         128, 160, 161, 162, 117, 163, 164, 

In [26]:
len(n_id)

1431

In [27]:
n_id

tensor([  86,   33,   53,  ..., 1934, 1432, 2610])

In [24]:
for i in n_id:
    print(i)

tensor(86)
tensor(33)
tensor(53)
tensor(56)
tensor(51)
tensor(95)
tensor(32)
tensor(112)
tensor(52)
tensor(83)
tensor(84)
tensor(117)
tensor(120)
tensor(138)
tensor(7)
tensor(43)
tensor(55)
tensor(136)
tensor(68)
tensor(72)
tensor(75)
tensor(102)
tensor(115)
tensor(14)
tensor(39)
tensor(46)
tensor(41)
tensor(8)
tensor(127)
tensor(31)
tensor(92)
tensor(131)
tensor(50)
tensor(106)
tensor(135)
tensor(78)
tensor(80)
tensor(139)
tensor(103)
tensor(62)
tensor(57)
tensor(99)
tensor(59)
tensor(65)
tensor(121)
tensor(3)
tensor(124)
tensor(60)
tensor(133)
tensor(89)
tensor(108)
tensor(100)
tensor(22)
tensor(82)
tensor(16)
tensor(0)
tensor(122)
tensor(54)
tensor(20)
tensor(9)
tensor(79)
tensor(113)
tensor(38)
tensor(134)
tensor(44)
tensor(119)
tensor(98)
tensor(88)
tensor(40)
tensor(18)
tensor(429)
tensor(1336)
tensor(2034)
tensor(2295)
tensor(286)
tensor(588)
tensor(698)
tensor(911)
tensor(1051)
tensor(2040)
tensor(2119)
tensor(2120)
tensor(2121)
tensor(1103)
tensor(1358)
tensor(1739)
tensor(412

### SAGEConv

In [6]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adjs):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        for i, (edge_index, _, size) in enumerate(adjs):
            # 对每一层的bipartite图都有x_target = x[:size[1]]
            x_target = x[:size[1]]  # Target nodes are always placed first.目标节点放在最前面，一共有size[1]个目标节点
            # 实现了对一层bipartite图的卷积。可以把卷积就理解为聚合操作，这里就是逐层聚合，从第L层到第1层
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:  # 不是最后一层就执行下面的操作
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x.log_softmax(dim=-1)

    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        total_edges = 0
        for i in range(self.num_layers): # 一共有l层
            xs = []
            # 一个batchsize中的目标节点采样L=1层涉及到的所有节点
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                total_edges += edge_index.size(1)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                xs.append(x.cpu())

                pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 64, dataset.num_classes, num_layers=3)
model = model.to(device)

x = data.x.to(device)
y = data.y.squeeze().to(device)
criterion = nn.CrossEntropyLoss().to(device)

In [11]:
def train(epoch):
    model.train()

    pbar = tqdm(total=train_idx.size(0))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = 0

    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs = [adj.to(device) for adj in adjs]

        optimizer.zero_grad()
        out = model(x[n_id], adjs)  # x[n_id]这个batchsize中的目标节点采样L层涉及到的所有节点
        loss = criterion(out, y[n_id[:batch_size]])
        loss.backward()
        optimizer.step()

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
        pbar.update(batch_size)

    pbar.close()

    loss = total_loss / len(train_loader)
    approx_acc = total_correct / train_idx.size(0)

    return loss, approx_acc

In [12]:
@torch.no_grad()
def test():
    model.eval()

    out = model.inference(x)

    y_true = y.cpu().unsqueeze(-1)
    y_pred = out.argmax(dim=-1, keepdim=True)
    correct = (y_pred == y_true).sum().item()
    test_acc = correct/data.num_nodes
    return test_acc

In [13]:
test_accs = []
for run in range(1, 2):#11
    print('')
    print(f'Run {run:02d}:')
    print('')

    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

    best_val_acc = final_test_acc = 0
    for epoch in range(1, 10):#51
        loss, acc = train(epoch)
        print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')


    test_acc = test()
    print(f'Test: {test_acc:.4f}')
    test_accs.append(test_acc)

test_acc = torch.tensor(test_accs)
print('============================')
print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}')


Run 01:



Epoch 01: 100%|██████████| 140/140 [00:10<00:00, 13.54it/s]


Epoch 01, Loss: 1.9410, Approx. Train: 0.1429


Epoch 02: 100%|██████████| 140/140 [00:09<00:00, 14.54it/s]


Epoch 02, Loss: 1.8810, Approx. Train: 0.3643


Epoch 03: 100%|██████████| 140/140 [00:10<00:00, 13.43it/s]


Epoch 03, Loss: 1.7864, Approx. Train: 0.6500


Epoch 04: 100%|██████████| 140/140 [00:10<00:00, 13.48it/s]


Epoch 04, Loss: 1.6264, Approx. Train: 0.8286


Epoch 05: 100%|██████████| 140/140 [00:10<00:00, 13.36it/s]


Epoch 05, Loss: 1.4584, Approx. Train: 0.8214


Epoch 06: 100%|██████████| 140/140 [00:10<00:00, 12.99it/s]


Epoch 06, Loss: 1.2353, Approx. Train: 0.8143


Epoch 07: 100%|██████████| 140/140 [00:10<00:00, 13.00it/s]


Epoch 07, Loss: 0.9877, Approx. Train: 0.8857


Epoch 08: 100%|██████████| 140/140 [00:11<00:00, 12.43it/s]


Epoch 08, Loss: 0.7243, Approx. Train: 0.9214


Epoch 09: 100%|██████████| 140/140 [00:12<00:00, 11.52it/s]


Epoch 09, Loss: 0.4996, Approx. Train: 0.9286


Evaluating: 100%|██████████| 8124/8124 [00:32<00:00, 246.61it/s]

Test: 0.8002
Final Test: 0.8002 ± nan



