In [None]:
%%bash
set -e

# ==================================================
# Basic configuration
# ==================================================

export MAMBA_ROOT_PREFIX=/usr/local/micromamba
export MAMBA_LOG_LEVEL=error
export PATH=/usr/local/micromamba/bin:$PATH

ENV_NAME=fenicsx
YML_FILE=fenicsx.yml
HASH_FILE=/usr/local/micromamba/.${ENV_NAME}_yml.hash

# ==================================================
# 1. Ensure micromamba is installed
# ==================================================

if [ ! -x /usr/local/micromamba/bin/micromamba ]; then
  echo "Installing micromamba..."
  wget -qO- https://micromamba.snakepit.net/api/micromamba/linux-64/latest \
    | tar -xvj bin/micromamba
  mkdir -p /usr/local/micromamba/bin
  mv bin/micromamba /usr/local/micromamba/bin/micromamba
  chmod +x /usr/local/micromamba/bin/micromamba
else
  echo "‚úÖ micromamba already installed."
fi

# ==================================================
# 2. Generate fenicsx.yml only if changed (hash-based)
# ==================================================

TMP_YML=$(mktemp)

cat << 'EOF' > "$TMP_YML"
name: fenicsx
channels:
  - conda-forge
dependencies:
  - fenics-dolfinx=0.10
  - pyvista>=0.45.0
  - mpi4py
  - ipyparallel
  - scipy
  - vtk
  - pygraphviz
  - jupyter-book
  - jupytext
  - trame-client
  - trame-vtk
  - trame-server
  - trame-vuetify
  - trame
  - ipywidgets
  - sphinx>=6.0.0
  - python-gmsh
variables:
  PYVISTA_OFF_SCREEN: false
  PYVISTA_JUPYTER_BACKEND: "trame"
  LIBGL_ALWAYS_SOFTWARE: 1
EOF

NEW_HASH=$(sha256sum "$TMP_YML" | awk '{print $1}')

if [ -f "${YML_FILE}" ]; then
  OLD_YML_HASH=$(sha256sum "${YML_FILE}" | awk '{print $1}')
  if [ "$OLD_YML_HASH" != "$NEW_HASH" ]; then
    echo "üîÑ fenicsx.yml changed"
    mv "$TMP_YML" "${YML_FILE}"
  else
    rm "$TMP_YML"
  fi
else
  mv "$TMP_YML" "${YML_FILE}"
  echo "‚úÖ fenicsx.yml created"
fi

# ==================================================
# 3. Load previously stored YAML hash (if any)
# ==================================================

if [ -f "${HASH_FILE}" ]; then
  STORED_HASH=$(cat "${HASH_FILE}")
else
  STORED_HASH=""
fi

# ==================================================
# 4. Check whether the environment exists
# ==================================================

ENV_EXISTS=false
if micromamba env list | awk '{print $1}' | grep -qx "${ENV_NAME}"; then
  ENV_EXISTS=true
fi

# ==================================================
# 5. Create / recreate environment if needed
# ==================================================

if $ENV_EXISTS && [ "$NEW_HASH" = "$STORED_HASH" ]; then
  echo "‚úÖ ${ENV_NAME} environment is up to date. Skipping installation."
else
  echo "üîÑ (Re)creating ${ENV_NAME} environment..."

  if $ENV_EXISTS; then
    micromamba remove -n "${ENV_NAME}" -y --quiet
  fi

  micromamba create -n "${ENV_NAME}" -f "${YML_FILE}" -y --quiet

  echo "$NEW_HASH" > "${HASH_FILE}"

  echo "üéâ ${ENV_NAME} environment is ready."
fi

Installing micromamba...
bin/micromamba
‚úÖ fenicsx.yml created
üîÑ (Re)creating fenicsx environment...
üéâ fenicsx environment is ready.


---

In [None]:
from IPython.core.magic import register_cell_magic
import subprocess, textwrap, os, shlex, tempfile

# --------------------------------------------------
# MPI detection helpers
# --------------------------------------------------

def detect_mpi_impl(env):
    try:
        out = subprocess.run(
            ["mpiexec", "--version"],
            env=env, capture_output=True, text=True, timeout=2
        )
        txt = (out.stdout + out.stderr).lower()
        if "open mpi" in txt:
            return "openmpi"
        if "mpich" in txt or "hydra" in txt:
            return "mpich"
    except Exception:
        pass
    return "mpich"   # Safe default (Colab)

