# Tutorial 2: Firedrake + ML
#### Author: Nacime Bouziani

In [None]:
import os
import matplotlib.pyplot as plt
import sys

try:
    import firedrake
except ImportError:
    !wget "https://fem-on-colab.github.io/releases/firedrake-install-release-real.sh" -O "/tmp/firedrake-install.sh" && bash "/tmp/firedrake-install.sh"
    import firedrake

from firedrake import *
from firedrake.adjoint import *
from firedrake.ml.pytorch import *
from firedrake.pyplot import triplot, tripcolor, streamplot

continue_annotation()

try:
  import physics_driven_ml
except:
  !git clone https://github.com/NBoulle/physics-driven-ml.git /content/physics-driven-ml
  !pip install -e /content/physics-driven-ml
  sys.path.append("/content/physics-driven-ml")
  import physics_driven_ml

from physics_driven_ml.dataset_processing import StokesDataset

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import Module, ModuleList, Sequential, Linear, ReLU

try:
  from torch_geometric.nn import MessagePassing
except:
  !pip install torch_geometric
  from torch_geometric.nn import MessagePassing

# Download mesh and dataset
!wget -P stokes_tutorial -c https://github.com/nbouziani/physics-driven-ml/raw/dev/data/datasets/meshes/stokes_cylinder.msh
!wget -P stokes_tutorial -c https://github.com/nbouziani/physics-driven-ml/raw/dev/data/datasets/stokes_tutorial/data.h5

In this tutorial, we employ a physics-driven ML approach that uses GNN to study the flow around a circular cylinder, a well-known test case in CFD. We consider the Stokes equations, which are a simpler version of the Navier-Stokes equations. The Stokes problem is a linear and time-independent PDE problem widely studied.

We are interested in devising a GNN model $\psi$ to learn the following operator:

$$\psi : f ↦ sol$$

where $sol := (u, p)$ is the solution of the following Stokes problem parametrised by a source term $f$:

$$
\begin{equation}
\begin{aligned}
- \Delta u + \nabla p &= f \quad \text{ on } \Omega\\
\nabla \cdot u &= 0 \quad \text{ on } \Omega
\end{aligned}
\end{equation}
$$

with $\Omega$ the domain where the problem is posed, and where $u$ and $p$ refer to the velocity field and pressure, respectively. We further equip our PDE problem with boundary conditions.

## Physical problem: flow around a circular cylinder


In [None]:
# Import mesh
mesh = Mesh("stokes_tutorial/stokes_cylinder.msh")

# Define mesh labels
inlet = 1
circle = 4
bottom_top = (3, 5)

# Plot mesh
fig, axes = plt.subplots(1, 1, figsize=(15, 5))
triplot(mesh, axes=axes);

### Stokes problem

We can now define the Stokes problem using the Firedrake finite element software (Ham et al., 2023).

#### Define the PDE problem

In [None]:
# TODO: Define the function spaces, we are going to use piecewise quadratic elements for the velocity and piecewise linear elements for the pressure
V = ...
W = ...
Z = ...

# Define source term
f = Function(V)

# TODO: Define the boundary conditions: u = (1, 0) at the inlet, u = (0, 0) at the top and bottom boundaries, and no-slip at the cylinder
g = ...
bcs = ...

# Set nullspace
nullspace = MixedVectorSpaceBasis(Z, [Z.sub(0), VectorSpaceBasis(constant=True)])

# Define solution
up = Function(Z)

In [None]:
# Solve the PDE
sol_exact = Function(Z)

# TODO: Define the trial and test functions
u, p = ...
v, q = ...

# TODO: Write down the weak form of the PDE
a = ...
L = ...

solve(a == L,
      sol_exact,
      bcs=bcs,
      nullspace=nullspace)

In [None]:
# Helper function to plot
def plot_sol(w):
    u, p = w.subfunctions
    fig, axes = plt.subplots(1, 1, figsize=(15, 5))
    streamlines = streamplot(u, resolution=1/3, seed=0, axes=axes)
    fig.colorbar(streamlines, ax=axes, fraction=0.046)
    axes.set_title("u")

    u1, u2 = u.sub(0), u.sub(1)
    fig, axes = plt.subplots(3, 1, figsize=(15, 10))
    c = tripcolor(u1, cmap="jet", axes=axes[0])
    plt.colorbar(c)
    axes[0].set_title("$u_1$")
    c = tripcolor(u2, cmap="jet", axes=axes[1])
    plt.colorbar(c)
    axes[1].set_title("$u_2$")

    c = tripcolor(p, cmap="jet", axes=axes[2])
    plt.colorbar(c)
    axes[2].set_title("p")

    plt.show()

