# Periodic homogenization of linear elasticity

After Jeremy Bleyer's example here https://bleyerj.github.io/comet-fenicsx/tours/homogenization/periodic_elasticity/periodic_elasticity.html

$$
\begin{split}\begin{cases}\operatorname{div} \boldsymbol{\sigma} = \boldsymbol{0} & \text{in } \mathcal{A} \\
\boldsymbol{\sigma} = \mathbb{C}(\boldsymbol{y}):\boldsymbol{\varepsilon} & \text{for }\boldsymbol{y}\in\mathcal{A} \\
\boldsymbol{\varepsilon} = \boldsymbol{E} + \nabla^s \boldsymbol{v} & \text{in } \mathcal{A} \\
\boldsymbol{v} & \text{is } \mathcal{A}\text{-periodic} \\
\boldsymbol{T}=\boldsymbol{\sigma}\cdot\boldsymbol{n} & \text{is } \mathcal{A}\text{-antiperiodic}
\end{cases}
\end{split}
$$

In [None]:
# [collapse: code] Colab Setup (Install Dependencies)

# Only run this if we are in Google Colab
if "google.colab" in str(get_ipython()):
    print("Installing dependencies using uv...")
    # Install uv if not available
    !pip install -q uv
    # Install system dependencies
    !apt-get install -qq gmsh
    # Use uv to install Python dependencies
    !uv pip install --system matplotlib meshio
    !uv pip install --system "git+https://github.com/smec-ethz/tatva-docs.git"
    print("Installation complete!")

In [None]:
import gmsh
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import Array
from tatva import Mesh, Operator, compound, element

jax.config.update("jax_enable_x64", True)

## Meshing

We generate a unit cell with circular inclusions using `gmsh`. The mesh includes both a matrix phase and inclusion phase with periodic boundaries.

In [None]:
# [collapse: code] Code for mesh generation and extraction

from itertools import chain


def plot_mesh(mesh: Mesh, ax: plt.Axes | None = None) -> None:
    if ax is None:
        fig, ax = plt.subplots()
    ax.triplot(
        mesh.coords[:, 0], mesh.coords[:, 1], mesh.elements, color="gray", linewidth=0.5
    )
    ax.set_aspect("equal")


def extract_physical_groups(tag_map: dict) -> dict[str, np.ndarray]:
    print("Extracting physical groups from Gmsh model...")
    physical_surfaces: dict[str, np.ndarray] = {}

    for dim, pg_tag in chain(
        gmsh.model.getPhysicalGroups(dim=1), gmsh.model.getPhysicalGroups(dim=2)
    ):
        name = gmsh.model.getPhysicalName(dim, pg_tag)

        # Entities (surface tags) that belong to this physical group
        entities = gmsh.model.getEntitiesForPhysicalGroup(dim, pg_tag)

        els = []
        for ent in entities:
            # Get all mesh elements on this surface entity
            types, _, node_tags_by_type = gmsh.model.mesh.getElements(dim, ent)

            for etype, ntags in zip(types, node_tags_by_type):
                nodes = np.array(ntags, dtype=np.int64).reshape(-1, etype + 1)
                els.append(nodes)

        if not els:
            physical_surfaces[name] = np.zeros((0, dim + 1), dtype=np.int32)
            continue

        group_els = np.vstack(els, dtype=np.int32)
        group_els = np.array(
            [[tag_map[t] for t in tri] for tri in group_els], dtype=np.int32
        )
        physical_surfaces[name] = group_els

    return physical_surfaces


def extract_mesh_data() -> tuple[Array, Array, dict[str, np.ndarray]]:
    # Extract nodes and elements
    node_tags, node_coords, _ = gmsh.model.mesh.getNodes()
    tag_map = {tag: i for i, tag in enumerate(node_tags)}
    nodes = jnp.array(node_coords).reshape(-1, 3)[:, :2]

    elem_types, elem_tags, elem_node_tags = gmsh.model.mesh.getElements(2)
    elements = jnp.array(elem_node_tags[0]).reshape(-1, 3) - 1

    pg = extract_physical_groups(tag_map)
    return nodes, elements, pg