def mpi_version_string(env):
    try:
        out = subprocess.run(
            ["mpiexec", "--version"],
            env=env, capture_output=True, text=True, timeout=2
        )
        return (out.stdout + out.stderr).strip().splitlines()[0]
    except Exception:
        return "unknown"

# --------------------------------------------------
# Cell magic
# --------------------------------------------------

@register_cell_magic
def fenicsx(line, cell):
    args = shlex.split(line)
    code = textwrap.dedent(cell)

    np = 1
    info_mode = "--info" in args

    if "-np" in args:
        np = int(args[args.index("-np") + 1])

    # Base environment
    env = os.environ.copy()
    env.update({
        "PATH": "/usr/local/micromamba/bin:" + env.get("PATH", ""),
        "MAMBA_ROOT_PREFIX": "/usr/local/micromamba",
        "MAMBA_EXE": "/usr/local/micromamba/bin/micromamba",
        "OMPI_ALLOW_RUN_AS_ROOT": "1",
        "OMPI_ALLOW_RUN_AS_ROOT_CONFIRM": "1",
    })

    mpi_impl = detect_mpi_impl(env)
    mpi_ver  = mpi_version_string(env)

    # --------------------------------------------------
    # Helper: run script
    # --------------------------------------------------

    def run_script(script):
        if np == 1:
            return ["micromamba", "run", "-n", "fenicsx", "python", script]

        if mpi_impl == "openmpi":
            return [
                "micromamba", "run", "-n", "fenicsx",
                "mpirun", "--oversubscribe", "--bind-to", "none",
                "-np", str(np), "python", script
            ]

        return [
            "micromamba", "run", "-n", "fenicsx",
            "mpiexec", "-n", str(np), "python", script
        ]

    # --------------------------------------------------
    # --info mode
    # --------------------------------------------------

    if info_mode:
        info_code = """
from mpi4py import MPI
import dolfinx, sys, platform, os

comm = MPI.COMM_WORLD
if comm.rank == 0:
    print("üêç Python         :", sys.version.split()[0])
    print("üì¶ dolfinx        :", dolfinx.__version__)
    print("üíª Platform       :", platform.platform())
    print("üßµ Running as root:", os.geteuid() == 0)
"""

        with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
            f.write(info_code)
            script = f.name

        try:
            res = subprocess.run(
                run_script(script),
                env=env, capture_output=True, text=True
            )
            print(res.stdout, end="")
            print(res.stderr, end="")
        finally:
            os.remove(script)

        print("\nüîé fenicsx runtime info")
        print("-----------------------")
        print("Environment       : fenicsx")
        print(f"MPI implementation: {mpi_impl.upper()}")
        print(f"MPI version       : {mpi_ver}")
        print(f"MPI ranks (-np)   : {np}")
        return

    # --------------------------------------------------
    # Normal execution
    # --------------------------------------------------

    with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
        f.write(code)
        script = f.name

    try:
        res = subprocess.run(
            run_script(script),
            env=env, capture_output=True, text=True, timeout=500
        )
        print(res.stdout, end="")
        print(res.stderr, end="")
    finally:
        os.remove(script)

---

In [None]:
%%fenicsx --info

üêç Python         : 3.11.14
üì¶ dolfinx        : 0.10.0
üíª Platform       : Linux-6.6.105+-x86_64-with-glibc2.35
üßµ Running as root: True

üîé fenicsx runtime info
-----------------------
Environment       : fenicsx
MPI implementation: MPICH
MPI version       : mpiexec (OpenRTE) 4.1.2
MPI ranks (-np)   : 1


In [None]:
%%fenicsx -np 4 --info

üêç Python         : 3.11.14
üì¶ dolfinx        : 0.10.0
üíª Platform       : Linux-6.6.105+-x86_64-with-glibc2.35
üßµ Running as root: True

üîé fenicsx runtime info
-----------------------
Environment       : fenicsx
MPI implementation: MPICH
MPI version       : mpiexec (OpenRTE) 4.1.2
MPI ranks (-np)   : 4


In [None]:
%%fenicsx

from mpi4py import MPI
import dolfinx

print("dolfinx :", dolfinx.__version__)
print("MPI size:", MPI.COMM_WORLD.size)

