In [1]:
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Embedding(nn.Module):
    """
    Embedding of the feature tokens in both, Transformer and EEDGNN.

    args:
        input_dim: int, dim of the feature tokens
        layers_dim: list, list of the number of neurons of each FC layer

    return:
          MLP network.
    """

    def __init__(self, input_dim, layers_dim):
        super(Embedding, self).__init__()

        self.layers_dim = layers_dim
        self.layers = nn.ModuleList()

        for i in layers_dim[1:]:
            self.layers.append(nn.LayerNorm(input_dim))
            self.layers.append(nn.Linear(input_dim, i))
            self.layers.append(nn.GELU())
            input_dim = i

    def forward(self, input_):
        x = input_
        for layer in self.layers:
            x = layer(x)
        return x

class mlp(nn.Module):
    """
    Final mlp of the network, after pooling the Transformer encoder output.

    args:
        input_dim: int, dim of the feature tokens of Transformer encoder output
        layers_dim: list, list of the number of neurons of each FC layer

    return:
          MLP network with one neuron output and Sigmoid activation
    """

    def __init__(self, input_dim, layers_dim):
        super(mlp, self).__init__()
        self.layers_dim = layers_dim
        self.layers = nn.ModuleList()

        for i in layers_dim[1:]:
            self.layers.append(nn.LayerNorm(input_dim))
            self.layers.append(nn.Linear(input_dim, i))
            self.layers.append(nn.GELU())
            input_dim = i
        self.layers.append(nn.Linear(i, 1))
        self.layers.append(nn.Sigmoid())

    def forward(self, input_):
        x = input_
        for layer in self.layers:
            x = layer(x)
        return x

class EEDGCNEncoder(nn.Module):
    def __init__(self,in_channels_n,  out_channels_n,in_channels_E,out_channels_E, k,n_layers, pooling='avg'):
        super(EEDGCNEncoder,self).__init__()
        self.in_channels_n=in_channels_n
        self.out_channels_n=out_channels_n
        self.out_channels_E=out_channels_E
        self.in_channels_E=in_channels_E
        self.k=k
        self.n_layers=n_layers
        self.pooling=pooling
        self.layers = nn.ModuleList([EdgeConvWithEdgeFeatures(self.in_channels_n, self.out_channels_n,self.out_channels_E,self.k,self.pooling) for i in range(self.n_layers)])
        self.nH = nn.LayerNorm(self.in_channels_n)
        self.nE = nn.LayerNorm(self.in_channels_E)
    def forward(self,H,E):

        out_H = self.nH(H)
        out_E = self.nE(E)

        for layer in self.layers:
            out_H,out_E = layer(out_H,out_E)

        return out_H,out_E

class Edgeupdate(nn.Module):
    def __init__(self, hidden_dim, dim_e, dropout_ratio=0.2):
        super(Edgeupdate, self).__init__()
        self.hidden_dim = hidden_dim
        self.dim_e = dim_e
        self.dropout = dropout_ratio
        self.W = nn.Linear(self.hidden_dim * 2 + self.dim_e, self.dim_e)

    def forward(self, edge, node1, node2):
        """
        :param edge: [batch, node,node, dim_e]
        :param node: [batch, node, node, dim]
        :return:
        """

        node = torch.cat([node1, node2], dim=-1) # [batch, node, node, dim * 2]
        edge = self.W(torch.cat([edge, node], dim=-1))
        return edge  # [batch, node,npde, dim_e]

