<a href="https://colab.research.google.com/github/zmgy107/DGL-Learning-Notes/blob/main/Chapter_3_Building_GNN_Modules.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install  dgl -f https://data.dgl.ai/wheels/cu117/repo.html
%pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html

import dgl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/cu117/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/cu117/dgl-1.0.1%2Bcu117-cp39-cp39-manylinux1_x86_64.whl (266.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.0.1+cu117
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dglgo
  Downloading dglgo-0.0.2-py3-none-any.whl (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 KB[0m [31m410.2 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting ruamel.yaml>=0.17.20
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.5/109.5 KB

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


DGL NN module consists of building blocks for GNN models. An NN module inherits from Pytorch’s NN Module, MXNet Gluon’s NN Block and TensorFlow’s Keras Layer, depending on the DNN framework backend in use. In a DGL NN module, the parameter registration in construction function and tensor operation in forward function are the same with the backend framework. In this way, DGL code can be seamlessly integrated into the backend framework code. The major difference lies in the message passing operations that are unique in DGL.

This chapter takes SAGEConv with Pytorch backend as an example to introduce how to build a custom DGL NN Module.

# DGL NN Module Construction Function

The construction function performs the following steps:

1. Set options.

2. Register learnable parameters or submodules.

3. Reset parameters.

In [9]:
import torch.nn as nn
from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
  def __init__(self,
         in_feats,
         out_feats,
         aggregator_type,
         bias=True,
         norm=None,
         activation=None):
    super(SAGEConv,self).__init__()

    self._in_src_feats,self._in_dst_feats=expand_as_pair(in_feats)
    self._out_feats=out_feats
    self._aggre_type=aggregator_type
    self.norm=norm
    self.activation=activation

    # Part 2:aggregator type:mean,pool,lstm,gcn
    if aggregator_type not in ['mean','pool','lstm','gcn']:
      raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
    if aggregator_type=='pool':
      self.fc_loop=nn.Linear(self._in_src_feats,self._in_src_feats)
    if aggregator_type=='lstm':
      self.lstm=nn.LSTM(self._in_src_feats,self._in_src_feats,batch_first=True)
    if aggregator_type in ['mean','pool','lstm']:
      self.fc_self=nn.Linear(self._in_dst_feats, out_feats, bias=bias)
    self.fc_neigh=nn.Linear(self._in_src_feats, out_feats, bias=bias)
    self.reset_parameters()

  # Part 3:
  def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain=nn.init.calculate_gain('relu')
    if self._aggre_type=='pool':
      nn.init.xavier_uniform_(self.fc_pool.weight,gain=gain)
    if self.aggre_type=='lstm':
      self.lstm.reset_parameters()
    if self._aggre_type!='gcn':
      nn.init.xavier_uniform(self.fc_self.weightm,gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight,gain=gain)

Part 2:In construction function, one first needs to set the data dimensions. For general PyTorch module, the dimensions are usually input dimension, output dimension and hidden dimensions. For graph neural networks, the input dimension can be split into source node dimension and destination node dimension.(use expand_as_pair function)

Besides data dimensions, a typical option for graph neural network is aggregation type (self._aggre_type). __Aggregation type determines how messages on different edges are aggregated for a certain destination node.__ Commonly used aggregation types include mean, sum, max, min. Some modules may apply more complicated aggregation like an lstm

norm here is a callable function for feature normalization. In the SAGEConv paper, such normalization can be $L_2$ normalization:$h_v=\frac{h_v}{||h_v||_2}$.

Part 3:Register parameters and submodules. In SAGEConv, submodules vary according to the aggregation type. Those modules are pure PyTorch nn modules like nn.Linear, nn.LSTM, etc. At the end of construction function, weight initialization is applied by calling reset_parameters().

# DGL NN Module Forward Function

In NN module, forward() function does the actual message passing and computation. Compared with PyTorch’s NN module which usually takes tensors as the parameters, DGL NN module takes an additional parameter __dgl.DGLGraph__. The workload for forward() function can be split into three parts:

*  Graph checking and graph type specification.

*  Message passing.

*  Feature update.

The rest of the section takes a deep dive into the forward() function in SAGEConv example.

## Graph checking and graph type specification

forward() needs to handle many corner cases on the input that can lead to invalid values in computing and message passing. __One typical check in conv modules like GraphConv is to verify that the input graph has no 0-in-degree nodes.__ When a node has 0 in-degree, the mailbox will be empty and the reduce function will produce all-zero values. This may cause silent regression in model performance. __However, in SAGEConv module, the aggregated representation will be concatenated with the original node feature, the output of forward() will not be all-zero.__ No such check is needed in this case.

DGL NN module should be reusable across different types of graph input including: homogeneous graph, heterogeneous graph, subgraph block 

The math formulas for SAGEConv are:

$$h_{N(dst)}^{(l+1)}=aggregate(\{h_{src}^l,∀src\in N(dst)\})$$

$$h_{dst}^{(l+1)}=\sigma(W\cdot concat(h_{dst}^l,h_{N(dst)}^{l+1})+b)$$

$$h_{dst}^{(l+1)}=norm(h_{dst}^{l+1})$$

One needs to specify the source node feature feat_src and destination node feature feat_dst according to the graph type. expand_as_pair() is a function that specifies the graph type and expand feat into feat_src and feat_dst. The detail of this function is shown below.


In [None]:
def forward(self,graph,feat):
  with graph.local_scope():
    # Specify graph type then expand input feature according to graph type
    feat_src,feat_dst=expand_as_pair(feat,graph)

def expand_as_pair(input,g=None):
  if isinstance(input_,tuple):
    # Bipartite graph case
    return input_
  #这个部分不太懂
  elif g is not None and g.is_block:
    # Subgraph block case
    if isinstance(input_,Mapping):
      input_dst={
          k:F.narrow_row(v,0,g.number_of_dst_nodes(k))
          for k,v in input_.items()}
    else:
      input_dst=F.narrow_row(input_,0,g.number_of_dst_nodes())
    return input_,input_dst
  else:
    # Homogeneous graph case
    return input_,input_

For homogeneous whole graph training, source nodes and destination nodes are the same. They are all the nodes in the graph.

For heterogeneous case, the graph can be split into several bipartite graphs, one for each relation. The relations are represented as (src_type, edge_type, dst_dtype). When it identifies that the input feature feat is a tuple, it will treat the graph as bipartite. The first element in the tuple will be the source node feature and the second element will be the destination node feature.

In mini-batch training, the computing is applied on a subgraph sampled based on a bunch of destination nodes. The subgraph is called as block in DGL. In the block creation phase, dst_nodes are in the front of the node list. One can find the feat_dst by the index [0:g.number_of_dst_nodes()].

After determining feat_src and feat_dst, the computing for the above three graph types are the same.

## Message passing and reducing

The code actually does message passing and reducing computing. This part of code varies module by module. __Note that all the message passing in the above code are implemented using update_all() API and built-in message/reduce functions to fully utilize DGL’s performance optimization.__

In [None]:
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape

if self.aggre_type=='mean':
  graph.srcdata['h']=feat_src
  graph.update_all(fn.copy_u('h','m'),fn.mean('m','neigh'))
  h_neigh=graph.dstdata['neigh']
elif self.aggre_type=='gcn':
  check_eq_shape(feat)
  graph.srcdata['h']=feat_src
  graph.dstdata['h']=feat_dst
  graph.update_all(fn.copy_u('h','m'),fn.sum('m','neigh'))
  # divide in degrees
  degs=graph.in_degrees().to(feat_dst)
  h_neigh=(graph.dstdata['neigh']+graph.dstdata['h'])/(degs.unsqueeze(-1)+1)
elif self.__aggre_type=='pool':
   graph.srcdata['h']=F.relu(self,fc_pool(feat_src))
   graph.update_all(fn.copy_u('h','m'),fn.max('m','neigh'))
   h_neigh=graph.dstdata['neigh']
else:
  raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self
if self._aggre_type=='gcn':
  rst=self.fc_neigh(h_neigh)
else:
  rst=self.fc_self(h_self)+self.fc_neigh(h_neigh)

## Update feature after reducing for output

In [None]:
# activation
if self.activation is not None:
  rst=self.activation(rst)
# normalization
  rst=self.norm(rst)
return rst

The last part of forward() function is to update the feature after the reduce function. Common update operations are applying activation function and normalization according to the option set in the object construction phase.

## Heterogeneous GraphConv Module

HeteroGraphConv is a module-level encapsulation to run DGL NN module on heterogeneous graphs. The implementation logic is the same as message passing level API multi_update_all(), including:

*  DGL NN module within each relation r

*  Reduction that merges the results on the same node type from multiple relations

This can be formulated as:

$$h_{dst}^{(l+1)}=AGG_{r\in R,r_{dst}=dst}(f_r(g_r,h_{r_{src}}^l,h_{r_{dst}}^l))$$

where $f_r$ is the NN module for each relation r,AGG is the aggregation function

In [None]:
import torch.nn as nn

class HeteroGraphConv(nn.Module):
  def __init__(self,mods,aggregate='sum'):
    super(HeterGraphConv,self).__init__()
    # dictionary mods:maps each relation to an nn module and
    # sets the function that aggregates results on the same node
    # type from multiple relations
    self.mods=nn.ModuleDict(mods)
    if isinstance(aggregate,str):
      # An internal function to get common aggregation functions
      self.agg_fn=get_aggregate_fn(aggregate)
    else:
      self.agg_fn=aggregate
    
  def forward(self,g,inputs,mod_args=None,mod_kwargs=None):
    if mod_args is None:
      mod_args={}
    if mod_kwargs is None:
      mod_kwargs={}
    outputs={nty:[] for nty in g.dsttypes}

Besides input graph and input tensors, the forward() function takes two additional dictionary parameters mod_args and mod_kwargs. These two dictionaries have the same keys as self.mods. They are used as customized parameters when calling their corresponding NN modules in self.mods for different types of relations.

An output dictionary is created to hold output tensor for each destination type nty . __Note that the value for each nty is a list, indicating a single node type may get multiple outputs if more than one relations have nty as the destination type.__ HeteroGraphConv will perform a further aggregation on the lists.

In [None]:
if g.is_block:
  src_inputs=inputs
  dst_input={k:v[:g.number_of_dst_nodes(k)] for k,v in inputs.items()}
else:
  src_inputs=dst_inputs=inputs

for stype,etype,dtype in g.canonical_etypes:
  rel_graph=g[stype,etype,dtype]
  if rel_graph.num_edges()==0:
    continue
  if stype not in src_inputs or dtype not in dst_inputs:
    continue
  dstdata=self.mods[etype](
      rel_graph,
      (src_inputs[stype],dst_inputs[dtype]),
      *mod_args.get(etype,()),
      **mod_kwargs.get(etype,{}))
  outputs[dtype].append(dstdata)

The input g can be a heterogeneous graph or a subgraph block from a heterogeneous graph. As in ordinary NN module, the forward() function need to handle different input graph types separately.

Each relation is represented as a canonical_etype, which is (stype, etype, dtype). Using canonical_etype as the key, one can extract out a bipartite graph rel_graph. For bipartite graph, the input feature will be organized as a tuple (src_inputs[stype], dst_inputs[dtype]). The NN module for each relation is called and the output is saved. To avoid unnecessary call, relations with no edges or no nodes with the src type will be skipped.

In [None]:
rsts={}
for nty,alist in outputs.items():
  if len(alist)!=0:
    rsts[nty]=self.agg_fn(alist,nty)

Finally, the results on the same destination node type from multiple relations are aggregated using self.agg_fn function. Examples can be found in the API Doc for HeteroGraphConv.