# Poisson Equation Solver with Finite Elements

This notebook demonstrates solving the 2D Poisson equation using the finite element method with FEniCSx (dolfinx). The equation is solved on a unit square domain with Dirichlet boundary conditions (u = 0 on the boundary). We use both h-refinement (increasing mesh resolution) and p-refinement (increasing polynomial degree) to study convergence. The code computes the L2 error between the numerical and analytical solutions, visualizes the results, and creates a GIF to show the evolution of the solution with refinement.

In [ ]:
import matplotlib.pyplot as plt
import numpy as np
from petsc4py import PETSc
from mpi4py import MPI
from dolfinx import fem, mesh, plot
import ufl
import pyvista
import matplotlib as mpl

# Import the necessary functions, including assemble_matrix
from dolfinx.fem.petsc import assemble_matrix, assemble_vector, create_vector, apply_lifting, set_bc

## Importing Libraries

The code begins by importing the required libraries:
- `matplotlib.pyplot` and `numpy` for plotting and numerical operations.
- `petsc4py` and `mpi4py` for parallel computing and linear algebra.
- `dolfinx` modules for finite element mesh creation, function spaces, and solvers.
- `ufl` for defining variational forms.
- `pyvista` for 3D visualization.
- `matplotlib` for colormap handling.
- Specific functions from `dolfinx.fem.petsc` are imported for assembling matrices and vectors, applying boundary conditions, and solving the linear system.

In [ ]:
# Define the domain and analytical solution
def analytical_solution(x):
    return np.sin(np.pi * x[0]) * np.sin(np.pi * x[1])

# Compute the L2 norm of the error
def compute_l2_error(u_numerical, V, domain):
    u_analytical = fem.Function(V)
    u_analytical.interpolate(analytical_solution)
    error_form = ufl.inner(u_numerical - u_analytical, u_numerical - u_analytical) * ufl.dx
    error_l2 = np.sqrt(fem.assemble_scalar(fem.form(error_form)))
    norm_analytical = np.sqrt(fem.assemble_scalar(fem.form(ufl.inner(u_analytical, u_analytical) * ufl.dx)))
    relative_error = error_l2 / norm_analytical if norm_analytical > 0 else error_l2
    return relative_error

## Defining the Analytical Solution and Error Computation

Two key functions are defined:
1. `analytical_solution(x)`: Computes the analytical solution u(x, y) = sin(pi x) sin(pi y), which satisfies the Poisson equation with the appropriate right-hand side.
2. `compute_l2_error(u_numerical, V, domain)`: Calculates the relative L2 error between the numerical solution (`u_numerical`) and the analytical solution. The error is computed as the ratio of the L2 norm of the difference to the L2 norm of the analytical solution. The function uses the finite element function space `V` and the domain to interpolate the analytical solution and compute the integrals.

In [ ]:
# Solve the Poisson equation and return the solution for visualization
def solve_poisson(nx, ny, degree, cell_type, case_name):
    # Create mesh
    domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([1, 1])], [nx, ny], cell_type)
    
    # Define function space
    V = fem.functionspace(domain, ("Lagrange", degree))
    
    # Boundary condition (u = 0 on the boundary)
    fdim = domain.topology.dim - 1
    boundary_facets = mesh.locate_entities_boundary(
        domain, fdim, lambda x: np.full(x.shape[1], True, dtype=bool))
    bc = fem.dirichletbc(PETSc.ScalarType(0), fem.locate_dofs_topological(V, fdim, boundary_facets), V)
    
    # Define variational problem: -Delta u = f
    u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    x = ufl.SpatialCoordinate(domain)  # Define spatial coordinates
    f = 2 * np.pi**2 * ufl.sin(np.pi * x[0]) * ufl.sin(np.pi * x[1])  # Right-hand side
    a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx  # Bilinear form
    L = f * v * ufl.dx  # Linear form
    
    # Assemble system
    bilinear_form = fem.form(a)
    linear_form = fem.form(L)
    A = assemble_matrix(bilinear_form, bcs=[bc])
    A.assemble()
    b = create_vector(linear_form)
    
    # Assemble right-hand side
    with b.localForm() as loc_b:
        loc_b.set(0)
    assemble_vector(b, linear_form)
    apply_lifting(b, [bilinear_form], [[bc]])
    b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
    set_bc(b, [bc])
    
    # Solve
    uh = fem.Function(V)
    solver = PETSc.KSP().create(domain.comm)
    solver.setOperators(A)
    solver.setType(PETSc.KSP.Type.PREONLY)
    solver.getPC().setType(PETSc.PC.Type.LU)
    solver.solve(b, uh.x.petsc_vec)
    uh.x.scatter_forward()
    
    # Compute L2 error
    relative_error = compute_l2_error(uh, V, domain)
    
    return uh, V, domain, relative_error