dolfinx : 0.10.0
MPI size: 1


In [None]:
%%fenicsx -np 4

from mpi4py import MPI

print(f"Hello from rank {MPI.COMM_WORLD.rank} / {MPI.COMM_WORLD.size}")

Hello from rank 0 / 4
Hello from rank 2 / 4
Hello from rank 1 / 4
Hello from rank 3 / 4


---

In [None]:
%%fenicsx -np 4

import numpy as np
from mpi4py import MPI
from petsc4py import PETSc

from dolfinx import mesh, fem, io
from dolfinx.fem.petsc import LinearProblem
import ufl

# --------------------------------------------------
# Build mesh + solve
# --------------------------------------------------

def solve_poisson(comm, nx=32, ny=32):
    # Mesh & function space
    domain = mesh.create_unit_square(comm, nx, ny)
    V = fem.functionspace(domain, ("Lagrange", 1))

    # Dirichlet BC: u = 0 on boundary
    u_bc = fem.Function(V)
    u_bc.x.array[:] = 0.0

    fdim = domain.topology.dim -1
    boundary_facets = mesh.locate_entities_boundary(
        domain, fdim, lambda x: np.full(x.shape[1], True)
    )
    bc = fem.dirichletbc(
        u_bc,
        fem.locate_dofs_topological(V, fdim, boundary_facets)
    )

    # Variational problem
    u = ufl.TrialFunction(V)
    v = ufl.TestFunction(V)
    f = fem.Constant(domain, PETSc.ScalarType(1.0))

    a = ufl.inner(ufl.grad(u), ufl.grad(v)) *ufl.dx
    L = f *v *ufl.dx

    problem = LinearProblem(
        a, L, bcs=[bc],
        petsc_options={"ksp_type": "cg", "pc_type": "hypre"},
        petsc_options_prefix="poisson_",
    )

    uh = problem.solve()

    # Diagnostics
    local_L2 = fem.assemble_scalar(fem.form(uh *uh *ufl.dx))
    global_L2 = comm.allreduce(local_L2, op=MPI.SUM)

    if comm.rank == 0:
        print("‚úÖ Poisson problem solved")
        print("   Number of dofs:", V.dofmap.index_map.size_global)
        print("   MPI size      :", comm.size)
        print("   L2 norm       :", np.sqrt(global_L2))

    return domain, uh

# --------------------------------------------------
# save_xdmf: collective I/O
# --------------------------------------------------

def save_xdmf(comm, domain, uh, filename="poisson.xdmf"):

    with io.XDMFFile(comm, filename, "w") as xdmf:
        xdmf.write_mesh(domain)
        xdmf.write_function(uh)
    if comm.rank == 0:
        print(f"üñºÔ∏è Saved {filename}")

# --------------------------------------------------

comm = MPI.COMM_WORLD

domain, uh = solve_poisson(comm)
save_xdmf(comm, domain, uh)


‚úÖ Poisson problem solved
   Number of dofs: 1089
   MPI size      : 4
   L2 norm       : 0.04115886586297522
üñºÔ∏è Saved poisson.xdmf


In [None]:
from google.colab import files

files.download("poisson.xdmf")
files.download("poisson.h5")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

---

In [None]:
%%fenicsx -np 1

import numpy as np
import matplotlib.pyplot as plt
from mpi4py import MPI

import dolfinx
from dolfinx.mesh import create_unit_interval
from dolfinx.fem import functionspace, Function
from dolfinx.geometry import (
    bb_tree,
    compute_collisions_points,
    compute_colliding_cells,
)

# ----------------------------
# Mesh & Function space
# ----------------------------
comm = MPI.COMM_WORLD

N=8
mesh = create_unit_interval(comm, N)
V = functionspace(mesh, ("Lagrange", 1))

# ----------------------------
# DOF coordinates (global)
# ----------------------------
imap = V.dofmap.index_map
x_dofs = V.tabulate_dof_coordinates().reshape(-1)

# ----------------------------
# Bounding box tree
# ----------------------------
tree = bb_tree(mesh, mesh.topology.dim)

# ----------------------------
# Plot grid
# ----------------------------
x_plot = np.linspace(0.0, 1.0, 400)
points = np.zeros((len(x_plot), 3))
points[:, 0] = x_plot

