In [1]:
import jax.numpy as np
from matplotlib import pyplot as plt
from jax import grad, vmap, jit, partial, random
import numpy as onp
from jax.experimental.ode import odeint
from jax.experimental import optimizers
from jax.scipy.optimize import minimize
import jax
key = random.PRNGKey(0)
import pickle
from jax.lax import fori_loop, scan
from jax.config import config
import timeit
# config.update("jax_enable_x64", True)



# Neural ODE constitutive model

## Constitutive modeling

When modeling a hyperelastic material from a mechanical perspective, we would like to define a strain energy function:
$$\Psi(I_1, I_2, ...,I_n)$$
Where $I_i$ represent the invariants of the right Cauchy deformation tensor. One desired property of this function is convexity. This is important because it guarantees that the material always a non-negative stiffness. To satisfy convexity, we have 2 equivalent options:

- the second derivative (stiffness) is always non-negative.
- the first derivative monotonically increases with respect to the input.

A very commonly used form of $\Psi$ is based on an additive decomposition:
$$\Psi(I_1, I_2, ...,I_n) = \sum_i^n \Psi_i(I_i)$$
In this particular case, to satisfy convexity, we just need $\frac{\delta \Psi_i}{\delta I_i}$ to be monotonically increasing.


## Neural networks for constitutive modeling

In this context, we would like to approximate $\Psi$ with a neural network to avoid the need of choosing a particular form of the equation. And to satisfy the convexity requirement, we would like the approximate $\frac{\delta \Psi_i}{\delta I_i}$  with a neural network that monotonically increases. This can be satisfied with a neural ordinary differential equation [1], which has the form:

$$\boldsymbol{y} = \boldsymbol{x} + \int_0^T f(\boldsymbol{h}(t),t,\boldsymbol{\theta})dt$$

where $f(\boldsymbol{h}(t),t,\boldsymbol{\theta})$ is a neural network that represents the right hand side of an ODE. This mapping from $\boldsymbol{x}$ to $\boldsymbol{y}$ must monotonically increase or decrease following this reasoning:
- For the kind of neural network that we use for right hand side of the ODE, we can guarantee that there is a unique solution for every initial condition. This means, in our context, that for every input there is a unique output [1].
- But, we can also integrate the system backwards in time, meaning that we can use the output as an initial condition to go to the input. And from what we just said, there must be a unique input for every output. This means we can invert the system $y = g(x),\, x = g^{-1}(y)$. 
- If a function can be inverted, it must monotonically increase or decrease.
- Another way to interpret this, is to consider trajectories (solutions of the ODE) never cross each other. 

The idea is to approximate $\frac{\delta \Psi_i}{\delta I_i}$ with a neural ODE. But in general, any invertible architecture should work. 
$$\frac{\delta \Psi_i}{\delta I_i} = I_i + \int_0^T f(\boldsymbol{h}(t),t,\boldsymbol{\theta})dt$$
$$\boldsymbol{h}(0) = I_i$$

The hope is that data with steer the function to be monotonically increasing instead of decreasing.

If we want to compute the tangent stiffness matrix, which involves $\frac{\delta^2 \Psi_i}{\delta I_i^2}$, this can done efficiently with the adjoint method. 

In the rare case that would like to evaluate the strain energy function, we would have to do it numerically.






In [2]:
load_existing_model = False
model_name = 'P12AC1_bsxsy'
dataset_name = 'P12AC1_bsxsy'
#P1C1: n_offx = n_offy = 61, n_equi = n_strx = n_stry = 0
#S111S1: n_offx = n_offy = n_equi = 183, n_strx = n_stry = 0
#P12AC1: n_offx = 72, n_offy = 76, n_equi = 81, n_strx = 101, n_stry = 72
#P12BC2: n_offx = 76, n_offy = 76, n_equi = 95, n_strx = 101, n_stry = 84
#Unless you are working with multifidelity data, set all of the following to 0.
n_offx = 0
n_offy = 0
n_equi = 0
n_strx = 0
n_stry = 0
n_hf = n_offx + n_offy + n_equi + n_strx + n_stry

with open('training_data/' + dataset_name + '.npy', 'rb') as f:
    lamb, sigma_gt = np.load(f,allow_pickle=True)
n_data = lamb.shape[0]
weights = onp.ones([n_data,1])
weights[:n_hf] = 10 #Weight of high fidelity loss
#sigma_gt = np.concatenate((sigma_gt.transp,weights), axis=0)
sigma_gt = np.hstack((sigma_gt,weights))

