In [1]:
import torch 
from torch import nn 
from torch.nn import Sequential, Linear, ReLU
from torch.nn import Embedding, Module, ModuleList

float_type = torch.float32  
categorical_type = torch.long
mask_type = torch.float32  
labels_type = torch.float32

In [2]:
# graph layer from the adjacancy list 
# this class will updates nodes only 

class AdjacencyListGraphLayer(Module):
    def __init__(self,input_dim,output_dim,ffn_dim):
        super().__init__()
        self.neighbour_mlp = Sequential(Linear(input_dim,ffn_dim),ReLU(),Linear(ffn_dim,output_dim))
        self.center_mlp = Sequential(Linear(input_dim,ffn_dim),ReLU(),Linear(ffn_dim,output_dim))
        self.output_mlp = Sequential(Linear(input_dim,ffn_dim),ReLU(),Linear(ffn_dim,output_dim))
        
    def forward(self,adjacency_list,node_features):
        neighbour_updated_node_features = self.neighbour_mlp(node_features)
        center_updated_node_features = self.center_mlp(node_features)
        
        aggregated_neighbourhoods = []
        for node_idx,neighbours in adjacency_list.items():
            neighbourhood = [neighbour_updated_node_features[neighbour_idx] for neighbour_idx in neighbours]
            neighbourhood = torch.stack(neighbourhood,dim=0)
            center_node = center_updated_node_features[node_idx]
            aggregated_neighbourhood = torch.sum(neighbourhood,dim=0) + center_node
            aggregated_neighbourhoods.append(aggregated_neighbourhood)
            
        aggregated_neighbourhoods = torch.stack(aggregated_neighbourhoods,dim=0)
        updated_node_features = self.output_mlp(aggregated_neighbourhoods)
        return updated_node_features

In [None]:
'''
# initialize the adjacency list graph layer 
adjacency_list_graph_layer = AdjacencyListGraphLayer(d_model, d_model, ffn_dim)


# get the first graph from dataset 
graph = dataset[1]
adjacency_list = graph['adjacency_list']
categorical_node_features = graph['categorical_node_features']
continuous_node_features = graph['continuous_node_features']

# get the node features 
node_features = node_featurizer(continuous_node_features, categorical_node_features)
node_features

# apply graph layer 
adjacency_list_graph_layer(adjacency_list, node_features)
'''

In [4]:
# this updates the node features from adjacency matrix 
class AdjacencyMatrixGraphLayer(Module):
    def __init__(self, input_dim, output_dim, ffn_dim):
        super().__init__()
        self.neighbour_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.center_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.output_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
    
    def forward(self,adjacency_matrix, node_features):
        neighbour_updated_node_features = self.neighbour_mlp(node_features)
        center_updated_node_features = self.center_mlp(node_features)
        
        aggregated_neightbourhood = torch.matmul(adjacency_matrix,neighbour_updated_node_features)
        aggregated_neightbourhood = aggregated_neightbourhood + center_updated_node_features
        
        updated_node_features = self.output_mlp(aggregated_neightbourhood)
        return updated_node_features

In [5]:
# update the edge and node features using adjacency list 
class AdjacencyListEdgeFeaturesGraphLayer(Module):
    def __init__(self, input_dim, output_dim, ffn_dim):
        super().__init__()
        self.neighbour_edges_mlp = Sequential(Linear(2*input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.center_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.output_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        
    def forward(self,adjacency_list,node_features,edge_features):
        center_updated_node_features = self.center_mlp(node_features)
        aggregated_neighbourhoods = []
        for node_idx,neighbours in adjacency_list.items():
            neighbourhood = []
            for neighbour_idx in neighbours:
                edge = edge_features[node_idx,neighbour_idx] # edge feats 
                neighbour = node_features[neighbour_idx]     # neighbour node feats 
                neighbour_edge_vector = torch.cat([edge,neighbour],dim=-1) # edge + neigh feats
                neighbour_updated_node_features = self.neighbour_edges_mlp(neighbour_edge_vector) # mlp 
                neighbourhood.append(neighbour_updated_node_features)
            neighbourhood = torch.stack(neighbourhood,dim=0) #stack 
            center_node = center_updated_node_features[node_idx]
            aggregated_neighbourhood = torch.sum(neighbourhood,dim=0) + center_node
            aggregated_neighbourhoods.append(aggregated_neighbourhood)
        
        aggregated_neighbourhoods = torch.stack(aggregated_neighbourhoods,dim=0)
        updated_node_features = self.output_mlp(aggregated_neighbourhoods)
        return updated_node_features

In [None]:
'''
adjacency_list_edge_features_graph_layer = AdjacencyListEdgeFeaturesGraphLayer(d_model, d_model, ffn_dim)

# node features 
node_featurizer = FeatureCombiner(dataset.categorical_node_variables, node_embedding_dim)
# get a graph 
graph = dataset[1]
adjacency_list = graph['adjacency_list']
categorical_node_features = graph['categorical_node_features']
continuous_node_features = graph['continuous_node_features']
# combine cts and cat node features 
node_features = node_featurizer(continuous_node_features, categorical_node_features)

# get the edge features 
edge_featurizer = FeatureCombiner(dataset.categorical_edge_variables, edge_embedding_dim)
categorical_edge_features = graph['categorical_edge_features']
continuous_edge_features = graph['continuous_edge_features']
edge_features = edge_featurizer(continuous_edge_features, categorical_edge_features)

# graph layer (using adjacency list)
adjacency_list_edge_features_graph_layer(adjacency_list, node_features, edge_features)
'''

In [7]:
class AdjacencyMatrixEdgeFeaturesGraphLayer(Module):
    def __init__(self, input_dim, output_dim, ffn_dim):
        super().__init__()
        self.neighbour_edges_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.center_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
        self.output_mlp = Sequential(Linear(input_dim, ffn_dim), ReLU(), Linear(ffn_dim, output_dim))
    
    def forward(self, adjacency_matrix, node_features,edge_features):
        center_updated_node_features = self.center_mlp(node_features)                      # (N,d_out)
        edge_and_node_features = edge_features + node_features.unsqueeze(dim=0)            # (N,N,d_in), where d_in = d_n + d_e
        neighbourhood = self.neighbour_edges_mlp(edge_and_node_features)                   # (N,N,d_out)
        masked_edge_and_node_features = neighbourhood * adjacency_matrix.unsqueeze(dim=-1) # (N,N,d_out) @ (N,N,1) -> (N,N,d_out)
        reduced_neighbourhoods = masked_edge_and_node_features.sum(dim=1)                  # (N,d_out)
        aggregated_neighbourhoods = reduced_neighbourhoods + center_updated_node_features  # (N,d_out)
        updated_node_features = self.output_mlp(aggregated_neighbourhoods)
        return updated_node_features                                                       # (N,d)