# ----------------------------
# Plot all basis functions
# ----------------------------
plt.figure(figsize=(9, 4))

for i in range(imap.size_global):
    # Basis function œÜ_i
    phi = Function(V)
    phi.x.array[:] = 0.0

    # Set i-th DOF = 1
    if i < imap.size_local:
        phi.x.array[i] = 1.0
    phi.x.scatter_forward()

    values = np.zeros(len(x_plot))

    for k, p in enumerate(points):
        p = p.reshape(1, 3)

        candidates = compute_collisions_points(tree, p)
        cells = compute_colliding_cells(mesh, candidates, p)

        if cells.num_nodes > 0:
            cell = cells.links(0)[0]
            values[k] = phi.eval(p, cell)[0]

    plt.plot(x_plot, values)

plt.title("All P1 Lagrange basis functions on [0,1]")
plt.xlabel("x")
plt.ylabel("œÜ")
plt.grid(True)
plt.tight_layout()

plt.savefig("lagrange1_basis.png", dpi=200, bbox_inches="tight")

In [None]:
files.download("lagrange1_basis.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

---

In [None]:
%%fenicsx -np 4

import numpy as np
import matplotlib.pyplot as plt
from mpi4py import MPI

import dolfinx
from dolfinx.mesh import create_unit_interval
from dolfinx.fem import Function, functionspace
from dolfinx.geometry import (
    bb_tree,
    compute_collisions_points,
    compute_colliding_cells,
)

# ----------------------------
# Mesh & Function space
# ----------------------------

comm = MPI.COMM_WORLD

N = 8
mesh = create_unit_interval(comm, N)
V = functionspace(mesh, ("Lagrange", 1))

ndofs = V.dofmap.index_map.size_global

# Evaluation points
x_plot = np.linspace(0, 1, 400)
points = np.zeros((len(x_plot), 3))
points[:, 0] = x_plot

tree = bb_tree(mesh, mesh.topology.dim)

plt.figure(figsize=(9, 4))

for i in range(ndofs):

    # --- basis function œÜ_i ---
    imap = V.dofmap.index_map
    local_size = imap.size_local
    local_to_global = imap.local_to_global(
        np.arange(local_size, dtype=np.int32)
    )

    phi = Function(V)
    phi.x.array[:] = 0.0

    # üîë global i Î•º local indexÎ°ú Î≥ÄÌôò
    for local_dof, global_dof in enumerate(local_to_global):
        if global_dof == i:
            phi.x.array[local_dof] = 1.0

    values_local = np.zeros(len(points))

    for k, p in enumerate(points):
        p = p.reshape(1, 3)

        candidates = compute_collisions_points(tree, p)
        cells = compute_colliding_cells(mesh, candidates, p)

        links = cells.links(0)
        if len(links) > 0:
            values_local[k] = phi.eval(p, links[0])[0]

    # MPI sum
    values = np.zeros_like(values_local)
    comm.Allreduce(values_local, values, op=MPI.SUM)

    # üîë ÌïµÏã¨: plotÏùÑ ÎàÑÏ†Å
    if comm.rank == 0:
        plt.plot(x_plot, values, lw=1)

if comm.rank == 0:
    plt.title("All P1 Lagrange basis functions on [0,1]")
    plt.xlabel("x")
    plt.ylabel("œÜ")
    plt.grid(True)

    plt.savefig("lagrange1_basis_np4.png", dpi=200, bbox_inches="tight")

In [None]:
files.download("lagrange1_basis_np4.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

---

In [None]:
%%fenicsx -np 4

from mpi4py import MPI
import numpy as np
import matplotlib.pyplot as plt

from dolfinx.mesh import create_unit_interval
from dolfinx.fem import functionspace, Function
from dolfinx.geometry import (
    bb_tree,
    compute_collisions_points,
    compute_colliding_cells,
)

# ----------------------------
# MPI
# ----------------------------
comm = MPI.COMM_WORLD
rank = comm.rank

# ----------------------------
# Mesh & space
# ----------------------------
N = 8
mesh = create_unit_interval(comm, N)
V = functionspace(mesh, ("Lagrange", 2))

# ----------------------------
# Bounding box tree
# ----------------------------
tree = bb_tree(mesh, mesh.topology.dim)

# ----------------------------
# Evaluation helper (MPI-safe)
# ----------------------------
def evaluate_function_global(phi, x_plot):
    values_local = np.zeros(len(x_plot), dtype=float)

    for i, x in enumerate(x_plot):
        point = np.array([[x, 0.0, 0.0]], dtype=mesh.geometry.x.dtype)

        candidates = compute_collisions_points(tree, point)
        cells = compute_colliding_cells(mesh, candidates, point)

        # üîë MPIÏóêÏÑú Ïú†ÏùºÌïòÍ≤å ÏïàÏ†ÑÌïú Í≤ÄÏÇ¨
        if cells.num_nodes > 0:
            links = cells.links(0)
            if len(links) > 0:
                cell = links[0]
                values_local[i] = phi.eval(point, [cell])[0]

    values_global = np.zeros_like(values_local)
    comm.Allreduce(values_local, values_global, op=MPI.SUM)

    return values_global

# ----------------------------
# Plot grid
# ----------------------------
if rank == 0:
    x_plot = np.linspace(0.0, 1.0, 400)
else:
    x_plot = None

x_plot = comm.bcast(x_plot, root=0)

# ----------------------------
# Plot all basis functions
# ----------------------------
if rank == 0:
    plt.figure(figsize=(9, 4))

num_dofs = V.dofmap.index_map.size_global

for i in range(num_dofs):
    phi = Function(V)
    phi.x.array[:] = 0.0

    # Ìï¥Îãπ global dofÎ•º ÏÜåÏú†Ìïú rankÎßå 1 ÎåÄÏûÖ
    imap = V.dofmap.index_map
    local_range = imap.local_range
    if local_range[0] <= i < local_range[1]:
        phi.x.array[i - local_range[0]] = 1.0

    y = evaluate_function_global(phi, x_plot)

    if rank == 0:
        plt.plot(x_plot, y)

if rank == 0:
    plt.title("All P1 Lagrange basis functions on [0,1] (MPI)")
    plt.xlabel("x")
    plt.ylabel("basis value")
    plt.grid(True)
    plt.tight_layout()

    plt.savefig("lagrange2_basis_np4.png", dpi=200, bbox_inches="tight")

In [None]:
files.download("lagrange2_basis_np4.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
%%fenicsx -np 4

import numpy as np
from mpi4py import MPI

import vtk
import pyvista

vtk.vtkLogger.SetStderrVerbosity(vtk.vtkLogger.VERBOSITY_ERROR)
pyvista.OFF_SCREEN = True

import dolfinx
import ufl


def approximate_function(N: int, degree: int):

    comm = MPI.COMM_WORLD
    rank = comm.rank

    # -------------------------
    # MPI FEM computation
    # -------------------------
    mesh = dolfinx.mesh.create_unit_interval(comm, N)

    def g(x):
        return x[0] +np.sin(np.pi *x[0]) *np.cos(3 *np.pi *x[0])

    V = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))
    u = dolfinx.fem.Function(V)
    u.interpolate(g)

    # -------------------------
    # Visualization (rank 0 only)
    # -------------------------
    if rank == 0:

        # Reference (exact) solution
        x_ref = np.linspace(0, 1, 1000)
        g_ref = g(x_ref.reshape(1, -1))

        # Serial mesh for visualization
        vis_mesh = dolfinx.mesh.create_unit_interval(
            MPI.COMM_SELF, N
        )

        # High-order visualization
        V_vis = dolfinx.fem.functionspace(vis_mesh, ("Lagrange", degree))
        u_vis = dolfinx.fem.Function(V_vis)
        u_vis.interpolate(g)

        pv_grid = pyvista.UnstructuredGrid(
            *dolfinx.plot.vtk_mesh(V_vis)
        )
        pv_grid.point_data["u"] = u_vis.x.array
        warped = pv_grid.warp_by_scalar("u", normal=[0, 1, 0])
        warped_tessellate = warped.tessellate()

        # Linear nodes (points)
        V_lin = dolfinx.fem.functionspace(vis_mesh, ("Lagrange", 1))
        u_lin = dolfinx.fem.Function(V_lin)
        u_lin.interpolate(g)

        lin_grid = pyvista.UnstructuredGrid(
            *dolfinx.plot.vtk_mesh(V_lin)
        )
        lin_grid.point_data["u"] = u_lin.x.array
        lin_warped = lin_grid.warp_by_scalar("u", normal=[0, 1, 0])

        # -------------------------
        # PyVista plotting
        # -------------------------
        plotter = pyvista.Plotter(off_screen=True)

        # Exact curve (PolyData)
        exact_points = np.vstack(
            [x_ref, g_ref, np.zeros_like(x_ref)]
        ).T
        exact_poly = pyvista.PolyData(exact_points)
        exact_poly.lines = np.hstack(
            [[len(exact_points)], np.arange(len(exact_points))]
        )
        plotter.add_mesh(
            exact_poly, color="red", line_width=4, label="Exact"
        )

        # FEM approximation
        plotter.add_mesh(
            warped_tessellate,
            color="blue",
            style="wireframe",
            line_width=3,
            label="Approximation"
        )
        plotter.add_mesh(
            lin_warped,
            color="blue",
            style="points",
            point_size=10
        )

        plotter.view_xy()
        plotter.add_legend()
        plotter.screenshot("result.png")
        plotter.close()


approximate_function(5, 2)

In [None]:
%%fenicsx -np 4

import numpy as np
from mpi4py import MPI

import vtk
import pyvista

vtk.vtkLogger.SetStderrVerbosity(vtk.vtkLogger.VERBOSITY_ERROR)
pyvista.OFF_SCREEN = True

import dolfinx

# -----------------------------------------
# Exact function
# -----------------------------------------
def g(x):
    return x[0] + np.sin(np.pi * x[0]) * np.cos(3 * np.pi * x[0])


# -----------------------------------------
# Main routine
# -----------------------------------------
def plot_for_degree(N, degree):

    comm = MPI.COMM_WORLD
    rank = comm.rank

    # -----------------------------
    # MPI FEM computation
    # -----------------------------
    mesh = dolfinx.mesh.create_unit_interval(comm, N)
    V = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))
    u = dolfinx.fem.Function(V)
    u.interpolate(g)

    # -----------------------------
    # Visualization (rank 0 only)
    # -----------------------------
    if rank != 0:
        return

    # Reference grid
    x_ref = np.linspace(0.0, 1.0, 2000)
    g_ref = g(x_ref.reshape(1, -1)).ravel()

    # Serial mesh for visualization
    vis_mesh = dolfinx.mesh.create_unit_interval(
        MPI.COMM_SELF, N
    )

    # High-order FEM (smooth curve)
    V_vis = dolfinx.fem.functionspace(vis_mesh, ("Lagrange", degree))
    u_vis = dolfinx.fem.Function(V_vis)
    u_vis.interpolate(g)

    pv_grid = pyvista.UnstructuredGrid(
        *dolfinx.plot.vtk_mesh(V_vis)
    )
    pv_grid.point_data["u"] = u_vis.x.array

    warped = pv_grid.warp_by_scalar("u", normal=[0, 1, 0])
    warped_tess = warped.tessellate()

    # Linear nodes (points)
    V_lin = dolfinx.fem.functionspace(vis_mesh, ("Lagrange", 1))
    u_lin = dolfinx.fem.Function(V_lin)
    u_lin.interpolate(g)

    lin_grid = pyvista.UnstructuredGrid(
        *dolfinx.plot.vtk_mesh(V_lin)
    )
    lin_grid.point_data["u"] = u_lin.x.array
    lin_warped = lin_grid.warp_by_scalar("u", normal=[0, 1, 0])

    # -----------------------------
    # Exact curve (PolyData)
    # -----------------------------
    exact_pts = np.column_stack(
        [x_ref, g_ref, np.zeros_like(x_ref)]
    )
    exact_poly = pyvista.PolyData(exact_pts)
    exact_poly.lines = np.hstack(
        [[len(exact_pts)], np.arange(len(exact_pts))]
    )

    # -----------------------------
    # Error computation (FAST)
    # -----------------------------
    x_vis = pv_grid.points[:, 0]
    u_vis_vals = pv_grid.point_data["u"]

    fem_interp = np.interp(x_ref, x_vis, u_vis_vals)
    error = g_ref - fem_interp

    error_pts = np.column_stack(
        [x_ref, error, np.zeros_like(x_ref)]
    )
    error_poly = pyvista.PolyData(error_pts)
    error_poly.lines = np.hstack(
        [[len(error_pts)], np.arange(len(error_pts))]
    )

    # =====================================================
    # FIGURE 1: Exact vs FEM
    # =====================================================
    plotter = pyvista.Plotter(
        off_screen=True,
        window_size=(1600, 450)
    )
    plotter.set_background("white")
    plotter.enable_parallel_projection()
    plotter.hide_axes()

    plotter.add_mesh(
        exact_poly,
        color="black",
        line_width=4,
        label="Exact"
    )

    plotter.add_mesh(
        warped_tess,
        color="blue",
        style="wireframe",
        line_width=2,
        label=f"FEM P{degree}"
    )

    plotter.add_mesh(
        lin_warped,
        color="blue",
        style="points",
        point_size=8
    )

    plotter.add_legend(size=(0.25, 0.3), loc="upper left")
    plotter.view_xy()
    plotter.screenshot(f"solution_P{degree}.png")
    plotter.close()

    # =====================================================
    # FIGURE 2: Error only
    # =====================================================
    plotter_err = pyvista.Plotter(
        off_screen=True,
        window_size=(1600, 350)
    )
    plotter_err.set_background("white")
    plotter_err.enable_parallel_projection()
    plotter_err.hide_axes()

    plotter_err.add_mesh(
        error_poly,
        color="red",
        line_width=3,
        label="Error (Exact ‚àí FEM)"
    )

    plotter_err.add_legend()
    plotter_err.view_xy()
    plotter_err.screenshot(f"error_P{degree}.png")
    plotter_err.close()


