图遍历,在GNN里的语义有别于经典的图计算。主流深度学习算法的训练模式会按batch迭代。为了满足这种要求,数据要能够按batch访问,我们把这种数据的访问模式称为遍历。在GNN算法中,数据源为图,训练样本通常由图的顶点和边构成。图遍历是指为算法提供按batch获取顶点、边或子图的能力。
目前GL支持顶点和边的batch遍历。这种随机遍历可以是无放回的,也可以是有放回的。在无放回遍历中,每当一个epoch结束后都会触发gl.OutOfRangeError
。被遍历的数据源是划分后的,即当前worker(以分布式TF为例)只遍历与其对应的Server上的数据。
顶点的数据来源有3种:所有unique的顶点,所有边的源顶点,所有边的目的顶点。顶点遍历依托NodeSampler
算子实现,Graph对象的node_sampler()
接口返回一个NodeSampler
对象,再调用该对象的get()
接口返回Nodes
格式的数据。
def node_sampler(type, batch_size=64, strategy="by_order", node_from=gl.NODE):
"""
Args:
type(string): 当node_from为gl.NODE时,为顶点类型,否则为边类型;
batch_size(int): 每次遍历的顶点数
strategy(string): 可选值为"by_order"和"random",表示无放回遍历和随机遍历
当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
node_from: 数据来源,可选值为gl.NODE、gl.EDGE_SRC、gl.EDGE_DST;
Return:
NodeSampler对象
"""
def NodeSampler.get():
"""
Return:
Nodes对象,若非触底,预期ids的shape为[batch_size]
"""
通过Nodes
对象获取具体的值,如id、weight、attribute等,参考API。在GSL中,顶点遍历参考g.V()
。
id | attributes |
---|---|
10001 | 0:0.1:0 |
10002 | 1:0.2:3 |
10003 | 3:0.3:4 |
sampler = g.node_sampler("user", batch_size=3, strategy="random")
for i in range(5):
nodes = sampler.get()
print(nodes.ids)
print(nodes.int_attrs)
print(nodes.float_attrs)
边遍历依托EdgeSampler
算子实现。Graph对象的edge_sampler()
接口返回一个EdgeSampler
对象,再调用该对象的get()
接口返回Edges
格式的数据。
def edge_sampler(edge_type, batch_size=64, strategy="by_order"):
"""
Args:
edge_type(string): 边类型
batch_size(int): 每次遍历的边数
strategy(string): 可选值为"by_order"和"random",表示无放回遍历和随机遍历
当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
Return:
EdgeSampler对象
"""
def EdgeSampler.get():
"""
Return:
Edges对象,若非触底,预期src_ids的shape为[batch_size]
"""
通过Edges
对象获取具体的值,如id、weight、attribute等,参考API。在GSL中,边遍历参考g.E()
。
src_id | dst_id | weight | attributes |
---|---|---|---|
20001 | 30001 | 0.1 | 0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19 |
20001 | 30003 | 0.2 | 0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29 |
20003 | 30001 | 0.3 | 0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39 |
20004 | 30002 | 0.4 | 0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49 |
sampler = g.edge_sampler("buy", batch_size=3, strategy="random")
for i in range(5):
edges = sampler.get()
print(edges.src_ids)
print(edges.src_ids)
print(edges.weights)
print(edges.float_attrs)