<a href="https://colab.research.google.com/github/yawmid/3rd-Year-Project/blob/main/meshgraphnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Accessing data from google drive**

Organising my data in Google Drive was convenient for me

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## **Installing and importing relevant libraries**

In [None]:
pip install trimesh torch-geometric

*Some of the libraries only work with an older version of pytorch, specifically torch-scatter and torch-sparse. They haven't been updated yet to torch 2.9.0. The following code uninstalls torch 2.9.0 and reinstalls torch 2.8.0. First check which PyTorch and CUDA version is installed. Verify that it is supported using this link https://data.pyg.org/whl/*

*If it is, you don't need to bother with changing the PyTorch version*

In [None]:
import torch
import sys

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Version: {torch.version.cuda}")
print(f"Python Version: {sys.version.split()[0]}")

*If not run the code below and change the install link accordingly*

In [None]:
#pip uninstall torch torchvision torchaudio -y

In [None]:
#pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0

In [None]:
import os
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cu128.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.8.0+cu128.html
!pip install torch-geometric
!pip install -q git+https://github.com/snap-stanford/deepsnap.git

In [None]:
import trimesh
import numpy as np
import pandas as pd
import os
from scipy.spatial import cKDTree
from tqdm import trange
import random
import glob
from tqdm import tqdm
import os.path as osp

import torch.nn as nn
from torch.nn import Linear, Sequential, LayerNorm, ReLU
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch_scatter

## **Preparing and loading datasets**

In [None]:
%cd '/content/drive/MyDrive/3rd_year_project/Training'

root_dir = os.getcwd()
dataset_dir = os.path.join(root_dir, 'Graphs/dataset.pt')
checkpoint_dir = os.path.join(root_dir, 'checkpoints')
postprocess_dir = os.path.join(root_dir, 'animations')

print(dataset_dir)

##**Building graphs**


This code loads the STL and WSS data. It first loads the STL using trimesh.load, we then have to merge the vertices, this prevents a vertex being counted more than once, the number of vertices should be equal to the number of nodes we get from the ansys data.

**Geom**  is the actual geometrical data for the artery, it has dimensions [num_nodes,3], (x,y,z) coordinates for each node.

**Edge index** describes how our nodes are connected to each other, they are the edges of our graph

**Edge Attributes**, this includes the relative distance (x), relative distance (y), relative distance (z) and the total length. It's important that we encode these distances between nodes, the physics of a fluid particle between adjacent nodes depends on the distance and the direction between them, without this the model may struggle to learn the physics of the system.

There is also an alignment procedure coded in. When the WSS data was exported from Ansys, it was ordered by node number, but there was no guarantee that it aligned with the nodes from the STL (it turns out it did line up fairly perfectly but this was not known at the time). tree = cKDtree(target_pos) creates a spatial map in a tree structure, it then takes every vertex in the STL map and matches it to the closest node (spatially) in the CSV. After running, it turns out the max distance between the nodes was zero, the data was already aligned to the STL.

For each node we also have input features. The node type has to be defined as specified by the original research paper https://arxiv.org/abs/2010.03409. The other node features chosen at the moment include, input flow and Reynold's number.

A lot of this code is based on the work of Isaac Ju, Robert Lupoiu, and Rayan Kanfar, as part of the Stanford CS224W course project.

The link to their medium post is provided here: https://medium.com/stanford-cs224w/learning-mesh-based-flow-simulations-on-graph-networks-44983679cf2d.

The link to their Colab notebook is provided here: https://colab.research.google.com/drive/1mZAWP6k9R0DE5NxPzF8yL2HpIUG3aoDC?usp=sharing#scrollTo=n33F2kSeJlQ3

Gemini 3 pro was used to debug and was extremely useful when I hit a wall a couple of times.



