In [1]:
import os
os.environ["JAX_CACHE_DIR"] = "/cluster/scratch/mpundir/jax-cache"

import jax
jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_platforms", "cpu")

from femsolver.quadrature import get_element
from femsolver.operator import Operator
import jax.numpy as jnp

import matplotlib.pyplot as plt
import cmcrameri.cm as cmc

In this notebook, we will demonstrate to simulate a bar under axial loading.

To keep things simple, we will consider a bar of length `L` and cross-sectional area `A`. The bar is subjected to an axial load `F` at its right end. The bar is fixed at its left end.

The bar is discretized into `nb_elem` elements. The length of each element is `L / nb_elem`. The cross-sectional area of each element is `A`.

In [2]:
# --- Mesh generation ---
def generate_bar_mesh(nx, lx):
    xi = jnp.linspace(lx[0], lx[1], nx + 1)
    yi = jnp.zeros_like(xi)
    coordinates = jnp.vstack((xi.flatten(), yi.flatten())).T
    elements = list()
    for i in range(nx):
        elements.append([i, i + 1])
    elements = jnp.unique(jnp.array(elements), axis=0)

    return coordinates, elements

We use the `get_element` function to get the quadrature points and weights for the triangular element. The `femsolver.quadrature` module contains the `get_element` function and a few pre-defined elements.

In [3]:
line2 = get_element("line2")
print(line2.get_quadrature())
print(line2.get_shape_functions(jnp.array([0.0])))

(Array([[0.]], dtype=float64), Array([2.], dtype=float64))
(Array([0.5, 0.5], dtype=float64), Array([-0.5,  0.5], dtype=float64))


Now, we create the `Operator` object. The `Operator` class takes two arguments:
- `element`: the element to use for the integration
- `integrand`: the function to integrate

Here, we defined a simple python function that returns the argument it excepts. The output of this function represents the integrand at each quadrature point. The `Operator` class will use this function at each quadrature point to compute the integral.

In [4]:
fem = Operator(element=line2, integrand=lambda u: u)

In [5]:
L = 10
nx = 10

coords,  elements = generate_bar_mesh(nx=nx, lx=(0, L))
n_nodes = coords.shape[0]
n_dofs_per_node = 1
n_dofs = n_dofs_per_node * n_nodes

Next, we define a function that reshape the dof values at each node into a cell-wise array. Within this function, we call the `integrate` method of the `Operator` object to compute the integral.

In [6]:
# --- Total energy ---
def total_energy(u, coords, elements, fem):
    u_cell = u[elements]
    x_cell = coords[elements]
    return jnp.sum(fem.integrate(u_cell, x_cell))


Now, we assign value of the solution vector `u` to 1.0. The `Operator` class will use this value to compute the integral. The integration of 1.0 over the domain should give us the area of the domain.

In [7]:
u = jnp.ones(n_dofs)
total_energy(u, coords, elements, fem) == L

Array(True, dtype=bool)

In [8]:
fem.gradient(u[elements], coords[elements])

Array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float64)

In [10]:
fem.interpolate(u[elements])

Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float64)