# Lab 04 : Gated graph convolutional networks for chemical regression - exercise

Bresson, Laurent, Residual Gated Graph ConvNets, 2017  
https://arxiv.org/pdf/1711.07553v2.pdf

### Xavier Bresson

<br>
Notebook goals :<br>  
• Implement a two-layer MLP for regression <br> 
• Define the Mean Absolute Error (MAE) as regression loss <br> 


In [None]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/GML_May23_codes/codes/05_Graph_Convnets'
    print(path_to_file)
    # change current path to the folder containing "path_to_file"
    os.chdir(path_to_file)
    !pwd
    !pip install dgl # Install DGL
    

In [None]:
# Libraries
import pickle
from utils import Dictionary, MoleculeDataset, MoleculeDGL, Molecule
from torch.utils.data import DataLoader
import dgl
import torch
import torch.nn as nn
import time


# Load molecular datasets

In [None]:
# Select dataset
dataset_name = 'ZINC'; data_folder_pytorch = '../08_Datasets/ZINC_pytorch/'; data_folder_dgl = '../08_Datasets/ZINC_dgl/'
dataset_name = 'QM9'; data_folder_pytorch = '../08_Datasets/QM9_pytorch/'; data_folder_dgl = '../08_Datasets/QM9_dgl/'

# Load the number of atom and bond types 
with open(data_folder_pytorch + "atom_dict.pkl" ,"rb") as f: num_atom_type = len(pickle.load(f))
with open(data_folder_pytorch + "bond_dict.pkl" ,"rb") as f: num_bond_type = len(pickle.load(f))
print(num_atom_type)
print(num_bond_type)

# Load the DGL datasets
datasets_dgl = MoleculeDataset(dataset_name, data_folder_dgl)
trainset, valset, testset = datasets_dgl.train, datasets_dgl.val, datasets_dgl.test
print(len(trainset))
print(len(valset))
print(len(testset))
idx = 0
print(trainset[:2])
print(valset[idx])
print(testset[idx])



# Generate a batch of graphs and test it

In [None]:
# collate function prepares a batch of graphs, labels and other graph features
def collate(samples):
    
    # Input sample is a list of pairs (graph, label)
    graphs, labels = map(list, zip(*samples))
    batch_graphs = dgl.batch(graphs)    # batch of graphs
    batch_labels = torch.stack(labels)  # batch of labels (here class label)
    
    # Normalization w.r.t. graph sizes
    tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
    tab_norm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
    batch_norm_n = torch.cat(tab_norm_n).sqrt()  
    tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
    tab_norm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
    batch_norm_e = torch.cat(tab_norm_e).sqrt()
    
    return batch_graphs, batch_labels, batch_norm_n, batch_norm_e


# Generate a batch of graphs
batch_size = 10
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)
batch_graphs, batch_labels, batch_norm_n, batch_norm_e = list(train_loader)[0]
print(batch_graphs)
print(batch_labels)
print('batch_norm_n:',batch_norm_n.size())
print('batch_norm_e:',batch_norm_e.size())
batch_x = batch_graphs.ndata['feat']
print('batch_x:',batch_x.size())
batch_e = batch_graphs.edata['feat']
print('batch_e:',batch_e.size())


# Design the class of GatedGCN networks with DGL

Node and edge update equations:  
\begin{eqnarray}
h_i^{\ell+1} &=& h_i^{\ell} + \text{ReLU} \left( A^\ell h_i^{\ell} +  \sum_{j\sim i} \eta(e_{ij}^{\ell}) \odot B^\ell h_j^{\ell} \right), \quad \eta(e_{ij}^{\ell}) = \frac{\sigma(e_{ij}^{\ell})}{\sum_{j'\sim i} \sigma(e_{ij'}^{\ell}) + \varepsilon} \\
e_{ij}^{\ell+1} &=& e^\ell_{ij} + \text{ReLU} \Big( C^\ell e_{ij}^{\ell} + D^\ell h^{\ell+1}_i + E^\ell h^{\ell+1}_j  \Big)
\end{eqnarray}

In [None]:
# Define a two-layer MLP for regression 
class MLP_layer(nn.Module): 
    
    def __init__(self, input_dim, hidden_dim): 
        super(MLP_layer, self).__init__()
        # You may use "nn.Linear"
        self.linear1 = ### YOUR CODE HERE
        self.linear2 = ### YOUR CODE HERE
        
    def forward(self, x):
        y = ### YOUR CODE HERE
        return y