Lx = 1.0
Ly = np.sqrt(3) / 2.0 * Lx
c = 0.5 * Lx
R = 0.2 * Lx
h = 0.02 * Lx

corners = np.array([[0.0, 0.0], [Lx, 0.0], [Lx + c, Ly], [c, Ly]])
a1 = corners[1, :] - corners[0, :]  # first vector generating periodicity
a2 = corners[3, :] - corners[0, :]  # second vector generating periodicity

gdim = 2  # domain geometry dimension
fdim = 1  # facets dimension
gmsh.initialize()

occ = gmsh.model.occ
model_rank = 0
points = [occ.add_point(*corner, 0) for corner in corners]
lines = [occ.add_line(points[i], points[(i + 1) % 4]) for i in range(4)]
loop = occ.add_curve_loop(lines)
unit_cell = occ.add_plane_surface([loop])
inclusions = [occ.add_disk(*corner, 0, R, R) for corner in corners]
vol_dimTag = (gdim, unit_cell)
out = occ.intersect(
    [vol_dimTag], [(gdim, incl) for incl in inclusions], removeObject=False
)
incl_dimTags = out[0]
occ.synchronize()
occ.cut([vol_dimTag], incl_dimTags, removeTool=False)
occ.synchronize()

# tag physical domains and facets
gmsh.model.addPhysicalGroup(gdim, [vol_dimTag[1]], 1, name="Matrix")
gmsh.model.addPhysicalGroup(
    gdim,
    [tag for _, tag in incl_dimTags],
    2,
    name="Inclusions",
)
gmsh.model.addPhysicalGroup(fdim, [7, 20, 10], 1, name="bottom")
gmsh.model.addPhysicalGroup(fdim, [9, 19, 16], 2, name="right")
gmsh.model.addPhysicalGroup(fdim, [15, 18, 12], 3, name="top")
gmsh.model.addPhysicalGroup(fdim, [11, 17, 5], 4, name="left")
gmsh.option.setNumber("Mesh.CharacteristicLengthMin", h)
gmsh.option.setNumber("Mesh.CharacteristicLengthMax", h)

gmsh.model.mesh.generate(gdim)

nodes, elements, pg = extract_mesh_data()
mesh = Mesh(coords=nodes, elements=elements, groups=pg)
gmsh.finalize()

plot_mesh(mesh)

## System setup

We define the material properties for both the matrix and inclusion phases, and set up the strain energy functional.

In [None]:
from typing import NamedTuple

from jax_autovmap import autovmap


class Material(NamedTuple):
    """Material properties for the elasticity operator."""

    mu: float  # Shear modulus
    lmbda: float  # LamÃ© parameter

    @classmethod
    def from_youngs_poisson_2d(
        cls, E: float, nu: float, plane_stress: bool = False
    ) -> "Material":
        mu = E / 2 / (1 + nu)
        if plane_stress:
            lmbda = 2 * nu * mu / (1 - nu)
        else:
            lmbda = E * nu / (1 - 2 * nu) / (1 + nu)
        return cls(mu=mu, lmbda=lmbda)


mat1 = Material.from_youngs_poisson_2d(50e3, 0.2)
mat2 = Material.from_youngs_poisson_2d(210e3, 0.3)


@autovmap(grad_u=2)
def compute_strain(grad_u):
    return 0.5 * (grad_u + grad_u.T)


@autovmap(eps=2, mu=0, lmbda=0)
def compute_stress(eps, mu, lmbda):
    return 2 * mu * eps + lmbda * jnp.trace(eps) * jnp.eye(2)


@autovmap(grad_u=2, eps_hat=2, mu=0, lmbda=0)
def strain_energy(grad_u, eps_hat, mu, lmbda):
    """Includes macroscopic strain eps_hat."""
    eps = compute_strain(grad_u) + eps_hat
    sigma = compute_stress(eps, mu, lmbda)
    return 0.5 * jnp.einsum("ij,ij->", sigma, eps)

## Periodic boundary conditions

We establish the periodic boundary conditions by creating mappings between master and slave edges.

In [None]:
from numpy.typing import NDArray

ArrayLike = NDArray | Array

