In [11]:
from collections import defaultdict
from collections.abc import Set

import rdkit
from rdkit.Chem import MolFromSmiles

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Sequential,Linear,ReLU, Module, Embedding, ModuleList, LayerNorm, Dropout
from torch.nn import Parameter
from torch.nn.functional import softmax

float_type = torch.float32  
categorical_type = torch.long
mask_type = torch.float32  
labels_type = torch.float32

In [2]:
# class to handle cts variables 
class ContinuousVariable:
    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return f'<ContinuousVariable: {self.name}>'

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self):
        return hash(self.name)
    
# class to handle cat variables 
class CategoricalVariable:
    def __init__(self, name, values, add_null_value=True):
        self.name = name
        self.has_null_value = add_null_value
        if self.has_null_value:
            self.null_value = None
            values = (None,) + tuple(values)
        self.values = tuple(values)
        self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
        self.inv_value_to_idx_mapping = {i: v for v, i in self.value_to_idx_mapping.items()}

        if self.has_null_value:
            self.null_value_idx = self.value_to_idx_mapping[self.null_value]

    def get_null_idx(self):
        if self.has_null_value:
            return self.null_value_idx
        else:
            raise RuntimeError(f"Categorical variable {self.name} has no null value")

    def value_to_idx(self, value):
        return self.value_to_idx_mapping[value]

    def idx_to_value(self, idx):
        return self.inv_value_to_idx_mapping[idx]

    def __len__(self):
        return len(self.values)

    def __repr__(self):
        return f'<CategoricalVariable: {self.name}>'

    def __eq__(self, other):
        return self.name == other.name and self.values == other.values

    def __hash__(self):
        return hash((self.name, self.values))