# -----------------------------------------
# Run for several degrees
# -----------------------------------------
for deg in [1, 2, 3]:
    plot_for_degree(N=8, degree=deg)

[0m[33m2025-12-30 15:59:49.298 (   4.780s) [    7F12B6734440]vtkXOpenGLRenderWindow.:1458  WARN| bad X server connection. DISPLAY=[0m


In [None]:
%%fenicsx -np 4

import numpy as np
from mpi4py import MPI
import dolfinx
import ufl


# ---------------------------------
# Exact function (UFL + NumPy Í≤∏Ïö©)
# ---------------------------------
def g_expr(x):
    return x[0] + ufl.sin(ufl.pi * x[0]) * ufl.cos(3 * ufl.pi * x[0])


def g_numpy(x):
    return x + np.sin(np.pi * x) * np.cos(3 * np.pi * x)


# ---------------------------------
# Error computation
# ---------------------------------
def compute_errors(N, degree):
    mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, N)
    V = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))

    u = dolfinx.fem.Function(V)
    u.interpolate(lambda x: g_numpy(x[0]))

    x = ufl.SpatialCoordinate(mesh)
    u_exact = g_expr(x)

    # L2 error
    L2_error = dolfinx.fem.assemble_scalar(
        dolfinx.fem.form((u - u_exact) ** 2 * ufl.dx)
    )

    # H1 seminorm error
    H1_error = dolfinx.fem.assemble_scalar(
        dolfinx.fem.form(
            ufl.inner(ufl.grad(u - u_exact),
                      ufl.grad(u - u_exact)) * ufl.dx
        )
    )

    L2_error = np.sqrt(mesh.comm.allreduce(L2_error, op=MPI.SUM))
    H1_error = np.sqrt(mesh.comm.allreduce(H1_error, op=MPI.SUM))

    h = 1.0 / N
    return h, L2_error, H1_error