In [None]:
def process_single_sample(stl_path, wss_path, flow_input):

    # --- 1. Load Geometry (STL) ---
    mesh = trimesh.load(stl_path, force = 'mesh')
    mesh.merge_vertices()

    geom = torch.tensor(mesh.vertices, dtype=torch.float)
    num_nodes = geom.shape[0]

    # Build Edges
    faces = mesh.faces
    edges = []
    for face in faces:
        edges.append([face[0], face[1]])
        edges.append([face[1], face[2]])
        edges.append([face[2], face[0]])
        # Bidirectional
        edges.append([face[1], face[0]])
        edges.append([face[2], face[1]])
        edges.append([face[0], face[2]])

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_index = torch.unique(edge_index, dim=1)


    # --- 3. CALCULATE EDGE ATTRIBUTES
    # compute relative position (u_ij) and distance (norm)
    # row = source nodes, col = target nodes
    row, col = edge_index

    u_i = geom[row]
    u_j = geom[col]

    # Relative position vector (Direction)
    u_ij = u_i - u_j

    # Euclidean distance (Magnitude)
    u_ij_norm = torch.norm(u_ij, p=2, dim=1, keepdim=True)

    # Edge Attributes: [dx, dy, dz, distance]
    edge_attr = torch.cat((u_ij, u_ij_norm), dim=-1).type(torch.float)

    #2. Process Features (Flow Input + Node Type) ---
    flow_tensor = torch.tensor(flow_input, dtype=torch.float).unsqueeze(0).repeat(num_nodes, 1)

    node_type = torch.zeros((num_nodes, 1), dtype=torch.float) # 0 for Wall

    x = torch.cat([flow_tensor, node_type], dim=1)

    # 3. Load WSS
    df = pd.read_csv(wss_path, skiprows=5)
    df.columns = [c.strip() for c in df.columns]
    target_pos = df[['X [ m ]', 'Y [ m ]', 'Z [ m ]']].values
    wss_data = df['Wall Shear Z [ Pa ]'].values

    # === ALIGNMENT BLOCK ===
    # We find which row in the CSV corresponds to each vertex in our mesh
    tree = cKDTree(target_pos)
    distances, indices = tree.query(mesh.vertices)


    # Ensure the matching is accurate, if this warning comes up, check over your data
    if np.max(distances) > 1e-4:
        print(f"Warning: High alignment error ({np.max(distances)}). Check units (m vs mm).")

    # Reorder the WSS data to match the mesh
    wss_aligned = wss_data[indices]

    # =======================
    y = torch.tensor(wss_aligned, dtype=torch.float)
    if y.ndim == 1:
        y = y.unsqueeze(1)

    # --- 4. Build data object
    data = Data(features=x, edge_index=edge_index,edge_attr= edge_attr, y=y, pos=geom)

    return data

*One instance to make sure the graph building works*

In [None]:
stl_path = '/content/drive/MyDrive/3rd_year_project/Training/STLs/CFD batch00001_wall.stl'
wss_path = '/content/drive/MyDrive/3rd_year_project/Training/WSS/CFD batch00001_wss.csv'


data = process_single_sample(stl_path, wss_path, flow_input=[1,0])
print(data)

In [None]:
import os
import glob
import torch
import pandas as pd
from tqdm import tqdm
# Ensure you have your imports: trimesh, numpy, etc.