In [None]:
plot_sol(sol_exact)

#### Define the physical constraint

We want to incorporate physical prior knowledge into our machine learning model. For that we use the interface introduced in (Bouziani & Ham, 2023), to
incorporate a physical constraint, implemented in Firedrake, into the training loss in a similar manner than PINNs. We define that constraint as the residual form associated with the PDE.

In [None]:
# Residual assembly
def assemble_residual(sol, f):
    u, p = split(sol)
    F = ... # TODO: Define the residual
    return ... # TODO: Assemble the residual with the boundary conditions

In [None]:
# Define physics-driven constraint
sol = Function(Z)
f = Function(V)
with set_working_tape() as _:
    # Define PyTorch operator for assembling the residual of the PDE
    F = ReducedFunctional(assemble_residual(sol, f), [Control(sol), Control(f)])
    G = torch_operator(F)

## Message Passing Neural Network (MPNN)

We want to build a Graph Neural Network that follows the Encode-Process-Decode ([Battaglia et al., 2018](https://arxiv.org/abs/1806.01261)) while using a Message Passing Neural Network as processor. We first define the encoder and decoder using a single linear layer:

- Encoder:
$$
\begin{equation}
\begin{aligned}
E :\ &\mathbb{R}^{n} → \mathbb{R}^{l}\\
& x ↦ Wx + b
\end{aligned}
\end{equation}
$$

- Decoder:
$$
\begin{equation}
\begin{aligned}
D :\ &\mathbb{R}^{l} → \mathbb{R}^{m}\\
& x ↦ Wx + b
\end{aligned}
\end{equation}
$$
with $W$ and $b$ learnable parameters, and where $n$, $l$, and $m$ refer to the input, latent, and output dimensions, respectively.

In [None]:
class Encoder(Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        # TODO: Define the encoder network
        self.encoder = ...

    def forward(self, f):
        """Apply the encoder to the input f"""
        # TODO: Define the forward pass
        return ...


class Decoder(Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        # TODO: Define the decoder network
        self.decoder = ...

    def forward(self, h):
        """Apply the decoder to the latent feature vector h"""
        # TODO: Define the forward pass
        return ...

We now need to implement the MPNN corresponding to the processor. For this we implement a model with the following simple update rule:

$$
\begin{equation}
\begin{aligned}
  m^{n}_{ij} &= \phi_{e}(h^{n}_{i}, h^{n}_{j} - h^{n}_{i}) \\
  h^{n+1}_{i} &= \phi_{v}\left(h^{n}_{i}, \frac{1}{|N_{i}|}\sum\limits_{j \in N_{i}} m^{n}_{ij}\right)
\end{aligned}
\end{equation}
$$

where $\phi_{e}$ and $\phi_{v}$ are MLPs. For implementing this, we use the PyTorch geometric library (PyG). In a similar manner than Graph Networks, we compose several blocks of our architecture to form the processor.

In [None]:
class MPNN(MessagePassing):
    def __init__(self, input_dim, latent_dim, output_dim):
        # Set the aggregation function as the mean (permutation-invariant)
        super(MPNN, self).__init__(aggr="mean")
        # TODO: Define ϕe as a sequential model with two linear layers (2x input_dim -> latent_dim -> latent_dim) and ReLU activation
        self.message_mlp = Sequential(...)
        # TODO: Define ϕv as a sequential model with three linear layers (input_dim + latent_dim -> latent_dim -> latent_dim // 2 -> output_dim) and ReLU activations
        self.update_mlp = Sequential(...)

    def forward(self, h, edge_index):
        return self.propagate(edge_index, h=h)

    def message(self, h_i, h_j):
        """Compute the messages m_{ij} given the feature vectors h_{i} and h_{j}."""
        # TODO: Define the message function
        m = ...
        return self.message_mlp(m)

    def update(self, message, h):
        """Compute the update"""
        z = torch.cat([h, message], dim=-1)
        return self.update_mlp(z)

In [None]:
# Do not change this block
class NeuralPDESolver(Module):
    def __init__(self, input_dim, latent_dim, output_dim, num_features, latent_features=2, nlayers=1):
        super(NeuralPDESolver, self).__init__()
        self.nlayers = nlayers
        # Encoder
        print("input_dim: %s latent_dim: %s output_dim: %s" % (num_features, latent_features, num_features))
        self.encoder = Encoder(input_dim=input_dim, latent_dim=latent_dim)
        # Processor
        self.processor_layers = ModuleList(modules=[MPNN(input_dim=num_features,
                                                         latent_dim=latent_features,
                                                         output_dim=num_features)
                                                    for _ in range(self.nlayers)])
        # Decoder
        self.decoder = Decoder(latent_dim=latent_dim, output_dim=output_dim)

    def forward(self, f, edge_index):
        # Encoding
        h = self.encoder(f)[..., None]
        # Processing
        for layer in self.processor_layers:
            h = layer(h, edge_index)
        # Decoding
        sol = self.decoder(h[..., 0])
        return sol

In [None]:
model = NeuralPDESolver(input_dim=V.dim(),
                        latent_dim= V.dim(),
                        num_features=1,
                        output_dim=Z.dim(),
                        nlayers=4)
# Set double precision (default Firedrake type)
model.double()

In [None]:
M = MPNN(input_dim=1, latent_dim=2, output_dim=1)

In [None]:
M.message_mlp

## Training using the physical constraint

### Load dataset

In [None]:
train_dataset = StokesDataset(dataset="stokes_tutorial")
train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=train_dataset.collate, shuffle=False)

### Define evaluation

In [None]:
def evaluate(model, dataloader):
    """Evaluate the model on a given dataset."""

    model.eval()

    total_error = 0.0
    for step_num, batch in enumerate(dataloader):

        f, sol_exact = batch.f, batch.u
        edge_index = batch.edge_index

        with torch.no_grad():
            # TODO: Evaluate the model and computes the mean squared error between the solution and the exact solution
            sol = ...
            total_error += ...

    total_error /= step_num + 1
    return total_error


### Training

We can now train our model. For this we use the physics-driven constraint implemented in Firedrake previously defined (Bouziani & Ham, 2023). We use the following loss:

$$
\begin{equation}
\mathcal{L} = \|sol - sol_{exact}\|_{\ell_{2}}^{2} + \alpha \|F(f, u)\|^{2}_{\ell_{2}}
\end{equation}
$$

with $sol_{exact}$ and $sol = (u, p)$ the predicted and exact solutions of the Stokes problem, respectively. $F$ is the residual form associated with our PDE problem.

In [None]:
# Define the loss function
def loss_fn(f, sol, sol_exact, alpha):
  # Assemble residual
  residual = ...
  # Compute the loss
  loss = ...
  return loss

In [None]:
def training(loss_fn):
  # Set hyperparameters
  epochs = 80
  learning_rate = 5e-5
  train_steps = len(train_dataloader)
  best_error = 0.0
  alpha = 0.2

  # Set optimiser
  optimiser = optim.AdamW(model.parameters(), lr=learning_rate)

  # Training lopp
  for epoch in range(epochs + 1):
      model.train()
      # Loop over dataset
      total_loss = 0.0
      for step_num, batch in enumerate(train_dataloader):
          model.zero_grad()

          # Retrieve data from batch
          f, sol_exact = batch.f, batch.u
          edge_index = batch.edge_index

          # TODO: Forward pass through the model
          sol = ...

          # TODO: Compute the loss
          loss = ...
          total_loss += loss.item()

          # Backpropagate
          loss.backward()
          # Optimiser step
          optimiser.step()

      # Compute error
      test_error = evaluate(model, train_dataloader)
      print(f"Epoch: {epoch} : Training loss: {total_loss/train_steps} Error (l2): {test_error}")

      if test_error < best_error or epoch == 0:
          best_error = test_error
          saved_model = model

  print(f"\n Best error: {best_error:.3e}")
  return sol

In [None]:
sol = training(loss_fn)

In [None]:
# Convert PyTorch tensor to Firedrake
sol_fd = from_torch(sol, Z)
# Plot
plot_sol(sol_fd)