In [3]:
# graph dataset class: this returns graph attributes for set of graphs  
class GraphDataset(Dataset):
    def __init__(self, *, graphs, labels, node_variables, edge_variables, metadata=None):
        '''
        Create a new graph dataset, 
        '''
        self.graphs = graphs
        self.labels = labels
        assert len(self.graphs) == len(self.labels), "The graphs and labels lists must be the same length"
        self.metadata = metadata
        if self.metadata is not None:
            assert len(self.metadata) == len(self.graphs), "The metadata list needs to be as long as the graphs"
        self.node_variables = node_variables
        self.edge_variables = edge_variables
        self.categorical_node_variables = [var for var in self.node_variables if isinstance(var, CategoricalVariable)]
        self.continuous_node_variables = [var for var in self.node_variables if isinstance(var, ContinuousVariable)]
        self.categorical_edge_variables = [var for var in self.edge_variables if isinstance(var, CategoricalVariable)]
        self.continuous_edge_variables = [var for var in self.edge_variables if isinstance(var, ContinuousVariable)]

    def __len__(self):
        return len(self.graphs)

    def make_continuous_node_features(self, nodes):
        if len(self.continuous_node_variables) == 0:
            return None
        n_nodes = len(nodes)
        n_features = len(self.continuous_node_variables)
        continuous_node_features = torch.zeros((n_nodes, n_features), dtype=float_type)
        for node_idx, features in nodes.items():
            node_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_node_variables], dtype=float_type)
            continuous_node_features[node_idx] = node_features
        return continuous_node_features
    
    def make_categorical_node_features(self, nodes):
        if len(self.categorical_node_variables) == 0:
            return None
        n_nodes = len(nodes)
        n_features = len(self.categorical_node_variables)
        categorical_node_features = torch.zeros((n_nodes, n_features), dtype=categorical_type)

        for node_idx, features in nodes.items():
            for i, categorical_variable in enumerate(self.categorical_node_variables):
                value = features[categorical_variable]
                value_index = categorical_variable.value_to_idx(value)
                categorical_node_features[node_idx, i] = value_index
        return categorical_node_features

    def make_continuous_edge_features(self, n_nodes, edges):
        if len(self.continuous_edge_variables) == 0:
            return None
        n_features = len(self.continuous_edge_variables)
        continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=float_type)
        for edge, features in edges.items():
            edge_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_edge_variables], dtype=float_type)
            u,v = edge
            continuous_edge_features[u, v] = edge_features
            if isinstance(edge, Set):
                continuous_edge_features[v, u] = edge_features
        return continuous_edge_features

    def make_categorical_edge_features(self, n_nodes, edges):
        if len(self.categorical_edge_variables) == 0:
            return None
        n_features = len(self.categorical_edge_variables)
        categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=categorical_type)

        for edge, features in edges.items():
            u,v = edge
            for i, categorical_variable in enumerate(self.categorical_edge_variables):
                value = features[categorical_variable]
                value_index = categorical_variable.value_to_idx(value)
                categorical_edge_features[u, v, i] = value_index
                if isinstance(edge, Set):
                    categorical_edge_features[v, u, i] = value_index
        return categorical_edge_features

    def __getitem__(self, index):
        
        graph = self.graphs[index]
        nodes, edges = graph
        n_nodes = len(nodes)
        continuous_node_features = self.make_continuous_node_features(nodes)
        categorical_node_features = self.make_categorical_node_features(nodes)
        continuous_edge_features = self.make_continuous_edge_features(n_nodes, edges)
        categorical_edge_features = self.make_categorical_edge_features(n_nodes, edges)

        label = self.labels[index]
        nodes_idx = sorted(nodes.keys())
        edge_list = sorted(edges.keys())
        n_nodes = len(nodes)
        adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
        
        for edge in edges:
            u, v = edge
            adjacency_matrix[u,v] = 1
            if isinstance(edge, Set):
                # This edge is unordered, assume this is a undirected graph
                adjacency_matrix[v,u] = 1

        adjacency_list = defaultdict(list)
        
        for edge in edges:
            u,v = edge
            adjacency_list[u].append(v)
            # Assume undirected graph is the edge is a set
            if isinstance(edge, Set):
                adjacency_list[v].append(u)

        data_record = {'nodes': nodes_idx,
                    'adjacency_matrix': adjacency_matrix,
                    'adjacency_list': adjacency_list,
                    'categorical_node_features': categorical_node_features,
                    'continuous_node_features': continuous_node_features,
                    'categorical_edge_features': categorical_edge_features,
                    'continuous_edge_features': continuous_edge_features,
                    'label': label}

        if self.metadata is not None:
            data_record['metadata'] = self.metadata[index]
            
        return data_record

    def get_node_variables(self):
        return {'continuous': self.continuous_node_variables,
                'categorical': self.categorical_node_variables}
    
    def get_edge_variables(self):
        return {'continuous': self.continuous_edge_variables,
                'categorical': self.categorical_edge_variables}

In [4]:
# this class handels the mini-batches of graphs 
# the graph dataset class can handle many graphs, but it can't create mini-batchs 

from collections.abc import Set