def create_dataset_individual_files(stl_dir, wss_dir, flow_input, output_dir):
    """
    Processes STLs and saves them as individual .pt files in output_dir.
    """

    # --- 1. Setup Output Directory ---
    if not os.path.exists(output_dir):
        print(f"Creating output directory: {output_dir}")
        os.makedirs(output_dir)
    else:
        print(f"Output directory exists: {output_dir}")

    # --- 2. Load Flow Rates Table ---
    print(f"Loading flow rates from {flow_input}...")
    flow_df = pd.read_csv(flow_input)
    flow_map = dict(zip(flow_df['Filename'], flow_df['FlowRate_kg_s']))

    # --- 3. Find all STL files ---
    stl_files = sorted(glob.glob(os.path.join(stl_dir, "*.stl")))
    print(f"Found {len(stl_files)} STL files. Starting batch processing...")

    success_count = 0

    for stl_path in tqdm(stl_files):
        try:
            # A. Extract ID
            filename = os.path.basename(stl_path)
            file_id = os.path.splitext(filename)[0]
            file_id = file_id.replace("_wall", "")

            # B. Define Save Path
            # We save it as "Case123.pt"
            save_path = os.path.join(output_dir, f"{file_id}.pt")

            # C. Check if already done (Resume capability)
            if os.path.exists(save_path):
                # print(f"Skipping {file_id} (already exists).")
                continue

            # D. Find Flow Rate
            if file_id not in flow_map:
                print(f"Skipping {file_id}: ID not found in flow_rates.csv")
                continue
            flow_val = flow_map[file_id]

            # E. Find WSS file
            wss_filename = f"{file_id}_wss.csv"
            wss_path = os.path.join(wss_dir, wss_filename)

            if not os.path.exists(wss_path):
                print(f"Skipping {file_id}: WSS file not found")
                continue

            # F. Process Graph (Using your existing function)
            graph = process_single_sample(stl_path, wss_path, [flow_val])

            # --- G. SAVE INSTANTLY ---
            torch.save(graph, save_path)
            success_count += 1

        except Exception as e:
            print(f"FAILED to process {filename}: {str(e)}")
            continue

    print(f"Processing complete. Saved {success_count} new graphs to {output_dir}")

In [None]:
stl_path = '/content/drive/MyDrive/3rd_year_project/Training/STLs'
wss_path = '/content/drive/MyDrive/3rd_year_project/Training/WSS'
flow_rates = '/content/drive/MyDrive/3rd_year_project/Training/flow_rates.csv'
output_dir = '/content/drive/MyDrive/3rd_year_project/Training/Graphs'

create_dataset_individual_files(stl_path, wss_path, flow_rates, output_dir)

In [None]:
class GraphFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None, pre_transform=None):
        self.root_dir = root_dir
        # Get list of all .pt files in the directory
        self.file_list = sorted(glob.glob(os.path.join(root_dir, '*.pt')))
        super(GraphFolderDataset, self).__init__(root_dir, transform, pre_transform)

    def len(self):
        return len(self.file_list)

    def get(self, idx):
        # Load the specific file corresponding to idx
        data = torch.load(self.file_list[idx],weights_only=False)
        return data

# **Normalising**

Normalising is common practice is machine learning. The input features and output parameters are scaled to a uniform range, this prevents features with a large numerical range from dominating the learning process, it stablises training


In [None]:
def normalize(to_normalize, mean_vec, std_vec):
    return (to_normalize - mean_vec) / std_vec

def unnormalize(to_unnormalize, mean_vec, std_vec):
    return to_unnormalize * std_vec + mean_vec

