In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('png')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

In [None]:
import jax
import equinox as eqx
import jax.random as jrandom
from jax import grad, vmap
import jax.numpy as jnp

class FourierEncoding(eqx.Module):
    B: jax.Array

    @property
    def num_fourier_features(self) -> int:
        return self.B.shape[0]

    @property
    def in_size(self) -> int:
        return self.B.shape[1]
    
    @property
    def out_size(self) -> int:
        return self.B.shape[0] * 2

    def __init__(self, 
                 in_size: int, 
                 num_fourier_features: int, 
                 key: jax.random.PRNGKey, 
                 sigma: float = 1.0):
        self.B = jax.random.normal(
            key, shape=(num_fourier_features, in_size),
            dtype=jax.numpy.float32) * sigma
    
    def __call__(self, x: jax.Array, **kwargs) -> jax.Array:
        return jax.numpy.concatenate(
            [jax.numpy.cos(jax.numpy.dot(self.B, x)),
             jax.numpy.sin(jax.numpy.dot(self.B, x))],
            axis=0)

$$
u_1(x,y) = \delta - \delta(1-x) + x(1-x)N_1(x,y;\theta),
$$
and,
$$
u_2(x,y) = x(1-x)N_2(x,y;\theta)
$$
where $N_1(x,y;\theta)$ and $N_2(x,y;\theta)$ are neural networks.

In [None]:
# The model that satisfies the boundary conditions
u_hat_1 = lambda x, y, DELTA, model1: DELTA - DELTA * (1.0 - x) + x * (1.0 - x) * model1(jnp.array([x, y]))
u_hat_2 = lambda x, y, model2: x * (1.0 - x) * model2(jnp.array([x, y]))

u_hat_1_x = grad(u_hat_1, 0)
u_hat_1_y = grad(u_hat_1, 1)

u_hat_2_x = grad(u_hat_2, 0)
u_hat_2_y = grad(u_hat_2, 1)

For this hyperelastic material, the stored energy $E_b$ in the body can be expressed in as:

$$
E_b[\mathbf{u}(\cdot)] = \int_{[0,1]^2}\left\{\frac{1}{2}(\sum_{i=1}^2\sum_{j=1}^2{F_{ij}^2} - 2)- \ln(\det(\mathbf{F})) + 50\ln(\det(\mathbf{F}))^2\right\} dxdy,
$$

with

$$
\mathbf{F} = \mathbf{I} + \nabla \mathbf{u},
$$

where $\mathbf{I}$ is an identity matrix.

In [None]:
F_matrix = lambda x,y,delta, model1, model2: jnp.array([[1.0 + u_hat_1_x(x, y, delta, model1), u_hat_1_y(x, y, delta, model1)], 
                                  [u_hat_2_x(x, y, model2), 1.0 + u_hat_2_y(x, y, model2)]])

pde_residual = vmap(lambda x,y,delta, model1, model2: (0.5 * (jnp.square(F_matrix(x,y,delta, model1, model2)).sum() -2) 
                                                  - jnp.log(jnp.linalg.det(F_matrix(x,y,delta, model1, model2))) 
                                                  + 50 * jnp.log(jnp.linalg.det(F_matrix(x,y,delta, model1, model2))) ** 2),
                                                  in_axes=(0,0,None, None, None))

pinn_loss = lambda model, x, y, delta:jnp.mean(jnp.square(pde_residual(x,y,delta, model[0], model[1])))

In [None]:
key = jax.random.PRNGKey(0)
key1, key2, key = jax.random.split(key, 3)
num_fourier_features = 100
width_size = 128
depth = 4

model1 = eqx.nn.Sequential([
    eqx.nn.Lambda(
        FourierEncoding(2, num_fourier_features, key1, sigma=6.0)),
    eqx.nn.Lambda(
        eqx.nn.MLP(num_fourier_features * 2, 1, width_size, depth, jnp.tanh, key=key2)),
    eqx.nn.Lambda(
        lambda y: y[0])])

key = jax.random.PRNGKey(1)
key1, key2, key = jax.random.split(key, 3)