mesh_matrix = mesh._replace(elements=mesh.groups["Matrix"])
mesh_inclusion = mesh._replace(elements=mesh.groups["Inclusions"])
op_matrix = Operator(mesh_matrix, element.Tri3())
op_inclusion = Operator(mesh_inclusion, element.Tri3())
op = Operator(mesh, element.Tri3())


def edge_bijection(
    coords: ArrayLike,
    m_group: ArrayLike,
    s_group: ArrayLike,
    *,
    axis: int = 0,
    offset: float = 0.0,
) -> Array:
    """Find a bijection between master and slave edge nodes. Returns an array of sorted
    slave node indices such that s_nodes[i] -> m_nodes[i]
    """
    axis = axis ^ 1  # bitwise xor to get the other axis
    m_nodes = jnp.unique(m_group)
    s_nodes = jnp.unique(s_group)

    # exclude corners from the matching (min and max in axis direction)
    m_nodes = m_nodes[
        (coords[m_nodes, axis] != jnp.min(coords[m_nodes, axis]))
        & (coords[m_nodes, axis] != jnp.max(coords[m_nodes, axis]))
    ]
    s_nodes = s_nodes[
        (coords[s_nodes, axis] != jnp.min(coords[s_nodes, axis]))
        & (coords[s_nodes, axis] != jnp.max(coords[s_nodes, axis]))
    ]

    # for each m_node, find the closest s_node depending on the periodicity vector
    def find_closest_slave(m_node: ArrayLike, inv_axis: int) -> ArrayLike:
        diffs = coords[s_nodes, inv_axis] - (coords[m_node, inv_axis] + offset)
        return jnp.array([m_node, s_nodes[jnp.argmin(diffs**2)]])

    return jax.vmap(find_closest_slave, in_axes=(0, None))(m_nodes, axis)


corner_nodes = [
    jnp.argmin(jnp.linalg.norm(mesh.coords - corner, axis=1)) for corner in corners
]
corner_m = jnp.repeat(corner_nodes[0], 3)
corner_s = jnp.array(corner_nodes[1:])
left_right = edge_bijection(
    mesh.coords, mesh.groups["left"], mesh.groups["right"], axis=0
)
bottom_top = edge_bijection(
    mesh.coords, mesh.groups["bottom"], mesh.groups["top"], axis=1, offset=Lx * 1 / 2
)
corner_map = jnp.vstack([corner_m, corner_s]).T

## Lifter and constraint setup

We use the `Lifter` to handle periodic boundary conditions and Dirichlet constraints.

In [None]:
from tatva.sparse import reduce_sparsity_pattern
from tatva.lifter import PeriodicMap, Lifter, DirichletBC
from tatva.sparse._extraction import create_sparsity_pattern_master_slave


class Solution(compound.Compound):
    u = compound.field(mesh.coords.shape)


periodic_map = jnp.concatenate(
    [
        jnp.array([Solution.u[nodes, :] for nodes in left_right.T]).T,
        jnp.array([Solution.u[nodes, :] for nodes in bottom_top.T]).T,
        jnp.array([Solution.u[nodes, :] for nodes in corner_map.T]).T,
    ]
)

sparsity = create_sparsity_pattern_master_slave(
    mesh,
    2,
    jnp.arange(Solution.size).at[periodic_map[:, 1]].set(periodic_map[:, 0]),
)
sparsity = reduce_sparsity_pattern(
    sparsity, jnp.setdiff1d(jnp.arange(0, sparsity.shape[0]), Solution.u[[corner_m[0]]])
)

lifter = Lifter(
    Solution.size,
    DirichletBC(Solution.u[[corner_m[0]]]),
    PeriodicMap(periodic_map[:, 1], periodic_map[:, 0]),
)

## Energy functional and sparse Jacobian

We define the total energy functional and compute its sparse Jacobian for efficient solving.

In [None]:
import scipy.sparse as sp
from tatva import sparse


def total_energy(u_flat: Array, eps_hat: Array) -> Array:
    (u,) = Solution(u_flat)
    # Inclusion gets Material 2, Matrix gets Material 1
    e_inclusion = op_inclusion.integrate(
        strain_energy(op_inclusion.grad(u), eps_hat, mat2.mu, mat2.lmbda)
    )
    e_matrix = op_matrix.integrate(
        strain_energy(op_matrix.grad(u), eps_hat, mat1.mu, mat1.lmbda)
    )
    return e_inclusion + e_matrix