def collate_graph_batch(batch):
    '''Collate a batch of graph dictionaries produdce by a GraphDataset'''
    batch_size = len(batch)
    max_nodes = max(len(graph['nodes']) for graph in batch)

    adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes), dtype=float_type)
    labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
    stacked_continuous_node_features = None
    stacked_categorical_node_features = None
    stacked_continuous_edge_features = None
    stacked_categorical_edge_features = None

    nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
    edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)

    has_metadata = False

    for i, graph in enumerate(batch):
        if 'metadata' in graph:
            has_metadata = True
            
        adjacency_matrix = graph['adjacency_matrix']
        g_nodes, g_nodes = adjacency_matrix.shape
        adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix

        edge_mask[i, :g_nodes, :g_nodes] = 1
        nodes_mask[i, :g_nodes] = 1

        g_continuous_node_features = graph['continuous_node_features']
        if g_continuous_node_features is not None:
            if stacked_continuous_node_features is None:
                g_nodes, num_features = g_continuous_node_features.shape
                stacked_continuous_node_features = torch.zeros((batch_size, max_nodes, num_features))
            stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features

        g_categorical_node_features = graph['categorical_node_features']
        if g_categorical_node_features is not None:
            if stacked_categorical_node_features is None:
                g_nodes, num_features = g_categorical_node_features.shape
                stacked_categorical_node_features = torch.zeros((batch_size, max_nodes, num_features), dtype=categorical_type)
            stacked_categorical_node_features[i, :g_nodes] = g_categorical_node_features

        g_continuous_edge_features = graph['continuous_edge_features']
        if g_continuous_edge_features is not None:
            if stacked_continuous_edge_features is None:
                g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
                stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features))
            stacked_continuous_edge_features[i, :g_nodes, :g_nodes] = g_continuous_edge_features

        g_categorical_edge_features = graph['categorical_edge_features']
        if g_categorical_edge_features is not None:
            if stacked_categorical_edge_features is None:
                g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
                stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features), dtype=categorical_type)
            stacked_categorical_edge_features[i, :g_nodes, :g_nodes] = g_categorical_edge_features


    batch_record = {'adjacency_matrices': adjacency_matrices,
            'categorical_node_features': stacked_categorical_node_features,
            'continuous_node_features': stacked_continuous_node_features,
            'categorical_edge_features': stacked_categorical_edge_features,
            'continuous_edge_features': stacked_continuous_edge_features,
            'nodes_mask': nodes_mask,
            'edge_mask': edge_mask,
            'labels': labels}
    
    if has_metadata:
        batch_record['metadata'] = [g['metadata'] for g in batch]

    return batch_record

In [5]:
# embedder: to embedd categorical variables ( in both node and edge features)
from torch.nn import Module
from torch.nn import Embedding
from torch.nn import Module, ModuleList

class Embedder(Module):
    def __init__(self, categorical_variables, embedding_dim):
        super().__init__()
        self.categorical_variables = categorical_variables
        embeddings = []
        for var in categorical_variables:
            num_embeddings = len(var)
            if var.has_null_value:
            # It's not uncommon to have missing values, we support this assinging a special 0-index which have the zero-vector as its embedding
                embedding = Embedding(num_embeddings, embedding_dim, padding_idx=var.get_null_idx())
            else:
                embedding = Embedding(num_embeddings, embedding_dim)
            embeddings.append(embedding)
        self.embeddings = ModuleList(embeddings)


    def forward(self, categorical_features):
        all_embedded_vars = []
        for i, embedding in enumerate(self.embeddings):
            var_indices = categorical_features[..., i]  
            embedded_vars = embedding(var_indices)
            all_embedded_vars.append(embedded_vars)

        # If you like, you can implement concatenation instead of sum here
        stacked_embedded_vars = torch.stack(all_embedded_vars, dim=0)
        embedded_vars = torch.sum(stacked_embedded_vars, dim=0)
        return embedded_vars

In [6]:
# feature combiner: to combine categorical and continous data 

class FeatureCombiner(Module):
    def __init__(self, categorical_variables, embedding_dim):
        super().__init__()
        self.categorical_variables = categorical_variables
        self.embedder = Embedder(self.categorical_variables, embedding_dim)
        
    def forward(self, continuous_features, categorical_features, ):
        # We need to be agnostic to whether we have categorical features and continuous features (it's not uncommon to only use one kind)
        features = []
        if categorical_features is not None:
            embedded_features = self.embedder(categorical_features)
            features.append(embedded_features)
        # The embedded features are now of shape (n_nodes, embedding_dim)
        if continuous_features is not None:
            features.append(continuous_features)
        if len(features) == 0:
            raise RuntimeError('No features to combine')
        full_features = torch.cat(features, dim=-1)
        return full_features

In [7]:
# graph prediction head: this class is used to do graph level predictions at the end 