# ---------------------------------
# Run convergence study
# ---------------------------------
def convergence_table(degree):
    Ns = [10, 20, 40, 80]
    results = []

    for N in Ns:
        h, L2, H1 = compute_errors(N, degree)
        results.append((h, L2, H1))

    if MPI.COMM_WORLD.rank != 0:
        return

    print(f"\n=== P{degree} FEM convergence ===")
    print(" h        L2 error     rate      H1 error     rate")
    print("--------------------------------------------------")

    for i in range(len(results)):
        h, L2, H1 = results[i]

        if i == 0:
            print(f"{h:6.4f}  {L2:10.3e}    ---   {H1:10.3e}    ---")
        else:
            h0, L20, H10 = results[i - 1]
            rate_L2 = np.log(L20 / L2) / np.log(h0 / h)
            rate_H1 = np.log(H10 / H1) / np.log(h0 / h)

            print(
                f"{h:6.4f}  {L2:10.3e}  {rate_L2:6.2f}  "
                f"{H1:10.3e}  {rate_H1:6.2f}"
            )


# ---------------------------------
# Execute
# ---------------------------------
for p in [1, 2, 3]:
    convergence_table(degree=p)


=== P1 FEM convergence ===
 h        L2 error     rate      H1 error     rate
--------------------------------------------------
0.1000   5.083e-02    ---    1.620e+00    ---
0.0500   1.303e-02    1.96   8.254e-01    0.97
0.0250   3.277e-03    1.99   4.147e-01    0.99
0.0125   8.204e-04    2.00   2.076e-01    1.00