class EdgeConvWithEdgeFeatures(nn.Module):
    def __init__(self, in_channels,  out_channels_n,out_channels_E, k, pooling='avg'):
        super(EdgeConvWithEdgeFeatures, self).__init__()
        self.k = k
        self.pooling=pooling
        self.in_channels = in_channels
        self.out_channels_n = out_channels_n
        self.out_channels_E = out_channels_E
        self.W = nn.Linear(self.in_channels, self.out_channels_E)
        self.highway = Edgeupdate(self.in_channels, self.out_channels_E, dropout_ratio=0.2)
        self.mlp = nn.Sequential(
            nn.Linear(2 * self.in_channels, self.out_channels_n, bias=False),
            nn.BatchNorm1d(self.out_channels_n),
            nn.GELU()
        )
        
    def forward(self, x,weight_adj):
        """
        Args:
            x: Input point cloud data, shape [B, N, D]
               B - batch size, N - number of points, D - feature dimensions
            edge_features: Input edge features, shape [B, N, k, E]
               E - edge feature dimensions
        Returns:
            x_out: Updated features after EdgeConv, shape [B, N, out_channels]
        """
        B, N, D = x.size()
        _, _, _, E = weight_adj.size()
        
        # Step 1: Compute pairwise distance and get k-nearest neighbors
        pairwise_dist = torch.cdist(x, x, p=2)  # [B, N, N]
        idx = pairwise_dist.topk(k=self.k, dim=-1, largest=False)[1]  # [B, N, k]
        
        # Step 2: Gather neighbor features
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1), 
            2, 
            idx.unsqueeze(-1).expand(-1, -1, -1, D)
        )  # [B, N, k, D]
        
        # Central point repeated for k neighbors: [B, N, k, D]
        central = x.unsqueeze(2).expand(-1, -1, self.k, -1)
        
        # Step 3: Compute edge features
        relative_features = neighbors - central  # [B, N, k, D]
        combined_features = torch.cat([central, relative_features], dim=-1)  # [B, N, k, 2*D + E]
        
        # Step 4: Apply MLP and aggregation
        combined_features = self.mlp(combined_features.view(-1, 2 * D))  # [B * N * k, out_channels]
        combined_features = combined_features.view(B, N, self.k, -1)  # Reshape to [B, N, k, out_channels]
        
        if self.pooling == 'avg':
            n_out = combined_features.mean(dim=2)
        elif self.pooling == 'max':
            n_out = combined_features.max(dim=2)[0]
        elif self.pooling == 'sum':
            n_out = combined_features.sum(dim=2)
    
        
        node_outputs1 = n_out.unsqueeze(1).expand(B, N, N,D)
        node_outputs2 = node_outputs1.permute(0, 2, 1, 3).contiguous()
        edge_outputs = self.highway(weight_adj,node_outputs1,node_outputs2)
        
        return n_out,edge_outputs
    
class MultiHead_Self_Attention(nn.Module):
    def __init__(self, embed_dim0, embed_dim, num_heads, masked=True):
        super(MultiHead_Self_Attention, self).__init__()
        """
        MultiHead self attention with interaction matrix U. 
        The diemsnion of U is (batch,num of heads, particle tokens, particle tokens)
        
        Args:
            embed_dim0: int, dim of the feature tokens
            embed_dim: int,hidden dimension, "scaled attention"
            num_heads: int,number of attention heads
            masked: polean, using of the attention mask to remove the padded points
            
        return:
              1- output of attention heads, with dim (batch,particle tokens, feature tokens)
              2- attention weights, with dim (batch, particle tokens, particle tokens)
        """
        self.masked = masked
        self.embed_dim0 = embed_dim0
        if embed_dim % num_heads == 0:

            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.head_dim = embed_dim // num_heads
        else:
            self.embed_dim = embed_dim + (num_heads - embed_dim % num_heads)
            self.num_heads = num_heads
            self.head_dim = self.embed_dim // num_heads

        # Initialize the linear layers
        self.q_linear = nn.Linear(self.embed_dim0, self.embed_dim, bias=False)
        self.k_linear = nn.Linear(self.embed_dim0, self.embed_dim, bias=False)
        self.v_linear = nn.Linear(self.embed_dim0, self.embed_dim, bias=False)
        self.out_linear = nn.Linear(self.embed_dim, self.embed_dim0, bias=False)

    def att_mask(self, input_):
        """
        Function to create attention mask with 1 for unpadded points and 0 for padded points

        arg1: input data set with dim (batch_size, n_particles,n_features)
        output: mask tensor of dim (batch_size, number of heads, n_particles,n_particles)
        """

        mask = (input_.sum(dim=-1) != 0).float()
        mask = mask[:, :, None, :]
        mask = mask.repeat(1, 1, mask.size(-1), 1)

        return mask

    def scaled_dot_product_attention(self, Q, K, V, U):
        """
        Computes scaled dot-product attention.
        dim of U: batch_size x num_heads x particle_tokens x particle_tokens
        """
        d_k = Q.size(-1)

        scores = (torch.matmul(Q, K.transpose(-2, -1)) + U) / torch.sqrt(
            torch.tensor(d_k, dtype=torch.float32)
        )

        if self.masked:
            mask = self.att_mask(Q)
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)

        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, query, U):
        batch_size = query.size(0)

        Q = self.q_linear(query)
        K = self.k_linear(query)
        V = self.v_linear(query)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply scaled dot-product attention
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, U)

        # Concatenate heads
        attn_output = (
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.embed_dim)
        )

        # Apply final linear layer # W matrix
        output = self.out_linear(attn_output)

        return output, attn_weights


