# Neural Constitutive Modeling

In [1]:
# [collapse: code] Colab Setup (Install Dependencies)

# Only run this if we are in Google Colab
if 'google.colab' in str(get_ipython()):
    print("Installing dependencies from pyproject.toml...")
    # This installs the repo itself (and its dependencies)
    !apt-get install gmsh 
    !apt-get install -qq xvfb libgl1-mesa-glx
    !pip install pyvista -qq
    !pip install -q "git+https://github.com/smec-ethz/tatva-docs.git"
    
    import pyvista as pv

    pv.global_theme.jupyter_backend = 'static'
    pv.global_theme.notebook = True
    pv.start_xvfb()
    
    print("Installation complete!")
else:
    import pyvista as pv
    pv.global_theme.jupyter_backend = 'client'



In this example, we will implement a neural constitutive model. A neural constitutive model uses neural networks to represent the relationship between stress and strain in materials. This approach allows for more flexible and accurate modeling of complex material behaviors compared to traditional constitutive models.


In [2]:
import jax

jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_enable_x64", True)

from functools import partial

import equinox as eqx
import jax.experimental.sparse as jsp
import jax.numpy as jnp
import numpy as np
import pyvista as pv
import scipy.sparse as sp
from jax import Array
from jax_autovmap import autovmap
from tatva import Mesh, Operator, element, sparse
from tatva_coloring import distance2_color_and_seeds

## Mesh 

We start by defining the mesh and material properties for our simulation.


In [3]:
# [collapse: all] Gmsh Geometry Creation and Mesh object creation

import gmsh
import meshio

cells_x = 4          
cells_y = 5          
side_length = 1.0   
wall_thickness = 0.15 
height = 1.0         
lc = 0.2         
fillet_radius = 0.6  # New: Radius for rounding the corners

gmsh.initialize()
gmsh.model.add("solid_honeycomb")
occ = gmsh.model.occ

h_hex = side_length * np.sqrt(3) 

dx = 1.5 * side_length + (wall_thickness * np.sqrt(3)/2)
dy = h_hex + wall_thickness

total_w = (cells_x- 1) * dx + 2 * side_length + 2 * wall_thickness
total_h = (cells_y) * dy + wall_thickness

# Create the main block
block = occ.addBox(0.1*total_w, 0.1*total_h, 0, total_w, total_h, height)

punches = []

for i in range(cells_x+1):
    for j in range(cells_y+1):
        cx = i * dx + side_length + wall_thickness
        offset_y = (dy / 2) if (i % 2 != 0) else 0
        cy = j * dy + (h_hex/2) + wall_thickness - offset_y
        
        pts = []
        for angle in range(0, 360, 60):
            rad = np.radians(angle)
            pts.append(occ.addPoint(cx + side_length * np.cos(rad), 
                                    cy + side_length * np.sin(rad), 0, lc))
        
        lines = [occ.addLine(pts[k], pts[(k+1)%6]) for k in range(6)]
        loop = occ.addCurveLoop(lines)
        surf = occ.addPlaneSurface([loop])
        
        # 2. Extrude to create the prism (Punch)
        # extrude returns list of (dim, tag)
        prism = occ.extrude([(2, surf)], 0, 0, height)
        vol_tag = prism[1][1] # Get the tag of the 3D volume
        
        # --- FILLET LOGIC START ---
        # We need to find the 6 vertical edges of this new prism to fillet them.
        occ.synchronize() # Must sync to query the new volume's edges
        
        # Get all curves (edges) associated with the volume
        # getBoundary returns list of (dim, tag)
        edges = gmsh.model.getBoundary([(3, vol_tag)], combined=False, oriented=False, recursive=True)
        
        vertical_edges = []
        for e in edges:
            dim, tag = e
            if dim == 1: # Ensure it is a curve
                xmin, ymin, zmin, xmax, ymax, zmax = gmsh.model.getBoundingBox(1, tag)
                
                # Check if edge is vertical (aligned with Z)
                z_len = abs(zmax - zmin)
                x_len = abs(xmax - xmin)
                y_len = abs(ymax - ymin)
                
                if z_len > 0.9 * height and x_len < 1e-4 and y_len < 1e-4:
                    vertical_edges.append(tag)
        
        if vertical_edges:
            filleted_punches = occ.fillet([vol_tag], vertical_edges, fillet_radius)
            
            punches.append((3, filleted_punches[0][1]))
        else:
            punches.append((3, vol_tag))
   
occ.synchronize()

out, _ = occ.cut([(3, block)], punches, removeTool=True)
occ.synchronize()