=== P2 FEM convergence ===
 h        L2 error     rate      H1 error     rate
--------------------------------------------------
0.1000   3.978e-03    ---    2.579e-01    ---
0.0500   5.055e-04    2.98   6.552e-02    1.98
0.0250   6.344e-05    2.99   1.645e-02    1.99
0.0125   7.939e-06    3.00   4.116e-03    2.00

=== P3 FEM convergence ===
 h        L2 error     rate      H1 error     rate
--------------------------------------------------
0.1000   2.883e-04    ---    2.733e-02    ---
0.0500   1.825e-05    3.98   3.462e-03    2.98
0.0250   1.144e-06    4.00   4.343e-04    3.00
0.0125   7.158e-08    4.00   5.433e-05    3.00


In [None]:
%%fenicsx -np 4

import numpy as np
from mpi4py import MPI
import dolfinx
import ufl
import matplotlib.pyplot as plt

comm = MPI.COMM_WORLD
rank = comm.rank


# ----------------------------
# Exact function
# ----------------------------
def g_expr(x):
    return x[0] + ufl.sin(ufl.pi * x[0]) * ufl.cos(3 * ufl.pi * x[0])


def g_numpy(x):
    return x + np.sin(np.pi * x) * np.cos(3 * np.pi * x)


# ----------------------------
# Error computation (ALL ranks)
# ----------------------------
def compute_errors(N, degree):
    mesh = dolfinx.mesh.create_unit_interval(comm, N)
    V = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))

    u = dolfinx.fem.Function(V)
    u.interpolate(lambda x: g_numpy(x[0]))

    x = ufl.SpatialCoordinate(mesh)
    u_exact = g_expr(x)

    L2 = dolfinx.fem.assemble_scalar(
        dolfinx.fem.form((u - u_exact)**2 * ufl.dx)
    )
    H1 = dolfinx.fem.assemble_scalar(
        dolfinx.fem.form(
            ufl.inner(ufl.grad(u - u_exact),
                      ufl.grad(u - u_exact)) * ufl.dx
        )
    )

    L2 = np.sqrt(comm.allreduce(L2, MPI.SUM))
    H1 = np.sqrt(comm.allreduce(H1, MPI.SUM))

    return 1.0 / N, L2, H1