In [3]:
# A generic function to compute the stress given the princial stretches
# this will be used later for the NN
@partial(jit, static_argnums=(4))
def sigma(lamb1, lamb2, lamb3, p, NN, params):
    I1_params, I2_params, Iv_params, Iw_params, J1_params, J2_params, J3_params, J4_params, J5_params, J6_params, \
        I_weights, theta, Psi1_bias, Psi2_bias = params
    
    a = 1/(1+np.exp(-I_weights))
    
    v0 = np.array([ np.cos(theta), np.sin(theta), 0])
    w0 = np.array([-np.sin(theta), np.cos(theta), 0])
    F = np.array([[lamb1, 0, 0],
                  [0, lamb2, 0],
                  [0, 0, lamb3]])
    v = np.dot(F, v0)
    w = np.dot(F, w0)
    b = np.dot(F, F.T)
    I1 = np.trace(b)
    b2 = np.einsum('ij,jk->ik', b, b)
    I2 = 0.5*(np.trace(b)**2 - np.trace(b2)) # again assuming is diagonal
    Iv = np.sum(b*np.outer(v0,v0)) 
    Iw = np.sum(b*np.outer(w0,w0)) 
    
    I1 = I1-3
    I2 = I2-3
    Iv = Iv-1
    Iw = Iw-1
    J1 = a[0]*I1+(1-a[0])*I2
    J2 = a[1]*I1+(1-a[1])*Iv
    J3 = a[2]*I1+(1-a[2])*Iw
    J4 = a[3]*I2+(1-a[3])*Iv
    J5 = a[4]*I2+(1-a[4])*Iw
    J6 = a[5]*Iv+(1-a[5])*Iw
    
    #Iv, Iw, J1, J2, J3, J4, J5, J6 = 0.,0.,0.,0.,0.,0.,0.,0.
    
    Psi1 = NN(I1,  I1_params)
    Psi2 = NN(I2,  I2_params)
    Psiv = NN(Iv,  Iv_params)
    Psiw = NN(Iw,  Iw_params)
    Phi1 = NN(J1,  J1_params)
    Phi2 = NN(J2,  J2_params)
    Phi3 = NN(J3,  J3_params)
    Phi4 = NN(J4,  J4_params)
    Phi5 = NN(J5,  J5_params)
    Phi6 = NN(J6,  J6_params)
    
    Psiv = np.max([Psiv, 0])
    Psiw = np.max([Psiw, 0])
    Phi1 = np.max([Phi1, 0])
    Phi2 = np.max([Phi2, 0])
    Phi3 = np.max([Phi3, 0])
    Phi4 = np.max([Phi4, 0])
    Phi5 = np.max([Phi5, 0])
    Phi6 = np.max([Phi6, 0])
    
    Psi1 = Psi1 +     a[0]*Phi1 +     a[1]*Phi2 +     a[2]*Phi3 + np.exp(Psi1_bias)
    Psi2 = Psi2 + (1-a[0])*Phi1 +     a[3]*Phi4 +     a[4]*Phi5 + np.exp(Psi2_bias)
    Psiv = Psiv + (1-a[1])*Phi2 + (1-a[3])*Phi4 +     a[5]*Phi6
    Psiw = Psiw + (1-a[2])*Phi3 + (1-a[4])*Phi5 + (1-a[5])*Phi6
    
    return -p*np.eye(3) + 2*Psi1*b + 2*Psi2*((I1+3)*b - b**2) + 2*Psiv*np.outer(v,v) + 2*Psiw*np.outer(w,w)

In [4]:
# a wrapper to compute biaxial stresses easily
@partial(jit, static_argnums=(1))
def sigma_biaxial(lamb, NN, params):
    # incompressibility
    lamb3 = 1/(lamb[0]*lamb[1])
    # use \sigma_33 = 0 to compute pressure
    p = sigma(lamb[0], lamb[1], lamb3, 0, NN, params)[2,2]
    # return \sigma_11 and \sigma_22
    return sigma(lamb[0], lamb[1], lamb3, p, NN, params)[[0,1],[0,1]]
# in jax we do everything for one value and then vectorize with vmap
sigma_biaxial_vmap = vmap(sigma_biaxial, in_axes=(0,None, None), out_axes=0)

In [5]:
# Fully connected neural network code
def init_params(layers, key):
    Ws = []
    for i in range(len(layers) - 1):
        std_glorot = np.sqrt(2/(layers[i] + layers[i + 1]))
        key, subkey = random.split(key)
        Ws.append(random.normal(subkey, (layers[i], layers[i + 1]))*std_glorot)
    return Ws

@jit
def forward_pass(H, Ws):
    N_layers = len(Ws)
    for i in range(N_layers - 1):
        H = np.matmul(H, Ws[i])
        H = np.tanh(H)
        Y = np.matmul(H, Ws[-1])
    return Y

@partial(jit, static_argnums=(0,))
def step(loss, i, opt_state, X_batch, Y_batch):
    params = get_params(opt_state)
    g = grad(loss)(params, X_batch, Y_batch)
    return opt_update(i, g, opt_state)

