# Simple PINN for an elastic plate with an elliptical hole 

## Geometry

We want to model a quarter of a plate with an elliptical hole. The domain itself is represented by collocation points, the boundaries are represented by uniformly sampled points along the perimeter.

In [1]:
import os

import numpy as np
import plotly.figure_factory as ff
import plotly.graph_objects as go
import torch
from plotly.express.colors import sequential
from scipy.stats import qmc
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

import wandb
from global_constants import LBD, MU, N1, N2, L, R

# Track notebook
os.environ["WANDB_NOTEBOOK_NAME"] = "./torch.ipynb"

# Elliptical axis in x direction
Rx = 0.14
Ry = R**2 / Rx
# Edge samples
N = 25
# Number of collocation points
M = 500
# Epochs
EPOCHS = 25000
# Batch size
BATCH_SIZE = 100
# Learning rate
LR = 0.001
# Scheduler step width
STEP = 1000
# Gamma factor of scheduler
GAMMA = 0.8
# Number of hidden neurons
HN = 40
# Number of hidden layers
LAYERS = 4
# Weight of PDE loss
W_PDE = 1.0
# Weight of Dirichlet loss
W_DIR = 1.0
# Weight of Neumann loss
W_NEU = 1.0
# Weight of hole loss
W_HOLE = 1.0
# Weight update factor
ALPHA = 0.9
# Variance
SIGMA = 1.0
# Fourier features
FEATURES = 40
# precision
DTYPE = torch.float32


wandb.init(
    project="pinn_hole_plate",
    entity="ddped",
    name=f"G:{GAMMA}|LR:{LR}|S:{SIGMA}|F:{FEATURES}|HN:{HN}",
    save_code = True,
    config={
        "Rx": Rx,
        "N_edge": N,
        "M_collocation": M,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "scheduler_step": STEP,
        "scheduler_gamma": GAMMA,
        "hidden_neurons": HN,
        "n_layers": LAYERS,
        "alpha": ALPHA,
        "sigma": SIGMA,
        "fourier_features": FEATURES,
    },
)


def sample_domain():
    # Create collocation points
    points = qmc.LatinHypercube(d=2).random(M)
    mask = (((points[:, 0] ** 2) / (Rx**2)) + ((points[:, 1] ** 2) / (Ry**2))) > 1
    collocation = torch.tensor(points[mask], dtype=DTYPE)

    # Top boundary
    x_top = L * torch.tensor(qmc.LatinHypercube(d=1).random(N), dtype=DTYPE)
    y_top = L * torch.ones((N, 1))
    top = torch.column_stack([x_top, y_top])

    # Right boundary
    x_right = L * torch.ones((N, 1))
    y_right = L * torch.tensor(qmc.LatinHypercube(d=1).random(N), dtype=DTYPE)
    right = torch.column_stack([x_right, y_right])

    # Bottom boundary
    NN = int(N * (L - Rx) / L)
    rand_samp = qmc.LatinHypercube(d=1).random(NN)
    x_bottom = Rx + (L - Rx) * torch.tensor(rand_samp, dtype=DTYPE)
    y_bottom = torch.zeros((NN, 1))
    bottom = torch.column_stack([x_bottom, y_bottom])

    # Left boundary
    NN = int(N * (L - Ry) / L)
    rand_samp = qmc.LatinHypercube(d=1).random(NN)
    x_left = torch.zeros((NN, 1))
    y_left = Ry + (L - Ry) * torch.tensor(rand_samp, dtype=DTYPE)
    left = torch.column_stack([x_left, y_left])

    # Hole boundary
    rand_samp = qmc.LatinHypercube(d=1).random(int(N * np.pi * Rx / L)).ravel()
    phi = 0.5 * np.pi * torch.tensor(rand_samp, dtype=DTYPE)
    x_hole = Rx * torch.cos(phi)
    y_hole = Ry * torch.sin(phi)
    n_hole = torch.stack([-Ry * torch.cos(phi), -Rx * torch.sin(phi)]).T
    n_hole = n_hole / torch.linalg.norm(n_hole, axis=1)[:, None]
    hole = torch.column_stack([x_hole, y_hole])

    return collocation, top, bottom, left, right, hole, n_hole


