In [63]:
data_path = "./AllFields_Resultats_MESH_1.xdmf"
from mesh_handler import *
import numpy as np
import torch
import meshio

In [64]:
meshes  = xdmf_to_meshes(data_path)

Loaded 80 timesteps from AllFields_Resultats_MESH_1.xdmf



In [69]:
mesh = meshes[0]
# Charger le maillage tétraédrique
def torch_input_features(mesh):
    # Extraire les éléments tétraédriques
    if "tetra" in mesh.cells_dict:
        tetrahedrons = mesh.cells_dict["tetra"]
    else:
        raise ValueError("Le maillage ne contient pas de tétraèdres.")

    # Construction des arêtes en reliant les sommets des tétraèdres
    edges = np.vstack([
        tetrahedrons[:, [0, 1]],  # Arête entre sommets 0 et 1
        tetrahedrons[:, [0, 2]],  # Arête entre sommets 0 et 2
        tetrahedrons[:, [0, 3]],  # Arête entre sommets 0 et 3
        tetrahedrons[:, [1, 2]],  # Arête entre sommets 1 et 2
        tetrahedrons[:, [1, 3]],  # Arête entre sommets 1 et 3
        tetrahedrons[:, [2, 3]]   # Arête entre sommets 2 et 3
    ])

    # Suppression des doublons et tri des indices pour éviter les arêtes en double
    edges = np.unique(np.sort(edges, axis=1), axis=0)

    # Convertir en format edge_index pour PyTorch Geometric (format [2, num_edges])
    edge_index = torch.tensor(edges.T, dtype=torch.int64)

    print("Edge index format:")
    print(edge_index)


    nodes_xyz = mesh.points
    nodes_v = mesh.point_data["Vitesse"]
    nodes_P = mesh.point_data["Pression"]
    tetras = mesh.cells_dict["tetra"]
    nodes_all = torch.tensor(np.hstack((nodes_xyz, nodes_v, np.expand_dims(nodes_P, axis=1))), dtype= torch.float)
    return nodes_all, edge_index


def torch_output_features(mesh):

    nodes_v = mesh.point_data["Vitesse"]
    nodes_P = mesh.point_data["Pression"]
    nodes_all = torch.tensor(np.hstack((nodes_v, np.expand_dims(nodes_P, axis=1))), dtype=torch.float)
    
    return nodes_all


In [72]:
nodes_all, edge_index = torch_features(mesh)
ground_truth = torch_output_features(meshes[1])


Edge index format:
tensor([[    0,     0,     0,  ..., 11422, 11430, 11435],
        [    1,    76,   726,  ..., 11425, 11434, 11443]])


In [97]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class CFDGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(CFDGNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x_0 = x[:,3:]
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index) 
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.lin(x) + x_0
        return  x

# Initialisation
model = CFDGNN(in_channels=7, hidden_channels=16, out_channels=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)




In [100]:
for epoch in range(1000):
    optimizer.zero_grad()
    out = model(nodes_all, edge_index)

    loss = F.mse_loss(out, ground_truth)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch}, Loss: {loss.item()}')

In [102]:
loss

tensor(1058.9080, grad_fn=<MseLossBackward0>)