class GraphPredictionHeadConfig:
  def __init__(self, *, d_model, ffn_dim, pooling_type='sum'):
    # Pooling type can be 'sum' or 'mean'
    self.d_model = d_model
    self.ffn_dim = ffn_dim
    self.pooling_type = pooling_type

class GraphPredictionHead(Module):
    def __init__(self, input_dim, output_dim, config):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.config = config
        self.predictor = Sequential(Linear(self.input_dim, self.config.ffn_dim), 
                                    ReLU(), 
                                    Linear(self.config.ffn_dim, self.output_dim))

    def forward(self, node_features, node_mask):
        if self.config.pooling_type == 'sum':
            pooled_nodes = node_features.sum(dim=-2) # node_feat = [B,N,d_n] -> [B,d_n]
        elif self.config.pooling_type == 'mean': 
            node_counts = node_mask.sum(dim=-1)      # node_mask = [B,N] -> [B,1]
            summed_feature_vectors = node_features.sum(dim=-2)
            pooled_nodes = summed_feature_vectors/node_counts # [B,d_n]/[B,1] ->[B,d_n]
        else:
            raise ValueError(f'Unsupported pooling type {self.config.pooling_type}')

        prediction = self.predictor(pooled_nodes) # [B,d_n] -> [B,d_o]
        return prediction

In [8]:
# transformer layer 
import math 

class BasicTransformerConfig:
  def __init__(self, *, 
                d_model: int, 
                n_layers: int, 
                ffn_dim: int,
                head_dim: int,
                layer_normalization: bool = True,
                dropout_rate: float = 0.1,
                residual_connections: bool=True):
    self.d_model = d_model
    self.n_layers = n_layers
    self.ffn_dim = ffn_dim
    self.head_dim = head_dim
    self.layer_normalization = layer_normalization
    self.dropout_rate = dropout_rate
    self.residual_connections = residual_connections
    
class BasicTransformerLayer(Module):
    def __init__(self,config):
        super().__init__()
        self.config = config 
        self.input_dim = config.d_model
        self.output_dim = config.d_model
        self.ffn_dim = config.ffn_dim 
        self.head_dim = config.head_dim 
        # Transformers typically don't use mlps, they use linear layers 
        self.neighbour_transform = Linear(self.input_dim, self.head_dim, bias=False)
        self.center_transform = Linear(self.input_dim, self.head_dim, bias=False)
        # The transformer uses layer normalization by default
        self.attention_norm = LayerNorm(self.input_dim)
        self.output_transform = Sequential(Linear(self.input_dim, self.ffn_dim),
                                            ReLU(), 
                                            Linear(self.ffn_dim, self.output_dim))
        self.output_norm = LayerNorm(self.output_dim)
        self.dropout = Dropout(self.config.dropout_rate)
        self.scaling_factor = math.sqrt(self.input_dim)
        
    def attention_function(self,adjacency_matrix, center_node_features, 
                            neighbour_node_features, edge_features, node_mask, edge_mask):
        # attn_logit: (B,N,d_n) @ (B,d_n,N) -> (B,N,N)
        attention_logits = torch.matmul(center_node_features,neighbour_node_features.transpose(-1,-2))/self.scaling_factor
        node_mask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1) # (B,1,N) * (B,N,1) -> (B,N,N)
        fill_mask = (1-node_mask_2d).to(torch.bool) # (B,N,N)
        attention_logits.masked_fill_(fill_mask, float('-inf')) # (B,N,N)
        attention_matrix = softmax(attention_logits, dim=-1)    # (B,N,N)
        attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
        return attention_matrix
        
    def forward(self,adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
        center_node_features = self.center_transform(node_features)       # (B,N,d_h)
        neighbour_node_features = self.neighbour_transform(node_features) # (B,N,d_h)
        
        attention_matrix = self.attention_function(adjacency_matrix, center_node_features,neighbour_node_features, 
                                                    edge_features, node_mask, edge_mask) # (B,N,N)
        # The transformer doesn't transform the node features at this stage
        aggregated_neighbourhoods = torch.matmul(attention_matrix,node_features)  # (B,N,N) @ (B,N,d_n) -> (B,N,d_n)
        masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1) # (B,N,d_n) * (B,N,1) -> (B,N,d_n)
        
        masked_features = self.dropout(masked_features)
        masked_features = self.attention_norm(masked_features)

        if self.config.residual_connections:
            masked_features = masked_features + node_features  # (B,N,d_n) + (B,N,d_n) -> (B,N,d_n)
        
        updated_node_features = self.output_transform(masked_features) # (B,N,d_o)
        
        # Mask again
        updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)
        
        # Followed by a dropout and normalization
        updated_node_features = self.dropout(updated_node_features)
        updated_node_features = self.output_norm(updated_node_features)
        
        # And the resiudal connection from the input to the output MLP
        if self.config.residual_connections:
            updated_node_features = updated_node_features + masked_features

        return updated_node_features