if out:
    final_vols = [v[1] for v in out]
    gmsh.model.addPhysicalGroup(3, final_vols, tag=100, name="Lattice_Material")

gmsh.option.setNumber("Mesh.Algorithm3D", 1) # Delaunay
gmsh.model.mesh.setSize(gmsh.model.getEntities(0), lc)
gmsh.model.mesh.generate(3)
gmsh.write("../meshes/honeycomb.msh")
gmsh.finalize()

_mesh = meshio.read("../meshes/honeycomb.msh")

points = _mesh.points
z_points = points[:, 2].copy()  # Store z-coordinates
y_points = points[:, 1].copy()  # Store y-coordinates
points[:, 1] = z_points 
points[:, 2] = y_points

if "tetra" in _mesh.cells_dict:
    tetra_elements = _mesh.cells_dict["tetra"]
    print(f"Successfully loaded {len(tetra_elements)} tetrahedrons.")
else:
    print("No tetrahedrons found. Ensure gmsh.model.mesh.generate(3) was called.")


mesh = Mesh(coords=points, elements=tetra_elements)

Info    : Meshing 1D...ence                                                                                  
Info    : [  0%] Meshing curve 13 (Line)
Info    : [ 10%] Meshing curve 19 (Line)
Info    : [ 10%] Meshing curve 20 (Line)
Info    : [ 10%] Meshing curve 21 (Line)
Info    : [ 10%] Meshing curve 31 (Line)
Info    : [ 10%] Meshing curve 36 (Line)
Info    : [ 10%] Meshing curve 37 (Line)
Info    : [ 10%] Meshing curve 38 (Line)
Info    : [ 10%] Meshing curve 39 (Line)
Info    : [ 10%] Meshing curve 46 (Line)
Info    : [ 10%] Meshing curve 48 (Line)
Info    : [ 10%] Meshing curve 49 (Line)
Info    : [ 10%] Meshing curve 54 (Line)
Info    : [ 10%] Meshing curve 55 (Line)
Info    : [ 10%] Meshing curve 56 (Line)
Info    : [ 10%] Meshing curve 57 (Line)
Info    : [ 10%] Meshing curve 64 (Line)
Info    : [ 10%] Meshing curve 66 (Line)
Info    : [ 10%] Meshing curve 67 (Line)
Info    : [ 10%] Meshing curve 72 (Line)
Info    : [ 10%] Meshing curve 73 (Line)
Info    : [ 10%] Meshing curv



In [None]:
grid = pv.UnstructuredGrid(
    np.hstack((np.full((mesh.elements.shape[0], 1), 4), mesh.elements)).flatten(),
    np.full(mesh.elements.shape[0], pv.CellType.TETRA),
    np.array(mesh.coords)
)

pl = pv.Plotter()
pl.add_mesh(grid, show_edges=True, color="lightgray",  smooth_shading=False)
pl.view_isometric()
pl.show()

![Geometry and Mesh](../assets/plots/honeycomb_mesh.png)

We use `Tetrahedral` elements for the mesh and below we define the `Operator` object.


In [5]:
tet = element.Tetrahedron4()
op = Operator(mesh, tet)


n_dofs_per_node = 3
n_nodes, n_dofs = mesh.coords.shape[0], mesh.coords.shape[0] * n_dofs_per_node 

## Defining the Neural Constitutive Model

The specific architecture employed for the neural strain energy density was a feed-forward Multi-Layer Perceptron (MLP). The network consisted of an input layer accepting the two scalar invariants $(I_1, J)$, followed by two hidden layers with 16 neurons each, and a final output layer producing the scalar energy value. To ensure that the second-order derivatives (Hessian) remained continuous and numerically stable, a \texttt{softplus} activation function was utilized across all hidden layers. This choice is critical as standard piecewise linear activations, such as \texttt{ReLU}, yield zero second derivatives almost everywhere, leading to immediate solver divergence.

$$
\psi_{\text{total}}(I_1, J) = \underbrace{\left[ \text{NN}(I_1, J; \theta) - \text{NN}(3, 1; \theta) \right]}_{\text{Shifted Neural Potential}} + \underbrace{\Psi_{\text{base}}(I_1, J)}_{\text{Stiffness Prior}}
$$

!!! note
    
    Note we use an untrained neural network and for actual purpose this should be replaced by a trained neural network

