In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import scatter

# -------------------
# 1. Graph Encoder
# -------------------
class GraphEncoder(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, hidden_dim):
        super(GraphEncoder, self).__init__()
        self.node_mlp = nn.Sequential(
            nn.Linear(node_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.node_mlp(x)
        edge_attr = self.edge_mlp(edge_attr)
        return x, edge_attr

# -------------------
# 2. Graph Neural Network (Message Passing)
# -------------------
class GraphNetwork(MessagePassing):
    def __init__(self, hidden_dim):
        super(GraphNetwork, self).__init__(aggr='mean')  # Aggregation method
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        return self.edge_mlp(torch.cat([x_i, x_j, edge_attr], dim=-1))

    def update(self, aggr_out, x):
        return self.node_mlp(torch.cat([x, aggr_out], dim=-1))

# -------------------
# 3. Clustering (Graph Pooling)
# -------------------
class GraphPooling(nn.Module):
    def __init__(self, hidden_dim):
        super(GraphPooling, self).__init__()
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

    def forward(self, x, clusters):
        cluster_embeddings = []
        for cluster in clusters:
            cluster_x = x[cluster]
            cluster_x = cluster_x.unsqueeze(0)  # Add batch dimension
            _, h = self.gru(cluster_x)
            cluster_embeddings.append(h.squeeze(0))
        return torch.stack(cluster_embeddings, dim=0)

# -------------------
# 4. Transformer (Self-Attention)
# -------------------
class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(TransformerLayer, self).__init__()
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.ReLU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        x = x + attn_output
        x = x + self.ffn(x)
        return x

# -------------------
# 5. Decoder (Upsampling and Prediction)
# -------------------
class GraphDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(GraphDecoder, self).__init__()
        self.gnn = GraphNetwork(hidden_dim)
        self.output_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.gnn(x, edge_index, edge_attr)
        return self.output_mlp(x)

# -------------------
# 6. Full Mesh Transformer Architecture
# -------------------
class MeshTransformer(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, hidden_dim, output_dim, num_heads, num_transformer_layers):
        super(MeshTransformer, self).__init__()
        self.encoder = GraphEncoder(node_input_dim, edge_input_dim, hidden_dim)
        self.gnn = GraphNetwork(hidden_dim)
        self.pooling = GraphPooling(hidden_dim)
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(hidden_dim, num_heads) for _ in range(num_transformer_layers)
        ])
        self.decoder = GraphDecoder(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_attr, clusters):
        # Encode graph
        x, edge_attr = self.encoder(x, edge_index, edge_attr)
        x = self.gnn(x, edge_index, edge_attr)

        # Pool clusters
        cluster_embeddings = self.pooling(x, clusters)

        # Apply transformer layers
        for layer in self.transformer_layers:
            cluster_embeddings = layer(cluster_embeddings)

        # Decode back to node-level predictions
        return self.decoder(x, edge_index, edge_attr)


In [2]:
import torch

# Define the dimensions
node_input_dim = 3    # Example: position, velocity, pressure
edge_input_dim = 2    # Example: relative position and distance
hidden_dim = 16
output_dim = 2        # Example: future velocity and pressure
num_heads = 4
num_transformer_layers = 2

# Instantiate the model
model = MeshTransformer(
    node_input_dim, edge_input_dim, hidden_dim, output_dim, num_heads, num_transformer_layers
)

# Create dummy input data
num_nodes = 100
num_edges = 300
num_clusters = 10

x = torch.rand((num_nodes, node_input_dim))  # Node features
edge_index = torch.randint(0, num_nodes, (2, num_edges))  # Edge index
edge_attr = torch.rand((num_edges, edge_input_dim))  # Edge features
clusters = [torch.randint(0, num_nodes, (num_nodes // num_clusters,)) for _ in range(num_clusters)]  # Clusters

# Forward pass through the model
output = model(x, edge_index, edge_attr, clusters)

# Check the output
print("Output shape:", output.shape)  # Should match (num_nodes, output_dim)


Output shape: torch.Size([100, 2])