collo, top, bottom, left, right, hole, n_hole = sample_domain()

# Visualize geometry
with torch.no_grad():
    mode = "markers"
    gray = dict(color="#C9C5BC")
    green = dict(color="#006561")
    black = dict(color="black")
    fig = ff.create_quiver(
        hole[:, 0], hole[:, 1], n_hole[:, 0], n_hole[:, 1], marker=black
    )
    fig.add_trace(go.Scatter(x=collo[:, 0], y=collo[:, 1], mode=mode, marker=gray))
    fig.add_trace(go.Scatter(x=top[:, 0], y=top[:, 1], mode=mode, marker=black))
    fig.add_trace(go.Scatter(x=bottom[:, 0], y=bottom[:, 1], mode=mode, marker=green))
    fig.add_trace(go.Scatter(x=left[:, 0], y=left[:, 1], mode=mode, marker=green))
    fig.add_trace(go.Scatter(x=right[:, 0], y=right[:, 1], mode=mode, marker=black))
    fig.add_trace(go.Scatter(x=hole[:, 0], y=hole[:, 1], mode=mode, marker=black))
    fig.layout.yaxis.scaleanchor = "x"
    fig.update_layout(
        template="none",
        width=400,
        height=400,
        margin=dict(l=0, r=0, b=0, t=0),
        showlegend=False,
    )
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    fig.show()