In [6]:
class NeuralMaterial(eqx.Module):
    layers: list
    mu_init: float
    lmbda_init: float

    def __init__(self, key, mu=500.0, lmbda=1000.0):
        self.mu_init = mu
        self.lmbda_init = lmbda

        keys = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(2, 12, key=keys[0]),
            jax.nn.softplus,
            eqx.nn.Linear(12, 12, key=keys[1]),
            jax.nn.softplus,
            eqx.nn.Linear(12, 1, key=keys[2]),
        ]

    def __call__(self, x: Array) -> Array:

        y = x
        for layer in self.layers:
            y = layer(y)
        psi_nn = y[0]

        ref_invariants = jnp.array([3.0, 1.0])
        y_ref = ref_invariants
        for layer in self.layers:
            y_ref = layer(y_ref)
        psi_0 = y_ref[0]

        I1, J = x[0], x[1]
        psi_base = (
            (self.mu_init / 2) * (I1 - 3)
            - self.mu_init * jnp.log(J)
            + (self.lmbda_init / 2) * (jnp.log(J)) ** 2
        )

        # Total Energy = (NN - Offset) + Base
        return (psi_nn - psi_0) + psi_base

Now, we define the neural network architecture and the total strain energy density function based on the neural network defined above.


In [8]:
key = jax.random.PRNGKey(42)

E = 1e4
nu = 0.3
mu = E / 2 / (1 + nu)
lmbda = E * nu / (1 - 2 * nu) / (1 + nu)

nn_material = NeuralMaterial(key, mu=mu, lmbda=lmbda)


@autovmap(grad_u=2)
def neural_strain_energy(grad_u, model):
    I = jnp.eye(3)
    F = I + grad_u
    C = F.T @ F

    I1 = jnp.trace(C)
    J = jnp.linalg.det(F)

    invariants = jnp.array([I1, J])
    return model(invariants)


@eqx.filter_jit
def total_neural_energy(u_flat: Array, model) -> float:
    u = u_flat.reshape(-1, n_dofs_per_node)
    u_grad = op.grad(u)
    energy_density = neural_strain_energy(u_grad, model)
    return op.integrate(energy_density)

To check if the total energy at 0 deformation is zero, we can evaluate the total strain energy density function at the reference configuration where $I_1 = 3$ and $J = 1$. This ensures that the neural network's contribution is shifted appropriately, and the stiffness prior is also evaluated at this point.


## Applying Boundary Conditions and Loads


In [9]:
z_min, z_max = jnp.min(mesh.coords[:, 2]), jnp.max(mesh.coords[:, 2])


top_nodes = jnp.where(jnp.isclose(mesh.coords[:, 2], z_max))[0]
bottom_nodes = jnp.where(jnp.isclose(mesh.coords[:, 2], z_min))[0]

zero_dofs = jnp.concatenate(
    [
        3 * bottom_nodes,
        3 * bottom_nodes + 1,
        3 * bottom_nodes + 2,
        3 * top_nodes,
        3 * top_nodes + 1,
    ]
)
applied_dofs = 3 * top_nodes + 2  # Apply displacement in z-direction

fixed_dofs = jnp.concatenate([applied_dofs, zero_dofs])

prescribed_values = jnp.zeros(n_dofs).at[applied_dofs].set(0.4)


## Using Coloring to compute Sparse Hessians

We will solve the problem using direct linear solver. To this end, we will need to perform sparse differentiation using the `sparse` module of `tatva` and coloring approach from `tatva_coloring.`




In [10]:
sparsity_pattern = sparse.create_sparsity_pattern(mesh, n_dofs_per_node=n_dofs_per_node)
sparsity_pattern_csr = sp.csr_matrix(
    (
        sparsity_pattern.data,
        (sparsity_pattern.indices[:, 0], sparsity_pattern.indices[:, 1]),
    )
)
indptr = sparsity_pattern_csr.indptr
indices = sparsity_pattern_csr.indices
colors = distance2_color_and_seeds(
    row_ptr=sparsity_pattern_csr.indptr,
    col_idx=sparsity_pattern_csr.indices,
    n_dofs=n_dofs,
)[0]

energy_fn = eqx.Partial(total_neural_energy, model=nn_material)
gradient_fn = jax.jacrev(energy_fn)

K_sparse_fn = sparse.jacfwd_with_batch(
    gradient=gradient_fn,
    row_ptr=jnp.array(sparsity_pattern_csr.indptr),
    col_indices=jnp.array(sparsity_pattern_csr.indices),
    colors=jnp.array(colors),
    color_batch_size=mesh.elements.shape[0],
)

zero_indices, one_indices = sparse.get_bc_indices(sparsity_pattern, fixed_dofs)