###########################################
###########################################
###########################################
class TransformerLayer(nn.Module):
    def __init__(
        self,
        input_dim,
        h_dim=500,
        expansion_factor=4,
        n_heads=10,
        masked=True,
    ):
        super(TransformerLayer, self).__init__()

        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: fator ehich determines output dimension of linear layer
           n_heads: number of attention heads
        
        """
        """
        To be done: 
               Here, the hidden dimension is fixed, may be we need to adopt it in the future.
        """

        self.input_dim = input_dim
        self.n_heads = n_heads
        self.masked = masked
        self.expansion_factor = expansion_factor
        self.h_dim = h_dim
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.attention = MultiHead_Self_Attention(
            self.input_dim, self.h_dim, self.n_heads, self.masked
        )
        self.norm1 = nn.LayerNorm(self.input_dim)
        self.norm2 = nn.LayerNorm(self.input_dim)
        self.feed_forward = nn.Sequential(
            nn.LayerNorm(self.input_dim),
            nn.Linear(self.input_dim, self.expansion_factor * self.input_dim),
            nn.GELU(),
            nn.LayerNorm(self.expansion_factor * self.input_dim),
            nn.Linear(self.expansion_factor * self.input_dim, self.input_dim),
        )

    def forward(self, query, U):

        attention_out, _ = self.attention(self.norm1(query), U)
        attention_residual_out = self.norm2(attention_out) + query
        norm1_out = self.dropout1(attention_residual_out)
        feed_fwd_out = self.feed_forward(norm1_out)
        feed_fwd_residual_out = feed_fwd_out + norm1_out
        norm2_out = self.dropout2(feed_fwd_residual_out)

        return norm2_out


###########################################
###########################################
###########################################
class TransformerEncoder(nn.Module):
    """
    Args:
        seq_len : length of input sequence
        embed_dim: dimension of embedding
        num_layers: number of encoder layers
        expansion_factor: factor which determines number of linear layers in feed forward layer
        n_heads: number of heads in multihead attention

    Returns:
        out: output of the encoder
    """

    def __init__(
        self,
        input_dim,
        embed_dim=[512, 256, 128],
        h_dim=200,
        num_layers=2,
        expansion_factor=4,
        n_heads=10,
        masked=True,
    ):
        super(TransformerEncoder, self).__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.h_dim = h_dim
        self.masked = masked
        self.expansion_factor = expansion_factor
        self.n_heads = n_heads
        self.num_layers = num_layers
        self.embed = Embedding(self.input_dim, self.embed_dim)
        self.layers = nn.ModuleList(
            [
                TransformerLayer(
                    self.embed_dim[-1],
                    self.h_dim,
                    self.expansion_factor,
                    self.n_heads,
                    self.masked,
                )
                for i in range(self.num_layers)
            ]
        )

    def forward(self, x, u):
        x_new = self.embed(x)
        n = nn.LayerNorm(x_new.size(-1))
        out = n(x_new)

        for layer in self.layers:
            out = layer(out, u)

        return out
    
class IAFormer(nn.Module):
    def __init__(
        self,
        f_dim,
        n_particles,
        U_features,
        k=7,
        n_Transformer=2,
        n_GNNLayers = 4,
        h_dim=200,
        expansion_factor=4,
        n_heads=10,
        masked=True,
        pooling='avg',
        embed_dim=[128,512,128],
        U_dim = [128,64,64,10],
        mlp_f_dim=[128,64]):
        super(IAFormer, self).__init__()
        
        """
        Args:
           f_dim: int, number  of the feature tokens
           n_particles: int, number  of the particle tokens
           U_features: int, number of the featires in the pairwise interaction matrix
           n_Transformer: int, number of Transformer layers
           h_dim: int, hidden dim of the Q,K and V
           expansion_factor: int, expansion of the size of the internal MLP layers in the Transformer layers.
           n_heads: int, number of attention heads
           masked: boolean, to use the attention mask
           Pooling: str, define the pooling kind, avg, max and sum
           embed_dim: list, define the number of neurons in the MLP for features embedding
           U_dim: list, define the number of neuron in the MLP for pairwise interaction embedding.
                                                                      The last number must equals the number of attention heads
           mlp_f_dim: list, define the number of neurons in the final MLP   
         
        return:
                transformer netwirk with pairwise interaction matrix included.
        """
        self.f_dim=f_dim
        self.n_particles=n_particles
        self.U_features = U_features
        self.k = k
        self.n_Transformer=n_Transformer
        self.n_GNNLayers =n_GNNLayers
        self.n_heads= n_heads
        self.masked = masked
        self.expansion_factor = expansion_factor
        self.h_dim = h_dim
        self.pooling=pooling
        self.embed_dim=embed_dim
        self.mlp_f_dim = mlp_f_dim
        self.U_dim = U_dim
        self.mlp = mlp(self.n_particles,self.mlp_f_dim)
        self.U_embeding =Embedding(self.U_features,self.U_dim)
        self.embed = Embedding(self.f_dim,self.embed_dim)
        self.encoder = TransformerEncoder(self.f_dim,embed_dim=self.embed_dim,h_dim=self.h_dim, num_layers=self.n_Transformer, 
               expansion_factor=self.expansion_factor, n_heads=self.n_heads,masked=self.masked)
        self.GNNencoder = EEDGCNEncoder(self.embed_dim[-1],self.embed_dim[-1],self.U_dim[-1],self.U_dim[-1],self.k,self.n_GNNLayers,self.pooling)     
        self.nW = nn.Linear(self.U_dim[-1],self.n_heads)
        self.nH = nn.LayerNorm(self.f_dim)
        self.nE = nn.LayerNorm(self.U_features)
    
    def forward(self,input_T,input_E):
        ''' 
        input_T: dim (batch, particle tokens, feature tokens)
        input_E: dim (batch, particle tokens, particle tokens, pairwise features)
        '''
        inp_E = self.U_embeding(input_E)
        inp_T = self.embed(input_T)
        out_H,out_E = self.GNNencoder(inp_T,inp_E)
        
        inp_E_T= torch.permute(self.nW(out_E),(0,-1,1,2))
        
        Transformer_out = self.encoder(input_T,inp_E_T)
        if self.pooling == 'avg':
            Transformer_output = Transformer_out.mean(dim=2)
            out_H_ = out_H.mean(dim=2)

        elif self.pooling == 'max':
            Transformer_output = Transformer_out.max(dim=2)[0]
            out_H_ = out_H.max(dim=2)[0]

        elif self.pooling == 'sum':
            Transformer_output = Transformer_out.sum(dim=2)
            out_H_ = out_H.sum(dim=2)
          
        output_c =Transformer_output+out_H_ #torch.cat((Transformer_output,out_H_),dim=-1)
        output = self.mlp(output_c)   
      
        return  output
    
    

In [3]:
x = torch.rand(5, 100, 11)
x_edge_features = torch.rand(5, 100, 100, 4)

In [5]:
D = 11 # number of node features
N = 100 # number of the particles in the event
f = 4   # number of edge features
n_Transformer=8 # Number of Transformer layers
n_GNN = 3
k = 7 
expansion_factor=4 # Expansion factor of the internal MLP in the Transformer layers
n_heads=15  # Number of attention heads
masked=True # If mask is used
pooling='avg' #pooling type, max, avg or sum
embed_dim=[256,128,64] # input embedding layers 
h_dim=embed_dim[-1] # hidden dim of the scaling matrices
U_dim = [512,256,128,64] # Embedding layers of the edge matrix
mlp_f_dim=[512,128,64] # layers of the final MLP

model = IAFormer(D,
        N,
        f,
        k=k,
        n_Transformer=n_Transformer,
        n_GNNLayers = n_GNN,
        h_dim=h_dim,
        expansion_factor=expansion_factor,
        n_heads=n_heads,
        masked=masked,
        pooling=pooling,
        embed_dim=embed_dim,
        U_dim =U_dim,
        mlp_f_dim=mlp_f_dim)
model

IAFormer(
  (mlp): mlp(
    (layers): ModuleList(
      (0): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=100, out_features=128, bias=True)
      (2): GELU(approximate='none')
      (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=128, out_features=64, bias=True)
      (5): GELU(approximate='none')
      (6): Linear(in_features=64, out_features=1, bias=True)
      (7): Sigmoid()
    )
  )
  (U_embeding): Embedding(
    (layers): ModuleList(
      (0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=4, out_features=256, bias=True)
      (2): GELU(approximate='none')
      (3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): GELU(approximate='none')
      (6): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (7): Linear(in_features=128, out_features=64, bias=True)
      (8): GELU(approx

In [6]:
model(x, x_edge_features)

tensor([[0.4507],
        [0.4717],
        [0.6448],
        [0.5296],
        [0.5503]], grad_fn=<SigmoidBackward0>)