In [2]:
import torch 
from torch import nn 
from torch.nn import Sequential, Linear, ReLU
from torch.nn import Embedding, Module, ModuleList

float_type = torch.float32  
categorical_type = torch.long
mask_type = torch.float32  
labels_type = torch.float32

In [4]:
from collections.abc import Set

# get the node and edge features (stacked) from mini-batches of graphs 
def collate_graph_batch(batch):
    batch_size = len(batch)
    max_nodes = max(len(graph['nodes']) for graph in batch)
    adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes), dtype=float_type)
    labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
    
    stacked_continuous_node_features = None
    stacked_categorical_node_features = None
    stacked_continuous_edge_features = None
    stacked_categorical_edge_features = None
    
    nodes_mask = torch.zeros((batch_size,max_nodes),dtype=mask_type)
    edges_mask = torch.zeros((batch_size,max_nodes,max_nodes),dtype=mask_type)
    
    has_metadata = False
    
    for i, graph in enumerate(batch):
        if 'metadata' in graph:
            has_metadata = True
        
        adjacency_matrix = graph['adjacency_matrix']
        g_nodes,g_nodes = adjacency_matrix.shape
        adjacency_matrices[i,:g_nodes,:g_nodes] = adjacency_matrix
        
        nodes_mask[i,:g_nodes] = 1
        edges_mask[i,:g_nodes,:g_nodes] =1 
        
        g_continuous_node_features = graph['continuous_node_features']
        if g_continuous_node_features is not None:
            if stacked_continuous_node_features is None:
                g_nodes, num_features = g_continuous_node_features.shape
                stacked_continuous_node_features = torch.zeros((batch_size, max_nodes, num_features))
            stacked_continuous_node_features[i,:g_nodes] = g_continuous_node_features
        
        g_categorical_node_features = graph['categorical_node_features']
        if g_categorical_node_features is not None:
            if stacked_categorical_node_features is None:
                g_nodes, num_features = g_categorical_node_features.shape
                stacked_categorical_node_features = torch.zeros((batch_size, max_nodes, num_features), dtype=categorical_type)
            stacked_categorical_node_features[i, :g_nodes] = g_categorical_node_features
        
        g_continuous_edge_features = graph['continuous_edge_features']
        if g_continuous_edge_features is not None:
            if stacked_continuous_edge_features is None:
                g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
                stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features))
            stacked_continuous_edge_features[i, :g_nodes, :g_nodes] = g_continuous_edge_features
        
        g_categorical_edge_features = graph['categorical_edge_features']
        if g_categorical_edge_features is not None:
            if stacked_categorical_edge_features is None:
                g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
                stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features), dtype=categorical_type)
            stacked_categorical_edge_features[i, :g_nodes, :g_nodes] = g_categorical_edge_features
            
        batch_record = {'adjacency_matrices': adjacency_matrices,
            'categorical_node_features': stacked_categorical_node_features,
            'continuous_node_features': stacked_continuous_node_features,
            'categorical_edge_features': stacked_categorical_edge_features,
            'continuous_edge_features': stacked_continuous_edge_features,
            'nodes_mask': nodes_mask,
            'edge_mask': edges_mask,
            'labels': labels}
        
    if has_metadata:
        batch_record['metadata'] = [g['metadata'] for g in batch]

    return batch_record

In [None]:
# then we can apply embedder and feature combiner functions (from the previous notebook) on the batched node and edge data

'''
# masked node feature 

node_featurizer = FeatureCombiner(dataset.categorical_node_variables, node_embedding_dim)
categorical_node_features = example_batch['categorical_node_features']
continuous_node_features = example_batch['continuous_node_features']
node_features = node_featurizer(continuous_node_features, categorical_node_features)
masked_node_features = node_features * example_batch['nodes_mask'].unsqueeze(dim=-1)

# masked edge features 

edge_featurizer = FeatureCombiner(dataset.categorical_edge_variables, edge_embedding_dim)
categorical_edge_features = example_batch['categorical_edge_features']
continuous_edge_features = example_batch['continuous_edge_features']
edge_features = edge_featurizer(continuous_edge_features, categorical_edge_features)
masked_edge_features = edge_features * example_batch['edge_mask'].unsqueeze(dim=-1)

'''

In [5]:
# batched graph layer

class BasicGraphLayer(Module):
    def __init__(self, input_dim, output_dim, ffn_dim):
        super().__init__()    
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.ffn_dim = ffn_dim
        self.neighbour_edges_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.center_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.output_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
    
    # node_feat: [B,N,d_n], edge_feat: [B,N,N,d_n], adj_mat = [B,N,N]
    def forward(self,adjacency_matrix,node_features,edge_features,node_mask,edge_mask):
        center_updated_node_features = self.center_mlp(node_features)  # [B,N,d_n] -> [B,N,d_o]
        # edge and node features 
        edge_and_node_features = edge_features + node_features.unsqueeze(dim=-2) # [B,N,N,d_n] + [B,1,N,d_n] -> [B,N,N,d_n]
        neighbourhood = self.neighbour_edges_mlp(edge_and_node_features) # [B,N,N,d_n] -> [B,N,N,d_o]
        # select connected features 
        masked_edge_and_node_features = neighbourhood * adjacency_matrix.unsqueeze(dim=-1) # [B,N,N,d_o] * [B,N,N,1]
        # masked out using the edge mask 
        masked_edge_and_node_features = masked_edge_and_node_features * edge_mask.unsqueeze(dim=-1) # [B,N,N,d_o] * [B,N,N,1]
        # combine the node and edge features of neighbors 
        reduced_neighbourhoods = masked_edge_and_node_features.sum(dim=-2) # [B,N,N,d_o] -> [B,N,d_ffnd_o]
        # add the neighbor values to node and update using mlp
        aggregated_neighbourhoods = reduced_neighbourhoods + center_updated_node_features # [B,N,d_o] + [B,N,d_o] -> [B,N,d_o]
        updated_node_features = self.output_mlp(aggregated_neighbourhoods) # [B,N,d_o]
        # masked out the nodes 
        masked_updated_features = updated_node_features * node_mask.unsqueeze(dim=-1) # [B,N,d_0] * [B,N,1]
        
        return masked_updated_features

In [None]:
'''
batched_graph_layer = BasicGraphLayer(d_model, d_model, ffn_dim)

adjacency_matrix = example_batch['adjacency_matrices']
nodes_mask = example_batch['nodes_mask']
edge_mask = example_batch['edge_mask']
batched_graph_layer(adjacency_matrix, masked_node_features, masked_edge_features, nodes_mask, edge_mask)

'''