def get_stats(data_list):

    #mean and std of the node features are calculated
    mean_vec_x = torch.zeros(data_list[0].x.shape[1:])
    std_vec_x = torch.zeros(data_list[0].x.shape[1:])

    #mean and std of the edge features are calculated
    mean_vec_edge = torch.zeros(data_list[0].edge_attr.shape[1:])
    std_vec_edge = torch.zeros(data_list[0].edge_attr.shape[1:])

    mean_vec_y = torch.zeros(data_list[0].y.shape[1:])
    std_vec_y = torch.zeros(data_list[0].y.shape[1:])

    # This prevents your computer from breaking or more accurately
    #Define the maximum number of accumulations to perform such that we do
    #not encounter memory issues

    max_accumulations = 10**8

    #Define a very small value for normalising to
    eps = torch.tensor(1e-8)

    num_accs_x = 0
    num_accs_edge = 0
    num_accs_y = 0

    for dp in data_list:
        # Accumulate sums and squared sums
        mean_vec_x += torch.sum(dp.x, dim=0)
        std_vec_x += torch.sum(dp.x**2, dim=0)
        num_accs_x += dp.x.shape[0]

        mean_vec_edge += torch.sum(dp.edge_attr, dim=0)
        std_vec_edge += torch.sum(dp.edge_attr**2, dim=0)
        num_accs_edge += dp.edge_attr.shape[0]

        mean_vec_y += torch.sum(dp.y, dim=0)
        std_vec_y += torch.sum(dp.y**2, dim=0)
        num_accs_y += dp.y.shape[0]

        if(num_accs_x>max_accumulations or num_accs_edge>max_accumulations or num_accs_y>max_accumulations):
            break

    # Calculate Mean & Std
    mean_vec_x = mean_vec_x / num_accs_x
    std_vec_x = torch.maximum(torch.sqrt(std_vec_x / num_accs_x - mean_vec_x**2), eps)

    mean_vec_edge = mean_vec_edge / num_accs_edge
    std_vec_edge = torch.maximum(torch.sqrt(std_vec_edge / num_accs_edge - mean_vec_edge**2), eps)

    mean_vec_y = mean_vec_y / num_accs_y
    std_vec_y = torch.maximum(torch.sqrt(std_vec_y / num_accs_y - mean_vec_y**2), eps)

    return [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y]


# **Encoder**

MeshGraphNets have three main components in its architecture: encoder, processor and decoder

The encoder takes the the input data and resizes in a way that we define and which we can use to train. It also allows us to have different dimensions for our edges and nodes as they are mapped onto a higher dimension anyway. For example, we are currently using [x, y , z, d] for the edge attribute (dimension = 4) whereas for the nodes we are using [node type, Input flow, Reynolds number], (dimension = 3). It maps the input data to a higher dimensional vector, with size 128 being used in the original deepmind paper. We have two separate encoders the edge and node encoders. The encoding process isn't explicitly determined, it is trained using an MLP with ReLu activation.

Mathematically, the node encoding is simply:

$$
\mathbf{z}_{i} = \textrm{MLP} ( \mathbf{h}_{i}) \;      \forall i \in V
$$

the edge encoding is thus:

$$
\mathbf{z}_{ij} = \textrm{MLP} ( \mathbf{h}_{ij}) \;      \forall edges \in E
$$

# **Processor**

The processor is the GNN message passing, aggregation, and update part of the architecture. It takes the graph with the new features generated by the encoder through the GNN pipeline: message, aggregation, and updates for the number of layers chosen.
The processing layers of the MeshGraphNets is handled by a separate class, ProcessorLayer, which inherits from the PyG MessagePassing base class. The message is a learned transformation of MLP with skip connection on the self edge embedding concatenated with the embeddings of the conneccting nodes.

The aggregation is done in two steps:
 1) sum over the connected edges of each node
 2) another MLP transformation of the edge sum concatenated with the self node.


 ### **Decoder**


The decoder is a postprocessing step. It takes the node updates from the processor and maps it into a change in WSS using another separately learned MLP.










