In [5]:
import torch 
from torch_geometric.datasets import QM9
from torch.nn.parameter import Parameter
import numpy as np
import math
import torch.nn.functional as F

In [7]:
# graph convolution layer (this performs two convolutions)
class BasicGraphConvolutionLayer(torch.nn.Module):
    def __init__(self,in_chanels,out_channels):
        super().__init__()
        self.in_chanels = in_chanels
        self.out_chanels = out_channels
        self.w1 = Parameter(torch.rand(in_chanels,out_channels,dtype=torch.float32))
        self.w2 = Parameter(torch.rand(in_chanels,out_channels,dtype=torch.float32))
        self.bias = Parameter(torch.zeros(out_channels,dtype=torch.float32))
        
    def forward(self,X,A):
        potential_msgs = torch.mm(X,self.w2)
        propergated_msgs = torch.mm(A,potential_msgs)
        root_update = torch.mm(X,self.w1)
        output = propergated_msgs + root_update + self.bias 
        return output

In [None]:
# sum graph features: one feature mat for each graph (global pooling layer)
def global_sum_pool(X,batch_mat):
    if batch_mat is None or batch_mat.dim() == 1:
        return torch.sum(X,dim=0).unsqueeze(dim=0)
    else:
        return torch.mm(batch_mat,X)

In [8]:
# this fucntion takes the size of graphs in a btach and return graph selection mask
# using this graph selection mask, you can add the features of each graphs together 
def get_batch_tensor(graph_sizes):
    starts= [sum(graph_sizes[:idx]) for idx in range(len(graph_sizes))]
    stops = [starts[idx] + graph_sizes[idx] for idx in range(len(graph_sizes))]
    total_len = sum(graph_sizes)
    batch_size =len(graph_sizes)
    batch_mat = torch.zeros([batch_size,total_len])
    
    for idx,starts_and_stops in enumerate(zip(starts,stops)):
        start = starts_and_stops[0]
        stop = starts_and_stops[1]
        batch_mat[idx,start:stop] = 1
    return batch_mat

In [9]:
def collate_grap(batch):
    adj_mats = [graph['A'] for graph in batch]
    sizes = [A.size(0) for A in adj_mats]
    tot_size = sum(sizes)
    batch_mat = get_batch_tensor(sizes)
    feat_mats = torch.cat([graph['X'] for graph in batch],dim=0)
    labels = torch.cat([graph['y'] for graph in batch],dim=0)
    batch_adj = torch.zeros([tot_size,tot_size],dtype=torch.float32)
    accum = 0 
    for adj in adj_mats:
        g_size = adj.shape[0]
        batch_adj[accum:accum + g_size,accum:accum + g_size] = adj
        accum = accum + g_size
    repr_and_label = {'A':batch_adj,'X':feat_mats,'y':labels,'batch':batch_mat}
    return repr_and_label

In [None]:
class NodeNetwork(torch.nn.Module):
    def __init__(self,input_features):
        super().__init__()
        self.conv_1 = BasicGraphConvolutionLayer(input_features,32)
        self.conv_2 = BasicGraphConvolutionLayer(32,32)
        self.fc_1 = torch.nn.Linear(32,16)
        self.out_layer = torch.nn.Linear(16,2)
        
    def forwrd(self,X,A,batch_mat):
        x = F.relu(self.conv_1(X,A))
        x = F.relu(self.conv_2(x,A))
        output = global_sum_pool(x,batch_mat)
        output = self.fc_1(output)
        output = self.out_layer(output)
        return F.softmax(output,dim=1)