In [9]:

class BasicTransformerEncoder(torch.nn.Module):
    def __init__(self,*,config:BasicTransformerConfig,
                continuous_node_variables=None,
                categorical_node_variables=None,
                continuous_edge_variables=None,
                categorical_edge_variables=None,
                layer_type:BasicTransformerLayer):
        super().__init__()
        self.cofig = config
        self.layer_type = layer_type
        self.continuous_node_variables = continuous_node_variables
        self.categorical_node_variables = categorical_node_variables
        self.continuous_edge_variables = continuous_edge_variables
        self.categorical_edge_variables = categorical_edge_variables
        
        self.categorical_node_embeddings_dim = config.d_model - len(self.continuous_node_variables)
        self.categorical_edge_embeddings_dim = config.d_model - len(self.continuous_edge_variables)
        
        self.node_featurizer = FeatureCombiner(self.categorical_node_variables, 
                                            self.categorical_node_embeddings_dim)
        self.edge_featurizer = FeatureCombiner(self.categorical_edge_variables, 
                                            self.categorical_edge_embeddings_dim)
        
        self.graph_layers = ModuleList([layer_type(config) for l in range(config.n_layers)])
        
    def forward(self,batch):
        
        node_mask = batch['node_mask']
        batch_size, max_nodes = node_mask.shape
        
        continuous_node_features = batch['continuous_node_features']
        categorical_node_features = batch['categorical_node_features']
        node_features = self.node_featurizer(continuous_node_features, categorical_node_features)
        masked_node_features = node_features * node_mask*node_mask.unsqueeze(dim=-1)
        
        edge_mask = batch['edge_mask']
        continuous_edge_features = batch['continuous_edge_features']
        categorical_edge_features = batch['categorical_edge_features']
        edge_features = self.edge_featurizer(continuous_edge_features, categorical_edge_features)
        masked_edge_features = edge_features * edge_mask.unsqueeze(dim=-1)
        
        adjacency_matrix = batch['adjacency_matrices']
        memory_state = masked_node_features
        for l in self.graph_layers:
            memory_state = l(adjacency_matrix,memory_state,masked_edge_features,node_mask,edge_mask)
            
        return memory_state

In [10]:
class GraphPredictionNeuralNetwork(Module):
    def __init__(self, encoder, prediction_head):
        super().__init__()
        self.encoder = encoder
        self.prediction_head = prediction_head

    def forward(self, batch):
        encoded_graph = self.encoder(batch)
        prediction = self.prediction_head(encoded_graph, batch['nodes_mask'])
        return prediction

In [12]:
# adding graph sructure to transfomer 