# class of GatedGCN layer  
class GatedGCN_layer(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(GatedGCN_layer, self).__init__()
        self.A = nn.Linear(input_dim, output_dim, bias=True)
        self.B = nn.Linear(input_dim, output_dim, bias=True)
        self.C = nn.Linear(input_dim, output_dim, bias=True)
        self.D = nn.Linear(input_dim, output_dim, bias=True)
        self.E = nn.Linear(input_dim, output_dim, bias=True)
        self.bn_node_h = nn.BatchNorm1d(output_dim)
        self.bn_node_e = nn.BatchNorm1d(output_dim)

    # Step 1 of message-passing with DGL: 
    #   Node feature and edge features are passed along edges (src/j => dst/i) 
    def message_func(self, edges):
        Bhj = edges.src['Bh'] # Bhj with j/src
        eij = edges.data['Ce'] +  edges.dst['Dh'] + edges.src['Eh'] # Ceij + Dhi + Ehj with dst/i, src/j
        edges.data['e'] = eij # update edge feature value
        return {'Bhj' : Bhj, 'eij' : eij} # send message={Bhj, eij} to node dst/i

    # Step 2 of message-passing with DGL: 
    #   Reduce function collects all messages={Bhj, eij} sent to node dst/i with Step 1
    def reduce_func(self, nodes):
        Ahi = nodes.data['Ah']
        Bhj = nodes.mailbox['Bhj']
        e = nodes.mailbox['eij'] 
        sigmaij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij)
        h = Ahi + torch.sum( sigmaij * Bhj, dim=1 ) / torch.sum( sigmaij, dim=1 ) # hi = Ahi + sum_j eta_ij * Bhj    
        return {'h' : h} # return update node feature hi
    
    def forward(self, g, h, e, snorm_n, snorm_e):
        
        h_in = h # residual connection
        e_in = e # residual connection
        
        g.ndata['h']  = h 
        g.ndata['Ah'] = self.A(h) # linear transformation 
        g.ndata['Bh'] = self.B(h) # linear transformation 
        g.ndata['Dh'] = self.D(h) # linear transformation 
        g.ndata['Eh'] = self.E(h) # linear transformation 
        g.edata['e']  = e 
        g.edata['Ce'] = self.C(e) # linear transformation 
        
        g.update_all(self.message_func,self.reduce_func) # update the node and edge features with DGL
        
        h = g.ndata['h'] # collect the node output of graph convolution
        e = g.edata['e'] # collect the edge output of graph convolution
        
        h = h* snorm_n # normalize activation w.r.t. graph node size
        e = e* snorm_e # normalize activation w.r.t. graph edge size
        
        h = self.bn_node_h(h) # batch normalization  
        e = self.bn_node_e(e) # batch normalization  
        
        h = torch.relu(h) # non-linear activation
        e = torch.relu(e) # non-linear activation
        
        h = h_in + h # residual connection
        e = e_in + e # residual connection
        
        return h, e
    
    
class GatedGCN_net(nn.Module):
    
    def __init__(self, net_parameters):
        super(GatedGCN_net, self).__init__()
        input_dim = net_parameters['input_dim']
        hidden_dim = net_parameters['hidden_dim']
        L = net_parameters['L']
        self.embedding_h = nn.Embedding(num_atom_type, hidden_dim)
        self.embedding_e = nn.Embedding(num_bond_type, hidden_dim)
        self.GatedGCN_layers = nn.ModuleList([ GatedGCN_layer(hidden_dim, hidden_dim) for _ in range(L) ]) 
        self.MLP_layer = MLP_layer(hidden_dim, hidden_dim)
        
    def forward(self, g, h, e, snorm_n, snorm_e):
        
        # input embedding
        h = self.embedding_h(h)
        e = self.embedding_e(e)
        
        # graph convnet layers
        for GGCN_layer in self.GatedGCN_layers:
            h,e = GGCN_layer(g,h,e,snorm_n,snorm_e)
        
        # MLP classifier
        g.ndata['h'] = h
        y = dgl.mean_nodes(g,'h') # DGL mean function over the neighbors
        y = self.MLP_layer(y)
        
        return y    
    
    def loss(self, y_scores, y_labels):
        # Define the Mean Absolute Error (MAE) as regression loss 
        loss = ### YOUR CODE HERE
        return loss        
    
    def update(self, lr):       
        update = torch.optim.Adam( self.parameters(), lr=lr )
        return update


# Instantiate one network (testing)
net_parameters = {}
net_parameters['input_dim'] = 1
net_parameters['hidden_dim'] = 128
net_parameters['L'] = 4
net = GatedGCN_net(net_parameters)
print(net)


# Train the network

In [None]:
def run_one_epoch(net, data_loader, train=True):
    if train:
        net.train() # during training
    else:
        net.eval()  # during inference/test
    epoch_loss = 0
    nb_data = 0
    gpu_mem = 0
    for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader):
        batch_x = batch_graphs.ndata['feat']
        batch_e = batch_graphs.edata['feat']
        batch_snorm_n = batch_snorm_n
        batch_snorm_e = batch_snorm_e
        batch_labels = batch_labels
        batch_scores = net.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e)
        loss = net.loss(batch_scores, batch_labels)
        if train: # during training, run backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss += loss.detach().item()
        nb_data += batch_labels.size(0)
    epoch_loss /= (iter + 1)
    return epoch_loss


# dataset loaders
batch_size = 100
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=datasets_dgl.collate)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, collate_fn=datasets_dgl.collate)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True, collate_fn=datasets_dgl.collate)

# Instantiate one network
net_parameters = {}
net_parameters['input_dim'] = 1
net_parameters['hidden_dim'] = 100
net_parameters['L'] = 4
net = GatedGCN_net(net_parameters)

# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

# training loop
for epoch in range(50):
    start = time.time()
    epoch_train_loss = run_one_epoch(net, train_loader, True)
    with torch.no_grad(): 
        epoch_test_loss = run_one_epoch(net, test_loader, False)
        epoch_val_loss = run_one_epoch(net, val_loader, False)  
    print('Epoch {}, time {:.4f}, train_loss: {:.4f}, test_loss: {:.4f}, val_loss: {:.4f}'.format(epoch, time.time()-start, epoch_train_loss, epoch_test_loss, epoch_val_loss))
   