In [None]:
class MeshGraphNet(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim, output_dim, args, emb=False):
        super(MeshGraphNet, self).__init__()
        """
        MeshGraphNet model. This model is built upon Deepmind's 2021 paper.
        This model consists of three parts: (1) Preprocessing: encoder (2) Processor
        (3) postproccessing: decoder. Encoder has an edge and node decoders respectively.
        Processor has two processors for edge and node respectively. Note that edge attributes have to be
        updated first. Decoder is only for nodes.

        Input_dim: dynamic variables + node_type + node_position
        Hidden_dim: 128 in deepmind's paper
        Output_dim: dynamic variables: velocity changes (1)

        """

        self.num_layers = args.num_layers

        # encoder convert raw inputs into higher dimensions.
        self.node_encoder = Sequential(Linear(input_dim_node , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim))

        self.edge_encoder = Sequential(Linear( input_dim_edge , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim)
                              )


        self.processor = nn.ModuleList()
        assert (self.num_layers >= 1), 'Number of message passing layers is not >=1'

        processor_layer=self.build_processor_model()
        for _ in range(self.num_layers):
            self.processor.append(processor_layer(hidden_dim,hidden_dim))


        # decoder: only for node vectors
        self.decoder = Sequential(Linear( hidden_dim , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, output_dim)
                              )


    def build_processor_model(self):
        return ProcessorLayer


    def forward(self,data,mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge):
        """
        Encoder encodes graph (node/edge features) into latent vectors (node/edge embeddings)
        The return of processor is fed into the processor for generating new feature vectors
        """
        x, edge_index, edge_attr, wss = data.x, data.edge_index, data.edge_attr, data.y

        x = normalize(x,mean_vec_x,std_vec_x)
        edge_attr=normalize(edge_attr,mean_vec_edge,std_vec_edge)

        # Step 1: encode node/edge features into latent node/edge embeddings
        x = self.node_encoder(x) # output shape is the specified hidden dimension

        edge_attr = self.edge_encoder(edge_attr) # output shape is the specified hidden dimension

        # step 2: perform message passing with latent node/edge embeddings

        for i in range(self.num_layers):
            # Calculate the update (delta)
            x_res, edge_attr_res = self.processor[i](x, edge_index, edge_attr)

            # Add update to the previous state
            x = x + x_res
            edge_attr = edge_attr + edge_attr_res


        # step 3: decode latent node embeddings into physical quantities of interest

        return self.decoder(x)

    def loss(self, pred, inputs, mean_vec_y, std_vec_y):

        # 1. Normalise the Ground Truth (y) to match the prediction scale
        target = (inputs.y - mean_vec_y) / std_vec_y

        # 2. Calculate Squared Error
        error = (target - pred) ** 2

        # 3. Root Mean Squared
        loss = torch.sqrt(torch.mean(error))

        return loss

#**Processor Layer**

1.   **Message passing**

Initiated by the propagate function, the message function most generally calculates messages, m, for edge u at layer l with function MSG given previous embeddings h_u:
$$m_u^{(l)}=MSG^{(l)}(h_u^{(l-1)})$$

Note that for MeshGraphNets, messages are calculated for edges and passed to nodes. This function thus takes edge embeddings and the adjacent node embeddings and concatenates them. These concatenated previous embeddings constitute h_u above. These are then put through an MLP (our MSG function) to give the final messages, m_u, which are passed to the aggregate function.

2.   **Aggregation**

Aggregation takes the updated edge embeddings and aggregates then over the connectivity matrix indexing using sum reduction. Most generally, we have:

$$h_v^{(l)}=AGG^{(l)}(\{m_u^{(l)},u\in N(v)\})$$

For MeshGraphNets, aggregation (AGG) for node v is sum over the neighbor nodes. However, there is also an additional aggregation step: aggregating with the self embedding. This is done outside of the aggregation function, in the forward function after the return of propagate:

$$h_v^{(l)}=\{h_v^{(l-1)},AGG^{(l)}(\{m_u^{(l)},u\in N(v)\})\}$$

3.   **Updating**

The nodes embeddings are finally updated by passing $h_v^{(l)}$ through the node MLP with a skip connection. This is most generally written as:

$$h_v^{(l)}=Processor(h_v^{(l)})$$

Where for us the Processor is an MLP.

