In [1]:
import os
os.environ["JAX_CACHE_DIR"] = "/cluster/scratch/mpundir/jax-cache"

import jax
jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_platforms", "cpu")

from femsolver.quadrature import get_element
from femsolver.operator import IntegrateOperator
import jax.numpy as jnp

import matplotlib.pyplot as plt
import cmcrameri.cm as cmc

In [16]:
# --- Mesh generation ---
def generate_rectangular_mesh_tri(nx, ny, Lx, Ly):
    x = jnp.linspace(0, Lx, nx + 1)
    y = jnp.linspace(0, Ly, ny + 1)
    xv, yv = jnp.meshgrid(x, y, indexing="ij")
    coords = jnp.stack([xv.ravel(), yv.ravel()], axis=-1)

    def node_id(i, j):
        return i * (ny + 1) + j

    elements = []
    for i in range(nx):
        for j in range(ny):
            n0 = node_id(i, j)
            n1 = node_id(i + 1, j)
            n2 = node_id(i, j + 1)
            n3 = node_id(i + 1, j + 1)
            elements.append([n0, n1, n3])
            elements.append([n0, n3, n2])
    return coords, jnp.array(elements)

In [27]:

tri3 = get_element("tri3")
print(tri3.get_quadrature())

(Array([[0.33333333, 0.33333333]], dtype=float64), Array([0.5], dtype=float64))


In [18]:
def integrand(u):
    return u

In [19]:
fem = IntegrateOperator(element=tri3, integrand=integrand)

In [24]:
Lx = 10
Ly = 10
nx = 10
ny = 10

coords,  elements = generate_rectangular_mesh_tri(nx=nx, ny=ny, Lx=Lx, Ly=Ly)
n_nodes = coords.shape[0]
n_dofs_per_node = 1
n_dofs = n_dofs_per_node * n_nodes
u = jnp.ones(n_dofs)

In [25]:
# --- Total energy ---
def total_energy(u, coords, elements, fem):
    u_cell = u[elements]
    x_cell = coords[elements]
    return jnp.sum(fem.integrate(u_cell, x_cell))


In [26]:
total_energy(u, coords, elements, fem) == Lx * Ly

Array(True, dtype=bool)