Skip to content

Latest commit

 

History

History
104 lines (89 loc) · 3.84 KB

graph_traverse_cn.md

File metadata and controls

104 lines (89 loc) · 3.84 KB

图遍历

1. 介绍

图遍历,在GNN里的语义有别于经典的图计算。主流深度学习算法的训练模式会按batch迭代。为了满足这种要求,数据要能够按batch访问,我们把这种数据的访问模式称为遍历。在GNN算法中,数据源为图,训练样本通常由图的顶点和边构成。图遍历是指为算法提供按batch获取顶点、边或子图的能力。

目前GL支持顶点和边的batch遍历。这种随机遍历可以是无放回的,也可以是有放回的。在无放回遍历中,每当一个epoch结束后都会触发gl.OutOfRangeError。被遍历的数据源是划分后的,即当前worker(以分布式TF为例)只遍历与其对应的Server上的数据。

2. 顶点遍历

2.1 用法

顶点的数据来源有3种:所有unique的顶点,所有边的源顶点,所有边的目的顶点。顶点遍历依托NodeSampler算子实现,Graph对象的node_sampler()接口返回一个NodeSampler对象,再调用该对象的get()接口返回Nodes格式的数据。

def node_sampler(type, batch_size=64, strategy="by_order", node_from=gl.NODE):
"""
Args:
  type(string):     当node_from为gl.NODE时,为顶点类型,否则为边类型;
  batch_size(int):  每次遍历的顶点数
  strategy(string): 可选值为"by_order"和"random",表示无放回遍历和随机遍历
      当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
  node_from:        数据来源,可选值为gl.NODE、gl.EDGE_SRC、gl.EDGE_DST;
Return:
  NodeSampler对象
"""
def NodeSampler.get():
"""
Return:
    Nodes对象,若非触底,预期ids的shape为[batch_size]
"""


通过Nodes对象获取具体的值,如id、weight、attribute等,参考API。在GSL中,顶点遍历参考g.V()

2.2 示例

id attributes
10001 0:0.1:0
10002 1:0.2:3
10003 3:0.3:4
sampler = g.node_sampler("user", batch_size=3, strategy="random")
for i in range(5):
    nodes = sampler.get()
    print(nodes.ids)
    print(nodes.int_attrs)
    print(nodes.float_attrs)

3. 边遍历

3.1 用法

边遍历依托EdgeSampler算子实现。Graph对象的edge_sampler()接口返回一个EdgeSampler对象,再调用该对象的get()接口返回Edges格式的数据。

def edge_sampler(edge_type, batch_size=64, strategy="by_order"):
"""
Args:
  edge_type(string): 边类型
  batch_size(int):   每次遍历的边数
  strategy(string):  可选值为"by_order"和"random",表示无放回遍历和随机遍历
      当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
Return:
  EdgeSampler对象
"""
def EdgeSampler.get():
"""
Return:
    Edges对象,若非触底,预期src_ids的shape为[batch_size]
"""


通过Edges对象获取具体的值,如id、weight、attribute等,参考API。在GSL中,边遍历参考g.E()

3.2 示例

src_id dst_id weight attributes
20001 30001 0.1 0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19
20001 30003 0.2 0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29
20003 30001 0.3 0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39
20004 30002 0.4 0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49
sampler = g.edge_sampler("buy", batch_size=3, strategy="random")
for i in range(5):
    edges = sampler.get()
    print(edges.src_ids)
    print(edges.src_ids)
    print(edges.weights)
    print(edges.float_attrs)