model2 = eqx.nn.Sequential([
    eqx.nn.Lambda(
        FourierEncoding(2, num_fourier_features, key1, sigma=6.0)),
    eqx.nn.Lambda(
        eqx.nn.MLP(num_fourier_features * 2, 1, width_size, depth, jnp.tanh, key=key2)),
    eqx.nn.Lambda(
        lambda y: y[0])])

In [None]:
# remember that we need a way to filter out the parameters of the Fourier encoding
import jax.tree_util as jtu
filter_spec1 = jtu.tree_map(lambda _: True, model1)
filter_spec1 = eqx.tree_at(
    lambda tree: (tree[0].fn.B,),
    filter_spec1,
    replace=(False,))

filter_spec2 = jtu.tree_map(lambda _: True, model2)
filter_spec2 = eqx.tree_at(
    lambda tree: (tree[0].fn.B,),
    filter_spec2,
    replace=(False,))

In [None]:
def train_pinn(
        loss,
        model1,model2,
        key,
        optimizer,
        filter_spec1, filter_spec2,
        delta=0.1,
        Lx=1.0,
        Ly=1.0,
        num_collocation_residual=512,
        num_iter=10_000,
        freq=1,
    ):

    fourier_mlp = (model1, model2)

    # this is new
    def new_loss(diff_model, static_model, x, y):
        comb_model = eqx.combine(diff_model, static_model)
        return loss(comb_model, x, y, delta)

    @eqx.filter_jit
    def step(opt_state, model, xs, ys):
        # added this line
        diff_model, static_model = eqx.partition(model, (filter_spec1,filter_spec2))
        # changed the loss to the new loss
        value, grads = eqx.filter_value_and_grad(new_loss)(diff_model, static_model, xs, ys)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, value
    
    opt_state = optimizer.init(eqx.filter(fourier_mlp, eqx.is_inexact_array))
    
    losses = []
    for i in range(num_iter):
        key1, key2, key = jrandom.split(key, 3)
        xb = jrandom.uniform(key1, (num_collocation_residual,), maxval=Lx)
        yb = jrandom.uniform(key2, (num_collocation_residual,), maxval=Ly)
        fourier_mlp, opt_state, value = step(opt_state, fourier_mlp, xb, yb)
        if value == jnp.nan:
            break
        if i % freq == 0:
            losses.append(value)
            print(f"Step {i}, residual loss {value:.3e}")
    return fourier_mlp, losses

In [None]:
import optax
key, subkey = jax.random.split(key)
optimizer = optax.adam(1e-3)
trained_model, losses = train_pinn(
    pinn_loss, model1, model2, key, optimizer, filter_spec1,filter_spec2,
    num_collocation_residual=32, num_iter=2_000, freq=100, Lx=1.0, Ly=1.0)

In [None]:
import numpy as np
fig, ax = plt.subplots()
ax.plot(losses, '-x', label="MLP+Fourier")
# set log scale for y axis
ax.set_yscale('log')
ax.set_xlabel("Iterations x 100")
ax.set_ylabel("Loss")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);

In [None]:
x = jnp.linspace(0, 1, 100)
y = jnp.linspace(0, 1, 100)
X, Y = jnp.meshgrid(x, y)

v_u_hat_1 = vmap(u_hat_1, in_axes=(0, 0, None, None))
u_pred_1 = v_u_hat_1(X.flatten(), Y.flatten(), 0.1, trained_model[0]).reshape(X.shape)

v_u_hat_2 = vmap(u_hat_2, in_axes=(0, 0, None))
u_pred_2 = v_u_hat_2(X.flatten(), Y.flatten(), trained_model[1]).reshape(Y.shape)

plt.subplot(1, 2, 1)
plt.title('Displacement Field: u1')
plt.contourf(x, y, u_pred_1, cmap='viridis')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('y')

plt.subplot(1, 2, 2)
plt.title('Displacement Field: u2')
plt.contourf(x, y, u_pred_2, cmap='viridis')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('y')