[34m[1mwandb[0m: Currently logged in as: [33mmeyer-nils[0m ([33mddped[0m). Use [1m`wandb login --relogin`[0m to force relogin


## The ANN model that approximates the displacement field

An ANN might be considered as a generic function approximator. In this case, it should approximated the function $u: \mathcal{R}^2 \rightarrow \mathcal{R}^2$ with five hidden layers having 20 neurons each.

In [2]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Input layer
        self.layers = torch.nn.ModuleList([torch.nn.Linear(2 * FEATURES, HN).type(DTYPE)])
        # Hidden layers
        for _ in range(LAYERS - 1):
            self.layers.append(torch.nn.Linear(HN, HN).type(DTYPE))
        # Output layer
        self.output_layer = torch.nn.Linear(HN, 2).type(DTYPE)

        # Initialize weights with Glorot scheme
        for layer in self.layers:
            torch.nn.init.xavier_uniform_(layer.weight)

        # Sample B from normal distribution
        self.B = torch.normal(0.0, SIGMA, size=(2, FEATURES), dtype=DTYPE)

    def forward(self, x):
        # Random Fourier feature embedding
        x = torch.cat([torch.sin(2 * np.pi * x @ self.B), torch.cos(2*np.pi*x @ self.B)], dim=-1)
        for layer in self.layers:
            x = torch.tanh(layer(x))
        return self.output_layer(x)


net = Net()

## The physics

We want to solve linear elasticity on the domain, which means ultimately that we want to minimize the residual of the following PDE 
$$\frac{\partial \sigma_{11}}{\partial x_1} + \frac{\partial \sigma_{12}}{\partial x_2} = 0$$
$$\frac{\partial \sigma_{21}}{\partial x_1} + \frac{\partial \sigma_{22}}{\partial x_2} = 0$$
with stress 
$$ \sigma_{ij} = 2\mu \varepsilon_{ij} + \frac{2\lambda\mu}{2\mu+\lambda} \varepsilon_{kk} \delta_{ij} $$
and strain 
$$ \varepsilon_{ij} = \frac{1}{2} \left( \frac{\partial u_i}{\partial x_j} +  \frac{\partial u_j}{\partial x_i}\right).$$

In [3]:
def epsilon(x):
    # Compute deformation gradient
    dudx = torch.func.jacrev(net)(x)
    return 0.5 * (dudx + dudx.T)


def sigma(x):
    # Compute (small deformation) strain
    eps = epsilon(x)
    # Compute linear elastic strain (assuming plane stress)
    return 2.0 * MU * eps + (2 * LBD * MU) / (2 * MU + LBD) * torch.trace(
        eps
    ) * torch.eye(2)


def pde_residual(x):
    # Compute stress gradient
    dsdx = torch.func.jacrev(sigma)(x)
    # Momentum balance in x direction
    residual_x = dsdx[0, 0, 0] + dsdx[0, 1, 1]
    # Momentum balance in y direction
    residual_y = dsdx[1, 0, 0] + dsdx[1, 1, 1]
    return residual_x, residual_y

## Boundary conditions

Left: 

$$ u_1 = 0$$

Bottom: 

$$ u_2 = 0$$

Top: 

$$ \sigma \cdot n = N_2 n$$

Right: 

$$ \sigma \cdot n = N_1 n$$

In [4]:
mse = torch.nn.MSELoss()


def compute_physics_losses(collocation, top, bottom, left, right, hole, n_hole):
    # pde
    all_points = torch.cat([collocation, top, bottom, left, right, hole])
    res_x, res_y = torch.vmap(pde_residual)(all_points)
    zeros = torch.zeros_like(res_x)
    pde_error = mse((res_x), zeros) + mse((res_y), zeros)

    # left boundary
    pred_left = net(left)
    bc_left = torch.zeros_like(pred_left[:, 0])
    left_error = mse(pred_left[:, 0], bc_left)

    pred_stress_left = torch.vmap(sigma)(left)
    pred_s_left_xy = pred_stress_left[:, 0, 1]
    s_left_xy = torch.zeros_like(pred_s_left_xy)
    left_symm_error = mse(pred_s_left_xy, s_left_xy)

    # bottom boundary
    pred_bottom = net(bottom)
    bc_bottom = torch.zeros_like(pred_bottom[:, 1])
    bottom_error = mse(pred_bottom[:, 1], bc_bottom)

    pred_stress_bottom = torch.vmap(sigma)(bottom)
    pred_s_bottom_xy = pred_stress_bottom[:, 0, 1]
    s_bottom_xy = torch.zeros_like(pred_s_bottom_xy)
    bottom_symm_error = mse(pred_s_bottom_xy, s_bottom_xy)

    # top boundary
    pred_stress_top = torch.vmap(sigma)(top)
    pred_s_top_yy = pred_stress_top[:, 1, 1]
    pred_s_top_xy = pred_stress_top[:, 0, 1]
    s_top_yy = N2 * torch.ones_like(pred_s_top_yy)
    s_top_xy = torch.zeros_like(pred_s_top_xy)
    top_error = mse(pred_s_top_yy, s_top_yy) + mse(pred_s_top_xy, s_top_xy)

    # right boundary
    pred_stress_right = torch.vmap(sigma)(right)
    pred_s_right_xx = pred_stress_right[:, 0, 0]
    pred_s_right_xy = pred_stress_right[:, 0, 1]
    s_right_xx = N1 * torch.ones_like(pred_s_right_xx)
    s_right_xy = torch.zeros_like(pred_s_right_xy)
    right_error = mse(pred_s_right_xx, s_right_xx) + mse(pred_s_right_xy, s_right_xy)

    # hole boundary
    stress_hole = torch.vmap(sigma)(hole)
    traction = torch.einsum("...ij,...j->...i", stress_hole, n_hole)
    zeros = torch.zeros_like(traction[:, 0])
    hole_error = mse(traction[:, 0], zeros) + mse(traction[:, 1], zeros)

    return (left_error, left_symm_error, right_error, bottom_error, bottom_symm_error, top_error, hole_error, pde_error)

## Training 

In [5]:
def compute_gradient_norm(loss):
    grads = torch.autograd.grad(loss, net.parameters(), allow_unused=True)
    return sum(0 if grad is None else torch.linalg.norm(grad) for grad in grads)


def update_weight(weight, grad_sum, norm):
    new_weight = grad_sum / norm
    return ALPHA * weight + (1 - ALPHA) * new_weight

In [6]:
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
scheduler = StepLR(optimizer, step_size=STEP, gamma=GAMMA)

for epoch in tqdm(range(EPOCHS)):
    # Sample domain randomly
    collocation, top, bottom, left, right, hole, n_hole = sample_domain()


    for collo in torch.chunk(collocation, int(M / BATCH_SIZE)):
        optimizer.zero_grad()

        # Compute physics losses
        left_l, left_sl, right_l, bottom_l, bottom_sl, top_l, hole_l, pde_l = compute_physics_losses(
            collo, top, bottom, left, right, hole, n_hole
        )

        # Aggregate losses
        dirichlet_losses = left_l + bottom_l
        neumann_losses = right_l + top_l + left_sl + bottom_sl
        loss = (
            W_DIR * dirichlet_losses
            + W_NEU * neumann_losses
            + W_HOLE * hole_l
            + W_PDE * pde_l
        )

        # Make optimization step after batch
        loss.backward(retain_graph=True)
        optimizer.step()

    # Rebalance weights every 100 epochs
    if epoch % 100 == 0:
        left_l, left_sl, right_l, bottom_l, bottom_sl, top_l, hole_l, pde_l = compute_physics_losses(
            collocation, top, bottom, left, right, hole, n_hole
        )
        grad_dir = compute_gradient_norm(left_l + bottom_l)
        grad_neu = compute_gradient_norm(right_l + top_l + left_sl + bottom_sl)
        grad_hole = compute_gradient_norm(hole_l)
        grad_pde = compute_gradient_norm(pde_l)
        grad_sum = grad_dir + grad_neu + grad_hole + grad_pde
        W_DIR = update_weight(W_DIR, grad_sum, grad_dir)
        W_NEU = update_weight(W_NEU, grad_sum, grad_neu)
        W_HOLE = update_weight(W_HOLE, grad_sum, grad_hole)
        W_PDE = update_weight(W_PDE, grad_sum, grad_pde)

    # Make scheduler step after full epoch
    scheduler.step()

    # append loss to history (=for plotting)
    with torch.no_grad():
        wandb.log(
            {
                "loss": loss,
                "left_loss": left_l,
                "left_symm_loss": left_sl,
                "right_loss": right_l,
                "bottom_loss": bottom_l,
                "bottom_symm_loss": bottom_sl,
                "top_loss": top_l,
                "hole_loss": hole_l,
                "pde_loss": pde_l,
                "W_DIR": W_DIR,
                "W_NEU": W_NEU,
                "W_HOLE": W_HOLE,
                "W_PDE": W_PDE,
            }
        )

100%|██████████| 25000/25000 [43:28<00:00,  9.58it/s]


## Visualization of results

In [7]:
stress_hole = torch.vmap(sigma)(hole)
data_hole = np.loadtxt(f"data/hole_Rx={Rx}.csv", delimiter=",")
data_hole = data_hole[data_hole[:, 0].argsort()]

with torch.no_grad():
    fig = go.Figure()
    m1 = dict(color="blue")
    m2 = dict(color="orange")
    fig.add_trace(
        go.Scatter(
            x=hole[:, 0],
            y=stress_hole[:, 0, 0],
            marker=m1,
            mode="markers",
            name="σ_xx (PINN)",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=data_hole[:, 0],
            y=data_hole[:, 2],
            marker=m1,
            mode="lines",
            name="σ_xx (FEM)",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=hole[:, 0],
            y=stress_hole[:, 1, 1],
            marker=m2,
            mode="markers",
            name="σ_yy (PINN)",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=data_hole[:, 0],
            y=data_hole[:, 3],
            marker=m2,
            mode="lines",
            name="σ_yy (FEM)",
        )
    )
    fig.update_layout(
        template="none", width=600, height=400, title="Stress at hole", showlegend=True
    )
    wandb.log({f"hole_stress": fig})
    fig.show()

In [8]:
# Create a validation domain different from the training domain
val_x, val_y = np.meshgrid(np.linspace(0, L, 25), np.linspace(0, L, 25))
val_domain = np.vstack([val_x.ravel(), val_y.ravel()]).T
mask = (
    ((val_domain[:, 0] ** 2) / (Rx**2)) + ((val_domain[:, 1] ** 2) / (Ry**2))
) > 1
val = torch.tensor(val_domain[mask], dtype=DTYPE, requires_grad=True)

# Compute model predictions on the validation domain
disp = net(val)
def_val = val + disp
stress = torch.vmap(sigma)(val)
mises = torch.sqrt(
    stress[:, 0, 0] ** 2
    + stress[:, 1, 1] ** 2
    - stress[:, 0, 0] * stress[:, 1, 1]
    + 3 * stress[:, 0, 1] ** 2
)
# print([loss.item() for loss in compute_les(val)])


@torch.no_grad()
def make_plot(x, y, variable, title, cmap=sequential.Viridis, size=8.0):
    fig = go.Figure()

    # Plot boundaries
    m = dict(color="black")
    fig.add_trace(go.Scatter(x=top[:, 0], y=top[:, 1], mode="lines", marker=m))
    fig.add_trace(go.Scatter(x=bottom[:, 0], y=bottom[:, 1], mode="lines", marker=m))
    fig.add_trace(go.Scatter(x=left[:, 0], y=left[:, 1], mode="lines", marker=m))
    fig.add_trace(go.Scatter(x=right[:, 0], y=right[:, 1], mode="lines", marker=m))
    fig.add_trace(go.Scatter(x=hole[:, 0], y=hole[:, 1], mode="lines", marker=m))

    # Plot variable values
    m = dict(color=variable, colorscale=cmap, size=size, colorbar=dict(thickness=10))
    fig.add_trace(go.Scatter(x=x, y=y, marker=m, mode="markers"))

    # plot settings
    fig.layout.yaxis.scaleanchor = "x"
    fig.update_layout(
        template="none", width=400, height=400, title=title, showlegend=False
    )
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    wandb.log({f"chart_{title}": fig})
    fig.show()


# Load reference data
data_input = torch.as_tensor(
    np.loadtxt(f"data/inputs_Rx={Rx}.csv", delimiter=","), dtype=DTYPE
)
data_output = torch.as_tensor(
    np.loadtxt(f"data/outputs_Rx={Rx}.csv", delimiter=","), dtype=DTYPE
)


# Compute data error
s_data = data_output[:, 0:3]
s_pred = torch.vmap(sigma)(data_input[:, 0:2])
ds_xx = s_data[:, 0] - s_pred[:, 0, 0]
ds_yy = s_data[:, 1] - s_pred[:, 1, 1]
ds_xy = s_data[:, 2] - s_pred[:, 0, 1]


# Plot stress errors
cmap = sequential.RdBu_r
make_plot(*data_input[:, 0:2].T, ds_xx, "Stress error xx", size=2.0, cmap=cmap)
make_plot(*data_input[:, 0:2].T, ds_yy, "Stress error yy", size=2.0, cmap=cmap)
make_plot(*data_input[:, 0:2].T, ds_xy, "Stress error xy", size=2.0, cmap=cmap)

# Plot stresses
make_plot(*def_val.T, stress[:, 0, 0], "Stress xx")
make_plot(*def_val.T, stress[:, 0, 1], "Stress xy")
make_plot(*def_val.T, stress[:, 1, 1], "Stress yy")
make_plot(*def_val.T, mises, "Mises stress")

# Plot displacements
make_plot(*def_val.T, disp[:, 0], "Displacement in x", cmap=sequential.Inferno)
make_plot(*def_val.T, disp[:, 1], "Displacement in y", cmap=sequential.Inferno)

# Finish tracking
wandb.finish()

VBox(children=(Label(value='10.880 MB of 10.880 MB uploaded (1.934 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
W_DIR,▂▇█▆▄▃▂▁▁▂▁▁▁▁▁▁▂▁▁▁▂▂▂▁▂▂▂▂▂▃▃▄▄▄▄▄▅▄▃▄
W_HOLE,▁▃▅██▇▆▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▂▂▂▂
W_NEU,▁▁▂▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▆██▇▇▇▇▆▆▇
W_PDE,▃█▇▃▂▂▂▃▂▃▃▄▄▄▄▄▄▃▆▆▅▅▄▅▅▅▅▅▄▄▄▂▁▁▂▃▂▃▃▃
bottom_loss,█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
bottom_symm_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
hole_loss,█▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
left_loss,█▃▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
left_symm_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
W_DIR,1302208.75
W_HOLE,143.45709
W_NEU,240.26729
W_PDE,1.01296
bottom_loss,0.0
bottom_symm_loss,1e-05
hole_loss,3e-05
left_loss,0.0
left_symm_loss,0.0
loss,0.06233
