# Chapter 3: Building GNN Modules

DGL NN module consists of **building blocks** for GNN models. An NN module **inherits** from Pytorch's NN Module.  

This chapter takes `SAGEConv` with Pytorch backend as an example to introduce how to build a custom DGL NN Module.

## 3.1 DGL NN Module Construction Function

In [1]:
import torch.nn as nn
from dgl.utils import expand_as_pair

Construction function, one first needs to set the data dimensions. For graph neural networks, the **input dimension** can be split into source node dimension and destination node dimension.  

Aggregation type determins how messages on different edges are aggregated for a certain destination node. Commonly used include `mean`, `sum`, `max`, `min`, or even complicated one `lstm`.

`norm` here is a callable function for feature normalization.

## 3.2 DGL NN Module Forward Function

In NN module, `forward()` function does the actual message passing and computation.  
Compared with PyTorch's NN module which usually takes tensors as the parameters, DGL NN module takes an additional parameter `dgl.DGLGraph`.  

The workload for `forward()` function can be split into three parts:
 - graph checking and graph type specification
 - message passing
 - feature update

In [6]:
class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None):
        super(SAGEConv, self).__init__()
        
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
    
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_ppool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init,xavier_uniform_(self.fc_neigh.weight, gain=gain)
        
    def forward(self, graph, feat):
        with graph.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, graph)