# ----------------------------
# Convergence data (ALL ranks)
# ----------------------------
def convergence_data(degree):
    Ns = [10, 20, 40, 80, 160]
    h, L2, H1 = [], [], []

    for N in Ns:
        hi, L2i, H1i = compute_errors(N, degree)
        h.append(hi)
        L2.append(L2i)
        H1.append(H1i)

    return np.array(h), np.array(L2), np.array(H1)


# ----------------------------
# Compute first (ALL ranks)
# ----------------------------
results = {}
for p in [1, 2, 3]:
    results[p] = convergence_data(p)


# ----------------------------
# Plot (rank 0 only)
# ----------------------------
if rank == 0:
    plt.figure(figsize=(7, 6))

    for p, marker in zip([1, 2, 3], ["o", "s", "^"]):
        h, L2, H1 = results[p]

        plt.loglog(h, L2, marker + "-", label=f"P{p}  L2")
        plt.loglog(h, H1, marker + "--", label=f"P{p}  H1")

        # reference slopes
        plt.loglog(h, L2[0]*(h/h[0])**(p+1), "k:", alpha=0.4)
        plt.loglog(h, H1[0]*(h/h[0])**p, "k--", alpha=0.4)

    plt.gca().invert_xaxis()
    plt.xlabel("Mesh size h")
    plt.ylabel("Error")
    plt.title("FEM convergence (L2 and H1)")
    plt.legend()
    plt.grid(True, which="both", ls=":")

    plt.tight_layout()
    plt.savefig("convergence.png", dpi=300)