plt.tight_layout()
plt.show()

## Part B

Solve the problem for $\delta=0.5$ using the same architecture as above.
It will likely fail to train.
If yes, then use the solution of $\delta=0.1$ as the initial guess for $\delta=0.2$, and then use the solution of $\delta=0.2$ as the initial guess for $\delta=0.3$, and so on, until you reach $\delta=0.5$.
This is called transfer learning.

At the end, plot the final displacement field for $\delta=0.5$.

In [None]:
# Solve the problem for $\delta=0.5$ using the same architecture as above.
key, subkey = jax.random.split(key)
optimizer = optax.adam(1e-3)
trained_model_B, losses_B = train_pinn(
    pinn_loss, model1, model2, key, optimizer, filter_spec1, filter_spec2,
    num_collocation_residual=32, num_iter=1_000, freq=100, Lx=1.0, Ly=1.0, delta=0.5)

In [None]:
import numpy as np
fig, ax = plt.subplots()
ax.plot(losses_B, '-x', label="MLP+Fourier")
# set log scale for y axis
ax.set_title('$\delta = 0.5$')
ax.set_xlabel("Iterations x 100")
ax.set_ylabel("Loss")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);

In [None]:
x = jnp.linspace(0, 1, 100)
y = jnp.linspace(0, 1, 100)
X, Y = jnp.meshgrid(x, y)

v_u_hat_1 = vmap(u_hat_1, in_axes=(0, 0, None, None))
u_pred_1 = v_u_hat_1(X.flatten(), Y.flatten(), 0.5, trained_model_B[0]).reshape(X.shape)

v_u_hat_2 = vmap(u_hat_2, in_axes=(0, 0, None))
u_pred_2 = v_u_hat_2(X.flatten(), Y.flatten(), trained_model_B[1]).reshape(Y.shape)

plt.subplot(1, 2, 1)
plt.title('Displacement Field: u1')
plt.contourf(x, y, u_pred_1, cmap='viridis')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('y')

plt.subplot(1, 2, 2)
plt.title('Displacement Field: u2')
plt.contourf(x, y, u_pred_2, cmap='viridis')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('y')

plt.tight_layout()
plt.show()

## Part C

Solve the parametric problem for $\delta \in [0,0.5]$. That is, build a neural network that takes $\delta$ as input and outputs the displacement field. To do this:
+ Modify the loss function to:

$$
\mathcal{L} = \int_0^{0.5} \int_{[0,1]^2} \left\{\frac{1}{2}(\sum_{i}\sum_{j}{F_{ij}^2} - 2)- \ln(\det(\mathbf{F})) + 50\ln(\det(\mathbf{F}))^2\right\} dxdy d\delta.
$$

+ Modify the neural networks to take $\delta$ as input, say $N_1(x,y;\delta;\theta)$ and $N_2(x,y;\delta;\theta)$. Your field will be $\mathbf{u}(x,y;\delta;\theta)$.
Use the following architecture for the neural networks:

$$
N_1(x,y;\delta) = \sum_{i=1}^n b_{1,i}(\delta)t_{1,i}(x,y).
$$

Here, $n$ is your choice (start with $n=10$), $b_{1,i}$ is a neural network that takes $\delta$ as input and outputs a scalar, and $t_{1,i}(x,y)$ is a multi-layer perceptron with 3 hidden layers, each with 128 units, and tanh activations, and Fourier features at the beginning. The same applies to $N_2(x,y;\delta)$. This representation resembles an expansion in terms of basis functions.
The same architecture appears in DeepONet.

Plot the $x$ and $y$ displacement at $x=0.5, y=0.5$ as a function of $\delta$.

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

In [2]:
import jax
import equinox as eqx