In [None]:
class ProcessorLayer(MessagePassing):
    def __init__(self, in_channels, out_channels,  **kwargs):
        super(ProcessorLayer, self).__init__(  **kwargs )
        """
        in_channels: dim of node embeddings [128], out_channels: dim of edge embeddings [128]

        """

        # Note that the node and edge encoders both have the same hidden dimension
        # size. This means that the input of the edge processor will always be
        # three times the specified hidden dimension
        # (input: adjacent node embeddings and self embeddings)
        self.edge_mlp = Sequential(Linear( 3* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.node_mlp = Sequential(Linear( 2* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))


        self.reset_parameters()

    def reset_parameters(self):
        """
        reset parameters for stacked MLP layers
        """
        self.edge_mlp[0].reset_parameters()
        self.edge_mlp[2].reset_parameters()

        self.node_mlp[0].reset_parameters()
        self.node_mlp[2].reset_parameters()

    def forward(self, x, edge_index, edge_attr, size = None):
        """
        Handle the pre and post-processing of node features/embeddings,
        as well as initiates message passing by calling the propagate function.

        Note that message passing and aggregation are handled by the propagate
        function, and the update

        x has shpae [node_num , in_channels] (node embeddings)
        edge_index: [2, edge_num]
        edge_attr: [E, in_channels]

        """

        out, updated_edges = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = size) # out has the shape of [E, out_channels]

        updated_nodes = torch.cat([x,out],dim=1)        # Complete the aggregation through self-aggregation

        updated_nodes = x + self.node_mlp(updated_nodes) # residual connection

        return updated_nodes, updated_edges

    def message(self, x_i, x_j, edge_attr):
        """
        source_node: x_i has the shape of [E, in_channels]
        target_node: x_j has the shape of [E, in_channels]
        target_edge: edge_attr has the shape of [E, out_channels]

        The messages that are passed are the raw embeddings. These are not processed.
        """

        updated_edges=torch.cat([x_i, x_j, edge_attr], dim = 1) # tmp_emb has the shape of [E, 3 * in_channels]
        updated_edges=self.edge_mlp(updated_edges)+edge_attr

        return updated_edges

    def aggregate(self, updated_edges, edge_index, dim_size = None):
        """
        First we aggregate from neighbours (i.e., adjacent nodes) through concatenation,
        then we aggregate self message (from the edge itself). This is streamlined
        into one operation here.
        """

        # The axis along which to index number of nodes.
        node_dim = 0

        out = torch_scatter.scatter(updated_edges, edge_index[0, :], dim=node_dim, reduce = 'sum')

        return out, updated_edges

# **Building the optimiser**

This defines a function which will allow us to play with the optimiser and learning rates. The Adam optimiser will be used by default as this was used in the original DeepMind paper.

This can altered aftewards to see if the performance of the network can be improved.

In [None]:
def build_optimizer(args, params):
    weight_decay = args.weight_decay
    # Filter parameters that require gradients
    filter_fn = filter(lambda p : p.requires_grad, params)

    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)

    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)

    return scheduler, optimizer

# **Main Training Loop**

In [None]:
def train(dataset, device, stats_list, args):
    # --- Setup ---
    # Convert stats to device
    stats_list = [s.to(device) for s in stats_list]
    [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y] = stats_list

    # Create DataLoaders
    # dataset[:n] slicing works if dataset is a list. If it's a PyG dataset object, use indexing carefully.
    train_loader = DataLoader(dataset[:args.train_size], batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(dataset[args.train_size:], batch_size=args.batch_size, shuffle=False)

    # Initialize Model
    num_node_features = dataset[0].x.shape[1]
    num_edge_features = dataset[0].edge_attr.shape[1]
    num_classes = 1 # WSS is scalar

    model = MeshGraphNet(num_node_features, num_edge_features, args.hidden_dim, num_classes, args).to(device)

    # You need to define build_optimizer or use standard Adam
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


    # Logging
    results_data = [] # Store dicts here, create DF at end
    model_name = f'model_nl{args.num_layers}_hd{args.hidden_dim}_lr{args.lr}'

    best_test_loss = np.inf
    best_model = None

    # --- Training Loop ---
    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0
        model.train()
        num_loops = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            # Forward
            pred = model(batch, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)

            # Loss (Model handles normalization of target internally)
            loss = model.loss(pred, batch, mean_vec_y, std_vec_y)

            # Backward
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_loops += 1

        avg_train_loss = total_loss / num_loops

        # --- Evaluation (Every 10 epochs) ---
        if epoch % 10 == 0:
            # We always run validation to check progress
            test_loss, wss_rmse = test(test_loader, device, model, stats_list, is_validation=True)

            # Log results
            results_data.append({
                'epoch': epoch,
                'train_loss': avg_train_loss,
                'test_loss': test_loss,
                'wss_rmse': wss_rmse
            })

            # Checkpoint Best Model
            if test_loss < best_test_loss:
                best_test_loss = test_loss
                best_model = copy.deepcopy(model)
                if args.save_best_model:
                    torch.save(best_model.state_dict(), os.path.join(args.checkpoint_dir, model_name + '.pt'))

            # Print Progress
            print(f" Ep {epoch}: Train Loss {avg_train_loss:.4f} | Test Loss {test_loss:.4f} | WSS RMSE {wss_rmse:.4f} Pa")

            # Save CSV log
            if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir)
            pd.DataFrame(results_data).to_csv(os.path.join(args.checkpoint_dir, model_name + '.csv'), index=False)

    return best_model, best_test_loss