!!! note

    One can use the matrix-free solver by computing the Jacobian-vector product by simply using `jax.jvp` on the `gradient_fn`.

## Defining Newton Solver

We will use a newton sparse solver

In [11]:
# [collapse: code] Newton Solver with Sparse Linear Solve

@eqx.filter_jit
def newton_sparse_solver(
    u,
    fext,
    gradient,
    hessian_sparse,
    fixed_dofs,
    zero_indices,
    one_indices,
    indptr,
    indices,
):
    fint = gradient(u)

    norm_res = 1.0

    tol = 1e-8
    max_iter = 10

    def solver(u, n):
        def true_func(u):
            fint = gradient(u)
            residual = fext - fint
            residual = residual.at[fixed_dofs].set(0.0)

            K_sparse = hessian_sparse(u)
            K_data_lifted = K_sparse.data.at[zero_indices].set(0)
            K_data_lifted = K_data_lifted.at[one_indices].set(1)

            du = jsp.linalg.spsolve(
                K_data_lifted, indices=indices, indptr=indptr, b=residual
            )

            u = u.at[:].add(du)
            return u

        def false_func(u):
            return u

        fint = gradient(u)
        residual = fext - fint
        residual = residual.at[fixed_dofs].set(0.0)
        norm_res = jnp.linalg.norm(residual)

        jax.debug.print("residual={}", norm_res)

        return jax.lax.cond(norm_res > tol, true_func, false_func, u), n

    u, xs = jax.lax.scan(solver, init=u, xs=jnp.arange(0, max_iter))

    fint = gradient(u)
    residual = fext - fint
    residual = residual.at[fixed_dofs].set(0.0)
    norm_res = jnp.linalg.norm(residual)

    return u, norm_res

## Solving the System


In [12]:
# [output: hide]

u_prev = jnp.zeros(n_dofs)

fext = jnp.zeros(n_dofs)

n_steps = 5
applied_displacement = prescribed_values / n_steps  # displacement increment

for i in range(n_steps):
    u_prev = u_prev.at[fixed_dofs].add(applied_displacement[fixed_dofs])

    u_new, rnorm = newton_sparse_solver(
        u_prev,
        fext,
        gradient_fn,
        K_sparse_fn,
        fixed_dofs,
        zero_indices,
        one_indices,
        indptr,
        indices,
    )

    u_prev = u_new

    print(f"Iteration {i}: Residual Norm = {rnorm:.4e}")

u_sol = u_prev.reshape(n_nodes, n_dofs_per_node)

residual=1371.570589894706
residual=1402.2690751026425
residual=516.3211491024836
residual=167.34643670888582
residual=26.77462596049994
residual=0.9441458517834659
residual=0.0013813731895309456
residual=3.1511882734558528e-09
residual=3.1511882734558528e-09
residual=3.1511882734558528e-09
Iteration 0: Residual Norm = 3.1512e-09
residual=1371.0780886756152
residual=1375.7943623328345
residual=501.66823707291917
residual=160.38678791963835
residual=24.9232784576777
residual=0.8278604173933287
residual=0.0010720930028279227
residual=1.9272968215794568e-09
residual=1.9272968215794568e-09
residual=1.9272968215794568e-09
Iteration 1: Residual Norm = 1.9273e-09
residual=1370.550620706396
residual=1349.8788478066529
residual=487.2624762079373
residual=153.5881601643318
residual=23.16588771120022
residual=0.7238842678098589
residual=0.0008293137477897302
residual=1.1781636224379997e-09
residual=1.1781636224379997e-09
residual=1.1781636224379997e-09
Iteration 2: Residual Norm = 1.1782e-09
resi

## Visualization

We will now visualize the deformation of metamaterial.

In [None]:
# [collapse: code] Visualization of Deformed Configuration

sargs = dict(
    title=r"Displacement Magnitude",
    height=0.08,
    width=0.2,
    vertical=False,
    position_x=0.1,
    position_y=0.2,
    title_font_size=20,
    label_font_size=16,
    color="black",
    font_family="arial",
)


pl = pv.Plotter()
grid["u"] = np.array(u_sol)
warped = grid.warp_by_vector("u", factor=4.0)
warped = warped.cell_data_to_point_data()
pl.add_mesh(
    warped,
    show_edges=False,
    scalars="u",
    component=0,
    cmap="managua",
    line_width=0.1,
    scalar_bar_args=sargs,
)
pl.view_vector([-0.55, -0.65, 0.5])
pl.show()

![Deformed Honeycomb](../assets/plots/neural_constitutive_deformed.png)