## Solving the Poisson Equation

The `solve_poisson` function implements the finite element solution of the Poisson equation -Delta u = f on a rectangular mesh. Key steps include:
- **Mesh Creation**: A rectangular mesh is created with `nx` x `ny` elements, using the specified `cell_type` (triangle or quadrilateral).
- **Function Space**: A Lagrange finite element space of degree `degree` is defined.
- **Boundary Conditions**: Dirichlet boundary conditions (u = 0) are applied on the domain boundary.
- **Variational Form**: The weak form of the Poisson equation is defined, where the left-hand side involves the inner product of the gradients of the trial and test functions, and the right-hand side involves the source term f = 2 pi^2 sin(pi x) sin(pi y).
- **Assembly**: The bilinear and linear forms are assembled into a matrix `A` and vector `b`, with boundary conditions applied.
- **Solving**: The linear system is solved using a direct LU solver from PETSc.
- **Error Computation**: The relative L2 error is computed using the `compute_l2_error` function.

The function returns the numerical solution (`uh`), function space (`V`), domain, and relative error for further processing.

In [ ]:
# Visualize the analytical solution as a static image
def visualize_analytical():
    domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([1, 1])], [50, 50], mesh.CellType.quadrilateral)
    V = fem.functionspace(domain, ("Lagrange", 1))
    u_analytical = fem.Function(V)
    u_analytical.interpolate(analytical_solution)
    
    pyvista.start_xvfb()
    grid = pyvista.UnstructuredGrid(*plot.vtk_mesh(V))
    grid.point_data["u_analytical"] = u_analytical.x.array
    warped = grid.warp_by_scalar("u_analytical", factor=1.0)
    
    plotter = pyvista.Plotter(off_screen=True)
    viridis = mpl.colormaps.get_cmap("viridis").resampled(25)
    sargs = dict(title_font_size=20, label_font_size=15, fmt="%.2e", color="black",
                 position_x=0.1, position_y=0.8, width=0.8, height=0.1)
    plotter.add_mesh(warped, show_edges=True, lighting=False, cmap=viridis,
                     scalar_bar_args=sargs, clim=[0, 1])
    plotter.view_xy()
    plotter.camera.zoom(1.5)
    plotter.add_title("Analytical Solution", font_size=12)
    plotter.show(screenshot="solution_analytical.png")
    plotter.close()

## Visualizing the Analytical Solution

The `visualize_analytical` function creates a static visualization of the analytical solution using PyVista:
- A 50x50 quadrilateral mesh is created, and the analytical solution is interpolated onto a linear Lagrange function space.
- The solution is visualized as a warped mesh, where the height represents the solution value.
- The plot uses the Viridis colormap, with a scalar bar and title, and is saved as `solution_analytical.png`.
- The visualization is performed off-screen to support headless environments.

In [ ]:
# Define cases
h_refinement_cases = [
    (5, 5, 1, mesh.CellType.quadrilateral, "h-ref: 5x5, p=1"),
    (20, 20, 1, mesh.CellType.quadrilateral, "h-ref: 20x20, p=1"),
    (50, 50, 1, mesh.CellType.quadrilateral, "h-ref: 50x50, p=1"),
    (100, 100, 1, mesh.CellType.quadrilateral, "h-ref: 100x100, p=1")
]

