In [1]:
import os
os.environ["JAX_CACHE_DIR"] = "/cluster/scratch/mpundir/jax-cache"

import jax
jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_platforms", "cpu")

from femsolver.quadrature import get_element
from femsolver.operator import Operator
import jax.numpy as jnp

import matplotlib.pyplot as plt
import cmcrameri.cm as cmc

In this notebook, we will demonstrate how to use the `Operator` class to integrate a function over a domain.

To keep things simple, we will consider a 2D domain with a triangular mesh. The length and the width of the domain are denoted by `Lx` and `Ly` respectively. The number of elements in the x-direction and the y-direction are denoted by `nx` and `ny` respectively.

In [2]:
# --- Mesh generation ---
def generate_rectangular_mesh_tri(nx, ny, Lx, Ly):
    x = jnp.linspace(0, Lx, nx + 1)
    y = jnp.linspace(0, Ly, ny + 1)
    xv, yv = jnp.meshgrid(x, y, indexing="ij")
    coords = jnp.stack([xv.ravel(), yv.ravel()], axis=-1)

    def node_id(i, j):
        return i * (ny + 1) + j

    elements = []
    for i in range(nx):
        for j in range(ny):
            n0 = node_id(i, j)
            n1 = node_id(i + 1, j)
            n2 = node_id(i, j + 1)
            n3 = node_id(i + 1, j + 1)
            elements.append([n0, n1, n3])
            elements.append([n0, n3, n2])
    return coords, jnp.array(elements)

We use the `get_element` function to get the quadrature points and weights for the triangular element. The `femsolver.quadrature` module contains the `get_element` function and a few pre-defined elements.

In [24]:
tri3 = get_element("tri3")
print(tri3.get_quadrature())

(Array([[0.33333333, 0.33333333]], dtype=float64), Array([0.5], dtype=float64))


Now, we create the `Operator` object. The `Operator` class takes two arguments:
- `element`: the element to use for the integration
- `integrand`: the function to integrate

Here, we defined a simple python function that returns the argument it excepts. The output of this function represents the integrand at each quadrature point. The `Operator` class will use this function at each quadrature point to compute the integral.

In [29]:
fem = Operator(element=tri3, integrand=lambda u: u)

In [None]:
Lx = 10
Ly = 10
nx = 10
ny = 10

coords,  elements = generate_rectangular_mesh_tri(nx=nx, ny=ny, Lx=Lx, Ly=Ly)
n_nodes = coords.shape[0]
n_dofs_per_node = 1
n_dofs = n_dofs_per_node * n_nodes

Next, we define a function that reshape the dof values at each node into a cell-wise array. Within this function, we call the `integrate` method of the `Operator` object to compute the integral.

In [27]:
# --- Total energy ---
def total_energy(u, coords, elements, fem):
    u_cell = u[elements]
    x_cell = coords[elements]
    return jnp.sum(fem.integrate(u_cell, x_cell))


Now, we assign value of the solution vector `u` to 1.0. The `Operator` class will use this value to compute the integral. The integration of 1.0 over the domain should give us the area of the domain.

In [30]:
u = jnp.ones(n_dofs)
total_energy(u, coords, elements, fem) == Lx * Ly

Array(True, dtype=bool)