class FourierEncoding(eqx.Module):
    B: jax.Array

    @property
    def num_fourier_features(self) -> int:
        return self.B.shape[0]

    @property
    def in_size(self) -> int:
        return self.B.shape[1]
    
    @property
    def out_size(self) -> int:
        return self.B.shape[0] * 2

    def __init__(self, 
                 in_size: int, 
                 num_fourier_features: int, 
                 key: jax.random.PRNGKey, 
                 sigma: float = 1.0):
        self.B = jax.random.normal(
            key, shape=(num_fourier_features, in_size),
            dtype=jax.numpy.float32) * sigma
    
    def __call__(self, x: jax.Array, **kwargs) -> jax.Array:
        return jax.numpy.concatenate(
            [jax.numpy.cos(jax.numpy.dot(self.B, x)),
             jax.numpy.sin(jax.numpy.dot(self.B, x))],
            axis=0)

In [3]:
import equinox as eqx
from jax import grad, vmap
import jax
import jax.numpy as jnp
import jax.random as jrandom

class ParametricModel(eqx.Module):
    """This model captures a simple structure made out of branches and trunks."""
    branch: list  # These are the b's
    trunk: list   # These are the phi's

    def __init__(self, branch_width=8, branch_depth=4, m=2, trunk_width=128, trunk_depth=4, trunk_num_fourier_features=100, key1=0, key2=1):
        key1=jax.random.PRNGKey(key1)
        key2=jax.random.PRNGKey(key2)
        # self.branch = branch
        self.branch = [eqx.nn.MLP('scalar', 'scalar', branch_width, branch_depth, jax.nn.tanh, key=k) for k in jrandom.split(key1, m)]
        # self.trunk = trunk
        self.trunk = [eqx.nn.Sequential([
            FourierEncoding(2, trunk_num_fourier_features, key=k),
            eqx.nn.MLP(trunk_num_fourier_features * 2, 'scalar', trunk_width, trunk_depth, jax.nn.tanh, key=k)]) for k in jrandom.split(key2, m)]
        

    def __call__(self, x, y, xi, **kwargs):
        res = 0.0
        for b, t in zip(self.branch, self.trunk):
            res += b(xi) * t(jnp.array([x,y]))
        return res

# The model that satisfies the boundary conditions
u_hat_1 = lambda x, y, DELTA, model1: DELTA - DELTA * (1.0 - x) + x * (1.0 - x) * model1(x, y, DELTA)
u_hat_2 = lambda x, y, DELTA, model2: x * (1.0 - x) * model2(x, y, DELTA)

u_hat_1_x = grad(u_hat_1, 0)
u_hat_1_y = grad(u_hat_1, 1)

u_hat_2_x = grad(u_hat_2, 0)
u_hat_2_y = grad(u_hat_2, 1)

In [4]:
M = 4
model1 = ParametricModel(key1=1234, key2=5678, m=M)
model2 = ParametricModel(key1=3241, key2=3465, m=M)

In [5]:
# remember that we need a way to filter out the parameters of the Fourier encoding
import jax.tree_util as jtu
filter_spec1 = jtu.tree_map(lambda _: True, model1)
for l in range(M):
# print(filter_spec1.trunk[0].layers[0].B)
    filter_spec1 = eqx.tree_at(
        lambda tree: (tree.trunk[l].layers[0].B,),
        filter_spec1,
        replace=(False,))

filter_spec2 = jtu.tree_map(lambda _: True, model2)
for l in range(M):
    filter_spec2 = eqx.tree_at(
        lambda tree: (tree.trunk[l].layers[0].B,),
        filter_spec2,
        replace=(False,))

In [6]:
F_matrix = lambda x,y,delta, model1, model2: jnp.array([[1.0 + u_hat_1_x(x, y, delta, model1), u_hat_1_y(x, y, delta, model1)], 
                                  [u_hat_2_x(x, y, delta, model2), 1.0 + u_hat_2_y(x, y, delta, model2)]])

pde_residual = vmap(lambda x,y,delta, model1, model2: (0.5 * (jnp.square(F_matrix(x,y,delta, model1, model2)).sum() -2) 
                                                  - jnp.log(jnp.linalg.det(F_matrix(x,y,delta, model1, model2))) 
                                                  + 50 * jnp.log(jnp.linalg.det(F_matrix(x,y,delta, model1, model2))) ** 2),
                                                  in_axes=(0,0,0, None, None))