def train(loss, X, Y, opt_state, key, nIter = 10000, batch_size = 10):
    train_loss = []
    val_loss = []
    for it in range(nIter):
        key, subkey = random.split(key)
        idx_batch = random.choice(subkey, X.shape[0], shape = (batch_size,), replace = False)
        opt_state = step(loss, it, opt_state, X[idx_batch], Y[idx_batch])         
        if (it+1)% 10000 == 0:
            params = get_params(opt_state)
            train_loss_value = loss(params, X, Y)
            train_loss.append(train_loss_value)
            to_print = "it %i, train loss = %e" % (it+1, train_loss_value)
            print(to_print)
    return get_params(opt_state), train_loss, val_loss

In [6]:
@jit
def ODEforward_pass(y0, params):
    f = lambda y, t: forward_pass(np.array([y]),params) # fake time argument for ODEint
    return odeint(f, y0, np.array([0.0,1.0]))[-1] # integrate between 0 and 1 and return the results at 1

def Eulerforward_pass(y0, params, steps = 10):
    body_func = lambda y, i: (y + forward_pass(np.array([y]), params)[0], None)
    out, _ = scan(body_func, y0, None, length = steps)
    return out

@jit
def loss(params, lamb, sigma_gt):
    sigma_pr = sigma_biaxial_vmap(lamb, Eulerforward_pass, params)
    J_weights = params[10]
    #sigma_gt[:,2] contains the weights of multifidelity data. i.e. sigma_gt[:,2] = 1 for low fidelity and 10 for high fidelity.
    dummy1 = sigma_pr[:,0]
    dummy2 = sigma_gt[:,0]
    dummy3 = sigma_gt[:,2]
    
    #transform to log space:
    # dummy1 = np.log(dummy1)
    # dummy2 = np.log(dummy2)
    loss_MSE = np.average((dummy1 - dummy2)**2*dummy3)
                          
    dummy1 = sigma_pr[:,1]
    dummy2 = sigma_gt[:,1]
    # dummy1 = np.log(dummy1)
    # dummy2 = np.log(dummy2)
    loss_MSE+= np.average((dummy1 - dummy2)**2*dummy3)
    
#    loss_L1 = np.array(1.0e-5*np.sum(np.abs(J_weights)), dtype = np.float64)
#    loss_L1 = 1.0e-5*np.sum(np.abs(J_weights))
    return  loss_MSE# + loss_L1

layers = [1, 5, 5, 1]

W1_params = init_params(layers, key)
W2_params = init_params(layers, key)
W4v_params = init_params(layers, key)
W4w_params = init_params(layers, key)
J1_params = init_params(layers, key)
J2_params = init_params(layers, key)
J3_params = init_params(layers, key)
J4_params = init_params(layers, key)
J5_params = init_params(layers, key)
J6_params = init_params(layers, key)
I_weights = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
theta = 1.0
Psi1_bias = -3.0
Psi2_bias = -3.0

params = (W1_params, W2_params, W4v_params, W4w_params, J1_params, J2_params, J3_params, J4_params, J5_params, J6_params, I_weights, \
          theta, Psi1_bias, Psi2_bias)


opt_init, opt_update, get_params = optimizers.adam(1.e-4)
opt_state = opt_init(params)

In [7]:
#%%timeit -r 0
if load_existing_model == False:
    if lamb.shape[0] > 100:
        batch_size = 100
    else:
        batch_size = lamb.shape[0]
    
    params, train_loss, val_loss = train(loss,lamb, sigma_gt, opt_state, key, nIter = 200000, batch_size = batch_size)
    with open('savednet/' + model_name + '.npy', 'wb') as f:
        pickle.dump(params, f)
else:
    with open('savednet/' + model_name + '.npy', 'rb') as f:
        params = pickle.load(f)



it 10000, train loss = 1.068761e-02
it 20000, train loss = 2.305864e-03
it 30000, train loss = 6.219599e-04
it 40000, train loss = 5.151139e-04
it 50000, train loss = 4.814814e-04
it 60000, train loss = 4.625857e-04
it 70000, train loss = 4.621571e-04
it 80000, train loss = 4.395556e-04
it 90000, train loss = 4.256108e-04
it 100000, train loss = 4.051035e-04
it 110000, train loss = 3.769968e-04
it 120000, train loss = 3.651644e-04
it 130000, train loss = 3.059523e-04
it 140000, train loss = 3.023825e-04
it 150000, train loss = 2.986836e-04
it 160000, train loss = 2.947662e-04
it 170000, train loss = 2.941077e-04
it 180000, train loss = 2.953334e-04
it 190000, train loss = 2.981724e-04
it 200000, train loss = 2.899026e-04


## Hyperparameter study

| Network Architecture | Training Loss |
|----------------------|---------------|
| 1x5x1                | 5.52E-01      |
| 1x5x5x1              | 2.23E-02      |
| 1x5x5x5x1            | 5.66E-02      |

| Network Architecture | Training Loss |
|----------------------|---------------|
| 1x3x3x1              | 2.45E-02      |
| 1x5x5x1              | 2.23E-02      |
| 1x8x8x1              | 7.27E-02      |