def lagrangian(u_free: Array, eps_hat: Array) -> Array:
    u_full = lifter.lift_from_zeros(u_free)
    return total_energy(u_full, eps_hat)


residual = jax.jacrev(lagrangian)

sparsity_csr = sp.csr_matrix(
    (sparsity.data, (sparsity.indices[:, 0], sparsity.indices[:, 1])),
    shape=sparsity.shape,
)
indptr = sparsity_csr.indptr
indices = sparsity_csr.indices

colors = sparse.distance2_colors(indptr, indices, lifter.size_reduced)
jacobian = sparse.jacfwd(residual, indptr, indices, colors)

## Solve for unit strains

We solve the homogenization problem for three unit macroscopic strain states to compute the effective stiffness tensor.

In [None]:
from dataclasses import dataclass
from soldis import linear
import jax.experimental.sparse as jsparse


class DirectSolverSparse(linear.DirectLinearSolver):
    def __call__(self, A: jsparse.BCOO, b: Array) -> Array:  # type: ignore
        return jsparse.linalg.spsolve(A.data, indices, indptr, b)


@dataclass
class Result:
    u: NDArray
    eps_hat: NDArray


v0 = jnp.zeros(lifter.size_reduced)
solver = DirectSolverSparse()

scale = 0.3
eps_hat_list = scale * jnp.array(
    [
        [[1.0, 0.0], [0.0, 0.0]],
        [[0.0, 0.0], [0.0, 1.0]],
        [[0.0, 0.5], [0.5, 0.0]],
    ]
)  # unit strain tensors

results = []

for eps_hat in eps_hat_list:
    v = solver(jacobian(v0, eps_hat), -residual(v0, eps_hat))
    u0 = (eps_hat @ mesh.coords.T).T
    u = u0 + Solution(lifter.lift_from_zeros(v)).u
    results.append(Result(u=u, eps_hat=eps_hat))

## Compute homogenized stiffness tensor

We compute the effective (homogenized) stiffness tensor using automatic differentiation.

In [None]:
def stress(u: Array) -> Array:
    grad_u = op.grad(u).squeeze()
    eps = compute_strain(grad_u)
    sig_1 = compute_stress(eps, mat1.mu, mat1.lmbda)
    sig_2 = compute_stress(eps, mat2.mu, mat2.lmbda)
    mask = jnp.isin(op.mesh.elements, mesh_inclusion.elements).all(axis=1)  # (nels,)
    sig = jnp.where(mask[:, None, None], sig_2, sig_1)
    return sig


def func(eps_hat_voigt: Array) -> Array:
    eps_hat = jnp.array(
        [[eps_hat_voigt[0], eps_hat_voigt[2]], [eps_hat_voigt[2], eps_hat_voigt[1]]]
    )
    v = linear.CG()(jacobian(v0, eps_hat), -residual(v0, eps_hat))
    u = (eps_hat @ mesh.coords.T).T + Solution(lifter.lift_from_zeros(v)).u
    sig = stress(u)
    sig = jnp.mean(sig, axis=0)  # average stress
    return jnp.array([sig[0, 0], sig[1, 1], sig[0, 1]])


C_hom = jax.jacfwd(func)(jnp.ones(3))
print("Homogenized stiffness tensor (Voigt notation):")
print(C_hom)

## Plot Solution

We visualize the displacement fields for the different unit strain cases.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
strain_labels = [r"$\varepsilon_{xx}$", r"$\varepsilon_{yy}$", r"$\varepsilon_{xy}$"]

for i, (result, ax, label) in enumerate(zip(results, axes, strain_labels)):
    u_mag = jnp.linalg.norm(result.u, axis=1)
    triplot = ax.tricontourf(
        mesh.coords[:, 0],
        mesh.coords[:, 1],
        mesh.elements,
        u_mag,
        levels=20,
        cmap="viridis",
    )
    ax.set_aspect("equal")
    ax.set_title(f"Displacement magnitude for {label}")
    plt.colorbar(triplot, ax=ax)

plt.tight_layout()
plt.show()