p_refinement_cases = [
    (20, 20, 3, mesh.CellType.triangle, "p-ref: degree=3, tri"),
    (20, 20, 6, mesh.CellType.triangle, "p-ref: degree=6, tri"),
    (20, 20, 4, mesh.CellType.quadrilateral, "p-ref: degree=4, quad"),
    # (20, 20, 8, mesh.CellType.quadrilateral, "p-ref: degree=8, quad")  # Excluded to avoid visualization issues
]

all_cases = h_refinement_cases + p_refinement_cases

# Store results
relative_errors = []
case_names = []
solutions = []

# Run simulations for all cases and store solutions
for nx, ny, degree, cell_type, case_name in all_cases:
    print(f"Running case: {case_name}")
    uh, V, domain, rel_err = solve_poisson(nx, ny, degree, cell_type, case_name)
    relative_errors.append(rel_err)
    case_names.append(case_name)
    solutions.append((uh, V, domain, case_name, rel_err))

## Defining and Running Refinement Cases

The code defines two sets of test cases:
- **h-refinement**: Uses a fixed polynomial degree (p=1) and increases mesh resolution (5x5, 20x20, 50x50, 100x100).
- **p-refinement**: Uses a fixed mesh (20x20) and increases polynomial degree (3, 6 for triangles; 4 for quadrilaterals).

Each case is run using the `solve_poisson` function, and the results (solution, function space, domain, case name, and error) are stored for visualization and comparison.

In [ ]:
# Create GIF animation of refinement progression
pyvista.start_xvfb()
plotter = pyvista.Plotter(off_screen=True)
plotter.open_gif("refinement_evolution.gif", fps=2)

viridis = mpl.colormaps.get_cmap("viridis").resampled(25)
sargs = dict(title_font_size=20, label_font_size=15, fmt="%.2e", color="black",
             position_x=0.1, position_y=0.8, width=0.8, height=0.1)

for uh, V, domain, case_name, rel_err in solutions:
    grid = pyvista.UnstructuredGrid(*plot.vtk_mesh(V))
    grid.point_data["uh"] = uh.x.array
    warped = grid.warp_by_scalar("uh", factor=1.0)
    
    plotter.clear()
    plotter.add_mesh(warped, show_edges=True, lighting=False, cmap=viridis,
                     scalar_bar_args=sargs, clim=[0, 1])
    plotter.view_xy()
    plotter.camera.zoom(1.5)
    plotter.add_title(f"{case_name}\nL2 Error: {rel_err:.2e}", font_size=12)
    plotter.write_frame()

plotter.close()

## Creating a GIF Animation

A GIF animation (`refinement_evolution.gif`) is created to visualize the progression of numerical solutions across all refinement cases:
- Each frame shows the numerical solution for a specific case, warped by the solution values.
- The title includes the case name and the L2 error.
- The animation runs at 2 frames per second, using the Viridis colormap and a consistent scalar bar.

In [ ]:
# Visualize the analytical solution
visualize_analytical()

# Plot comparison of relative errors
plt.figure(figsize=(10, 6))
bars = plt.bar(range(len(case_names)), relative_errors, tick_label=case_names)
plt.xlabel('Case')
plt.ylabel('Relative L2 Error')
plt.title('Relative L2 Error Comparison for h- and p-Refinement')
plt.yscale('log')  # Use logarithmic scale for better visualization
plt.xticks(rotation=45, ha='right')
plt.grid(True, which="both", ls="--")

# Add error values on top of bars
for bar, error in zip(bars, relative_errors):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{error:.2e}',
             ha='center', va='bottom')

plt.tight_layout()
plt.savefig('relative_l2_error_comparison_poisson.png')
plt.show()

# Print relative errors
print("\nRelative L2 Errors:")
for name, error in zip(case_names, relative_errors):
    print(f"{name}: {error:.2e}")

## Final Visualizations and Output

The code concludes with:
- **Analytical Solution Visualization**: The `visualize_analytical` function is called to generate `solution_analytical.png`.
- **Error Comparison Plot**: A bar plot compares the relative L2 errors across all cases, using a logarithmic scale for clarity. Error values are annotated above each bar. The plot is saved as `relative_l2_error_comparison_poisson.png`.
- **Error Output**: The relative L2 errors for each case are printed to the console for reference.