class AdjacencyTransformerLayer(BasicTransformerLayer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.adjacency_weight = Parameter(torch.tensor(1.))
    
def attention_fucntion(self,adjacency_matrix, center_node_features, 
                        neighbour_node_features, edge_features, node_mask, edge_mask):
    
    # transformer dot product (or attention weights)
    # (B,N,d_n) @ (B,d_n,N) -> (B,N,N)
    dot_product_logits = torch.matmul(center_node_features, 
                                    neighbour_node_features.transpose(-1, -2))/self.scaling_factor
    # graph attention weights
    adjacency_logits = self.adjacency_weight*adjacency_matrix
    # transformer + graph attenstion weights 
    attention_logits = dot_product_logits + adjacency_logits
    
    nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
    fill_mask = (1 - nodemask_2d).to(torch.bool)
    attention_logits.masked_fill_(fill_mask, float('-inf'))
    attention_matrix = softmax(attention_logits, dim=-1)
    
    attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
    return attention_matrix

In [22]:
# edge features in attention: here edge features (in addition to node features) are used for the attention weights 

class EdgeAttributesAttentionTransformerLayer(BasicTransformerLayer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

    # We're going to sum our center and neighbour vectors to the edge feature
    # vector, so have to be mindful of the dimensionality
    self.center_transform = Linear(self.input_dim, self.input_dim)
    self.neighbour_transform = Linear(self.input_dim, self.input_dim)
    self.attention_score_function = Sequential(Linear(self.input_dim, self.config.ffn_dim),ReLU(), 
                                                Linear(self.config.ffn_dim, 1))
    
def attention_fucntion(self,adjacency_matrix,center_node_features,neighbour_node_features,
                        edge_features,node_mask,edge_mask):
    attention_score_input = edge_features + center_node_features.unsqueeze(dim=-2) # (B,N,N,d_e) + (B,N,1,d_n) -> (B,N,N,d_0)
    attention_score_input = attention_score_input + neighbour_node_features.unsqueeze(dim=-3) # (B,N,N,d_0) + (B,1,N,d_n)->(B,N,N,d_0)
    attention_logits = self.attention_score_function(attention_score_input).sum(dim=-1) # (B,N,N,d_0) @ (B,N,d_0,N) -> (B,N,N,N)->(B,N,N)
    # nm : (B,N)->(B,1,N)*(B,N,1) -> (B,N,N)
    nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
    fill_mask = (1 - nodemask_2d).to(torch.bool)
    attention_logits.masked_fill_(fill_mask, float('-inf')) # (B,N,N)
    attention_matrix = softmax(attention_logits, dim=-1)    # (B,N,N)
    attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
    return attention_matrix

In [24]:

class EdgeAttributesTransformerLayer(EdgeAttributesAttentionTransformerLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
        center_node_features = self.center_transform(node_features)
        neighbour_node_features = self.center_transform(node_features)
        attention_matrix = self.attention_function(adjacency_matrix, center_node_features,neighbour_node_features, 
                                                    edge_features,node_mask, edge_mask)
        
        # here we use combined features (edge + center_node + neighbour_node) in the transformer 
        context_dependent_features = edge_features + center_node_features.unsqueeze(dim=-2) # (B,N,N,d_e) + (B,N,1,d_n)-> (B,N,N,d_n), assume d_n =d_e
        context_dependent_features = context_dependent_features + neighbour_node_features.unsqueeze(dim=-3) # (B,N,N,d_n) + (B,1,N,d_n) -> (B,N,N,d_n)
        attended_features = context_dependent_features * attention_matrix.unsqueeze(dim=-1) # (B,N,N,d_n) * (B,N,N,1)-> (B,N,N,d_n)
        aggregated_neighbourhoods = attended_features.sum(dim=-2)  # (B,N,N,d_n) -> (B,N,d_n)
        masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1) # (B,N,d_n) * (B,N,1) -> (B,N,d_n)
        
        masked_features = self.dropout(masked_features)
        masked_features = self.attention_norm(masked_features)
        
        if self.config.residual_connections:
            masked_features = masked_features + node_features
            
        updated_node_features = self.output_transform(masked_features)
        updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)
        
        if self.config.residual_connections:
            updated_node_features = updated_node_features + masked_features

        return updated_node_features