In [34]:
import os

os.environ["JAX_PLATFORM"] = "cpu"
os.environ["JAX_CACHE_DIR"] = "/cluster/scratch/mpundir/jax-cache"
os.environ["PLOT_LIB_PATH"] = "/cluster/home/mpundir/dev"

import jax

jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_compilation_cache_dir", os.environ["JAX_CACHE_DIR"])
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_platforms", os.environ["JAX_PLATFORM"])


import jax.numpy as jnp
from jax import grad, jacfwd, vmap, lax

In [4]:
# -----------------------------
# MATERIAL MODEL (History-dependent Neo-Hookean with dummy history update)
# -----------------------------
def neo_hookean_energy(ue, Xe, history, mu=1.0, lam=1.0):
    F = compute_deformation_gradient(ue, Xe)
    J = jnp.linalg.det(F)
    I1 = jnp.trace(F.T @ F)
    psi = 0.5 * mu * (I1 - 3) - mu * jnp.log(J) + 0.5 * lam * (jnp.log(J)) ** 2
    return psi


def update_history(history, ue, Xe):
    # Dummy history update (no real state)
    return history


# -----------------------------
# FEM UTILITIES
# -----------------------------
def compute_deformation_gradient(ue, Xe):
    # Simplified for 2D quad elements
    dNdxi = jnp.array(
        [[-0.25, 0.25, 0.25, -0.25], [-0.25, -0.25, 0.25, 0.25]]
    )  # Shape fn derivatives at center
    J = Xe.T @ dNdxi.T
    dNdx = jnp.linalg.solve(J.T, dNdxi)
    F = ue.reshape(-1, 2).T @ dNdx.T
    return F


# -----------------------------
# SCATTERING UTILITIES
# -----------------------------
def scatter_element_vector(f_elem, element_dofs, total_dofs):
    return lax.scatter_add(
        jnp.zeros(total_dofs),
        jnp.expand_dims(element_dofs, -1),
        f_elem,
        dimension_numbers=lax.ScatterDimensionNumbers(
            update_window_dims=(0,),
            inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,),
        ),
    )


def scatter_element_stiffness(Ke, element_dofs, total_dofs):
    idx = jnp.repeat(element_dofs[:, None], element_dofs.size, axis=1)
    jdx = jnp.repeat(element_dofs[None, :], element_dofs.size, axis=0)
    return lax.scatter_add(
        jnp.zeros((total_dofs, total_dofs)),
        jnp.stack([idx, jdx], -1),
        Ke,
        dimension_numbers=lax.ScatterDimensionNumbers(
            update_window_dims=(0, 1),
            inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0, 1),
        ),
    )


# -----------------------------
# FEM SOLVER CLASS
# -----------------------------
class FEMSolver:
    def __init__(self, X, elements, element_energy_fn):
        self.X = X
        self.elements = elements
        self.element_energy_fn = element_energy_fn
        self.num_nodes = X.shape[0]
        self.dim = X.shape[1]
        self.total_dofs = self.num_nodes * self.dim

    def get_element_dofs(self, idx):
        if self.dim == 2:
            return jnp.ravel(jnp.stack([2 * idx, 2 * idx + 1], axis=1))
        else:
            return jnp.ravel(jnp.stack([3 * idx, 3 * idx + 1, 3 * idx + 2], axis=1))

    def compute_residual(self, u_flat, history):
        def element_residual(e_idx):
            idx = self.elements[e_idx]
            Xe = self.X[idx]
            dofs = self.get_element_dofs(idx)
            ue = u_flat[dofs]
            fe = grad(self.element_energy_fn)(ue, Xe, history[e_idx])
            return dofs, fe

        dofs_all, f_elems = vmap(element_residual)(jnp.arange(self.elements.shape[0]))
        return jnp.sum(
            vmap(scatter_element_vector, in_axes=(0, 0, None))(
                f_elems, dofs_all, self.total_dofs
            ),
            axis=0,
        )

    def compute_stiffness(self, u_flat, history):
        def element_stiffness(e_idx):
            idx = self.elements[e_idx]
            Xe = self.X[idx]
            dofs = self.get_element_dofs(idx)
            ue = u_flat[dofs]
            Ke = jacfwd(grad(self.element_energy_fn))(ue, Xe, history[e_idx])
            return dofs, Ke

        dofs_all, K_elems = vmap(element_stiffness)(jnp.arange(self.elements.shape[0]))
        return jnp.sum(
            vmap(scatter_element_stiffness, in_axes=(0, 0, None))(
                K_elems, dofs_all, self.total_dofs
            ),
            axis=0,
        )

    def apply_dirichlet(self, K, f, fixed_dofs, fixed_vals):
        K = K.at[fixed_dofs, :].set(0.0)
        K = K.at[:, fixed_dofs].set(0.0)
        K = K.at[fixed_dofs, fixed_dofs].set(1.0)
        f = f.at[fixed_dofs].set(fixed_vals)
        return K, f

    def apply_neumann(self, f, neumann_dofs, neumann_vals):
        return f.at[neumann_dofs].add(neumann_vals)

    def solve_newton(
        self,
        u0,
        history,
        fixed_dofs,
        fixed_vals,
        neumann_dofs,
        neumann_vals,
        n_steps=10,
    ):
        u = u0
        for _ in range(n_steps):
            f_int = self.compute_residual(u, history)
            K = self.compute_stiffness(u, history)
            K_bc, f_bc = self.apply_dirichlet(K, -f_int, fixed_dofs, fixed_vals)
            f_bc = self.apply_neumann(f_bc, neumann_dofs, neumann_vals)
            delta_u = jnp.linalg.solve(K_bc, f_bc)
            u = u + delta_u
            history = vmap(update_history)(
                history, self.extract_ues(u), self.get_element_Xs()
            )
        return u, history

    def extract_ues(self, u_flat):
        def get_ue(e_idx):
            idx = self.elements[e_idx]
            dofs = self.get_element_dofs(idx)
            return u_flat[dofs]

        return vmap(get_ue)(jnp.arange(self.elements.shape[0]))

    def get_element_Xs(self):
        return vmap(lambda e: self.X[self.elements[e]])(
            jnp.arange(self.elements.shape[0])
        )

In [5]:
# -----------------------------
# EXAMPLE USAGE
# -----------------------------
# Simple 2D quad mesh: 2 elements, 4 nodes
X = jnp.array([[0, 0], [1, 0], [1, 1], [0, 1]])
elements = jnp.array([[0, 1, 2, 3]])
total_dofs = X.size


def energy_fn(ue, Xe, history):
    return neo_hookean_energy(ue, Xe, history)


history = jnp.zeros((elements.shape[0],))  # Dummy history per element
fem = FEMSolver(X, elements, energy_fn)
u0 = jnp.zeros(total_dofs)
fixed_dofs = jnp.array([0, 1])
fixed_vals = jnp.array([0.0, 0.0])
neumann_dofs = jnp.array([4, 5])
neumann_vals = jnp.array([0.1, 0.0])
u_final, history_final = fem.solve_newton(
    u0, history, fixed_dofs, fixed_vals, neumann_dofs, neumann_vals
)
print("Final displacement:", u_final)
print("Final history:", history_final)

TypeError: Updates tensor must be of rank 3; got 2.