def test(loader, device, test_model, stats_list, is_validation=False):
    '''
    Calculates test set losses and validation set errors (RMSE in Pascals).
    '''
    # Unpack stats
    [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y] = stats_list

    loss_accum = 0
    wss_rmse_accum = 0
    num_loops = 0

    test_model.eval() # Set model to evaluation mode

    for data in loader:
        data = data.to(device)

        with torch.no_grad():
            # 1. Forward Pass
            pred = test_model(data, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)

            # 2. Calculate Loss (Normalized Space) - used for learning curves
            # The model.loss function handles normalization of data.y internally
            batch_loss = test_model.loss(pred, data, mean_vec_y, std_vec_y)
            loss_accum += batch_loss.item()

            # 3. Calculate Validation Error (Physical Space - Pascals)
            if is_validation:
                # A. Unnormalize the predictions back to Pascals
                pred_physical = unnormalize(pred, mean_vec_y, std_vec_y)

                # B. Get the ground truth (data.y is already physical/raw)
                true_physical = data.y

                # C. Calculate RMSE directly

                error = (pred_physical - true_physical) ** 2
                batch_rmse = torch.sqrt(torch.mean(error))

                wss_rmse_accum += batch_rmse.item()

        num_loops += 1

    avg_test_loss = loss_accum / num_loops

    if is_validation:
        avg_wss_rmse = wss_rmse_accum / num_loops
    else:
        avg_wss_rmse = 0.0

    return avg_test_loss, avg_wss_rmse

*Specify parameters for model training*

In [None]:
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

for args in [
        {'model_type': 'meshgraphnet',
         'num_layers': 10,
         'batch_size': 16,
         'hidden_dim': 10,
         'epochs': 500,
         'opt': 'adam',
         'opt_scheduler': 'none',
         'opt_restart': 0,
         'weight_decay': 5e-4,
         'lr': 1e-4,
         'train_size': 45,
         'test_size': 10,
         'device':'cuda',
         'shuffle': True,
         'save_wss_val': True,
         'save_wss_model': True,
         'checkpoint_dir': './checkpoints/',
         'postprocess_dir': './plots/'},
    ]:
        args = objectview(args)

#To ensure reproducibility the best we can, here we control the sources of
#randomness by seeding the various random number generators used in this Colab

torch.manual_seed(5)  #Torch
random.seed(5)        #Python
np.random.seed(5)     #NumPy

*Load dataset*

In [None]:
dataset = GraphFolderDataset(output_dir)[:(args.train_size+args.test_size)]


*Shuffle the dataset and get the statistics of the dataset*

In [None]:
stats_list = get_stats(dataset)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.device = device
print(device)

# **Training**

In [None]:
test_losses, losses, wss_val_losses, best_model, best_test_loss, test_loader = train(dataset, device, stats_list, args)

print("Min test set loss: {0}".format(min(test_losses)))
print("Minimum loss: {0}".format(min(losses)))
if (args.save_wss_val):
    print("Minimum wss validation loss: {0}".format(min(wss_val_losses)))