pinn_loss = lambda model, x, y, delta:jnp.mean(jnp.square(pde_residual(x,y,delta, model[0], model[1])))

In [7]:
def train_pinn(
        loss,
        model1,model2,
        key,
        optimizer,
        filter_spec1, filter_spec2,
        Lx=1.0,
        Ly=1.0,
        num_collocation_residual=512,
        num_xis = 16,
        num_iter=10_000,
        freq=1,
    ):

    fourier_mlp = (model1, model2)

    # this is new
    def new_loss(diff_model, static_model, x, y, xis):
        comb_model = eqx.combine(diff_model, static_model)
        return loss(comb_model, x, y, xis)

    @eqx.filter_jit
    def step(opt_state, model, xs, ys, xis):
        # added this line
        diff_model, static_model = eqx.partition(model, (filter_spec1,filter_spec2))
        # changed the loss to the new loss
        value, grads = eqx.filter_value_and_grad(new_loss)(diff_model, static_model, xs, ys, xis)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, value
    
    opt_state = optimizer.init(eqx.filter(fourier_mlp, eqx.is_inexact_array))
    
    losses = []
    for i in range(num_iter):
        key1, key2, key3, key = jrandom.split(key, 4)
        xb = jrandom.uniform(key1, (num_collocation_residual,), maxval=Lx)
        yb = jrandom.uniform(key2, (num_collocation_residual,), maxval=Ly)
        xis = jrandom.uniform(key3, (num_xis,))
        fourier_mlp, opt_state, value = step(opt_state, fourier_mlp, xb, yb, xis)
        if value == jnp.nan:
            break
        if i % freq == 0:
            losses.append(value)
            print(f"Step {i}, residual loss {value:.3e}")
    return fourier_mlp, losses

In [8]:
import optax
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
optimizer = optax.adam(1e-3)
trained_model, losses = train_pinn(
    pinn_loss, model1, model2, key, optimizer, filter_spec1,filter_spec2,
    num_collocation_residual=32, num_iter=2_000, freq=100, Lx=1.0, Ly=1.0, num_xis=32)

Step 0, residual loss 1.217e+02
Step 100, residual loss 2.741e+00
Step 200, residual loss 7.437e-01
Step 300, residual loss 7.720e-01
Step 400, residual loss 1.373e+00
Step 500, residual loss 7.494e-01
Step 600, residual loss 6.145e-01


Exception ignored in: <function _xla_gc_callback at 0x132af9800>
Traceback (most recent call last):
  File "/Users/rohan/anaconda3/lib/python3.11/site-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


Step 700, residual loss 8.213e-01


In [None]:
import numpy as np
fig, ax = plt.subplots()
ax.plot(losses, '-o', label="MLP+Fourier")
# set log scale for y axis
ax.set_xlabel("Iterations x 100")
ax.set_ylabel("Loss")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);

In [None]:
x = 0.5
y = 0.5
delta = np.linspace(0,0.5,100)

# u_hat_1 = lambda x, y, DELTA, model1: DELTA - DELTA * (1.0 - x) + x * (1.0 - x) * model1(x, y, DELTA)
# u_hat_2 = lambda x, y, DELTA, model2: x * (1.0 - x) * model2(x, y, DELTA)

v_u_hat_1 = vmap(u_hat_1, in_axes=(None, None, 0, None))
u_pred_1 = v_u_hat_1(x, y, delta, trained_model[0])

v_u_hat_2 = vmap(u_hat_2, in_axes=(None, None, 0, None))
u_pred_2 = v_u_hat_2(x, y, delta, trained_model[1])

plt.subplot(1, 2, 1)
plt.title('Displacement Field: u1')
plt.plot(delta, u_pred_1)
plt.xlabel('$\delta$')
# plt.ylabel('y')

plt.subplot(1, 2, 2)
plt.title('Displacement Field: u2')
plt.plot(delta, u_pred_2)
plt.xlabel('$\delta$')
# plt.ylabel('y')

plt.tight_layout()
plt.show()