# Poisson SolverSolves 2D Poisson equation on a unit square using FEniCSx. Uses h-refinement and p-refinement. Computes L2 error and creates visualizations.

In [ ]:
import matplotlib.pyplot as plt
import numpy as np
from petsc4py import PETSc
from mpi4py import MPI
from dolfinx import fem, mesh, plot
import ufl
import pyvista
import matplotlib as mpl

# Import the necessary functions, including assemble_matrix
from dolfinx.fem.petsc import assemble_matrix, assemble_vector, create_vector, apply_lifting, set_bc

## LibrariesImports tools for plotting, parallel computing, finite elements, and visualization.

In [ ]:
# Define the domain and analytical solution
def analytical_solution(x):
    return np.sin(np.pi * x[0]) * np.sin(np.pi * x[1])

# Compute the L2 norm of the error
def compute_l2_error(u_numerical, V, domain):
    u_analytical = fem.Function(V)
    u_analytical.interpolate(analytical_solution)
    error_form = ufl.inner(u_numerical - u_analytical, u_numerical - u_analytical) * ufl.dx
    error_l2 = np.sqrt(fem.assemble_scalar(fem.form(error_form)))
    norm_analytical = np.sqrt(fem.assemble_scalar(fem.form(ufl.inner(u_analytical, u_analytical) * ufl.dx)))
    relative_error = error_l2 / norm_analytical if norm_analytical > 0 else error_l2
    return relative_error

## Solution and ErrorDefines analytical solution and computes L2 error between numerical and analytical results.

In [ ]:
# Solve the Poisson equation and return the solution for visualization
def solve_poisson(nx, ny, degree, cell_type, case_name):
    # Create mesh
    domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([1, 1])], [nx, ny], cell_type)
    
    # Define function space
    V = fem.functionspace(domain, ("Lagrange", degree))
    
    # Boundary condition (u = 0 on the boundary)
    fdim = domain.topology.dim - 1
    boundary_facets = mesh.locate_entities_boundary(
        domain, fdim, lambda x: np.full(x.shape[1], True, dtype=bool))
    bc = fem.dirichletbc(PETSc.ScalarType(0), fem.locate_dofs_topological(V, fdim, boundary_facets), V)
    
    # Define variational problem: -Delta u = f
    u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    x = ufl.SpatialCoordinate(domain)  # Define spatial coordinates
    f = 2 * np.pi**2 * ufl.sin(np.pi * x[0]) * ufl.sin(np.pi * x[1])  # Right-hand side
    a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx  # Bilinear form
    L = f * v * ufl.dx  # Linear form
    
    # Assemble system
    bilinear_form = fem.form(a)
    linear_form = fem.form(L)
    A = assemble_matrix(bilinear_form, bcs=[bc])
    A.assemble()
    b = create_vector(linear_form)
    
    # Assemble right-hand side
    with b.localForm() as loc_b:
        loc_b.set(0)
    assemble_vector(b, linear_form)
    apply_lifting(b, [bilinear_form], [[bc]])
    b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
    set_bc(b, [bc])
    
    # Solve
    uh = fem.Function(V)
    solver = PETSc.KSP().create(domain.comm)
    solver.setOperators(A)
    solver.setType(PETSc.KSP.Type.PREONLY)
    solver.getPC().setType(PETSc.PC.Type.LU)
    solver.solve(b, uh.x.petsc_vec)
    uh.x.scatter_forward()
    
    # Compute L2 error
    relative_error = compute_l2_error(uh, V, domain)
    
    return uh, V, domain, relative_error

## SolverSolves Poisson equation on a mesh with nx x ny elements. Applies boundary conditions and computes solution.

In [ ]:
# Visualize the analytical solution as a static image
def visualize_analytical():
    domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([1, 1])], [50, 50], mesh.CellType.quadrilateral)
    V = fem.functionspace(domain, ("Lagrange", 1))
    u_analytical = fem.Function(V)
    u_analytical.interpolate(analytical_solution)
    
    pyvista.start_xvfb()
    grid = pyvista.UnstructuredGrid(*plot.vtk_mesh(V))
    grid.point_data["u_analytical"] = u_analytical.x.array
    warped = grid.warp_by_scalar("u_analytical", factor=1.0)
    
    plotter = pyvista.Plotter(off_screen=True)
    viridis = mpl.colormaps.get_cmap("viridis").resampled(25)
    sargs = dict(title_font_size=20, label_font_size=15, fmt="%.2e", color="black",
                 position_x=0.1, position_y=0.8, width=0.8, height=0.1)
    plotter.add_mesh(warped, show_edges=True, lighting=False, cmap=viridis,
                     scalar_bar_args=sargs, clim=[0, 1])
    plotter.view_xy()
    plotter.camera.zoom(1.5)
    plotter.add_title("Analytical Solution", font_size=12)
    plotter.show(screenshot="solution_analytical.png")
    plotter.close()

## VisualizationCreates image of analytical solution using PyVista.

In [ ]:
# Define cases
h_refinement_cases = [
    (5, 5, 1, mesh.CellType.quadrilateral, "h-ref: 5x5, p=1"),
    (20, 20, 1, mesh.CellType.quadrilateral, "h-ref: 20x20, p=1"),
    (50, 50, 1, mesh.CellType.quadrilateral, "h-ref: 50x50, p=1"),
    (100, 100, 1, mesh.CellType.quadrilateral, "h-ref: 100x100, p=1")
]

p_refinement_cases = [
    (20, 20, 3, mesh.CellType.triangle, "p-ref: degree=3, tri"),
    (20, 20, 6, mesh.CellType.triangle, "p-ref: degree=6, tri"),
    (20, 20, 4, mesh.CellType.quadrilateral, "p-ref: degree=4, quad")
]

all_cases = h_refinement_cases + p_refinement_cases

# Store results
relative_errors = []
case_names = []
solutions = []

# Run simulations for all cases and store solutions
for nx, ny, degree, cell_type, case_name in all_cases:
    print(f"Running case: {case_name}")
    uh, V, domain, rel_err = solve_poisson(nx, ny, degree, cell_type, case_name)
    relative_errors.append(rel_err)
    case_names.append(case_name)
    solutions.append((uh, V, domain, case_name, rel_err))

## CasesRuns h-refinement and p-refinement cases and stores results.

In [ ]:
# Create GIF animation of refinement progression
pyvista.start_xvfb()
plotter = pyvista.Plotter(off_screen=True)
plotter.open_gif("refinement_evolution.gif", fps=2)

viridis = mpl.colormaps.get_cmap("viridis").resampled(25)
sargs = dict(title_font_size=20, label_font_size=15, fmt="%.2e", color="black",
             position_x=0.1, position_y=0.8, width=0.8, height=0.1)

for uh, V, domain, case_name, rel_err in solutions:
    grid = pyvista.UnstructuredGrid(*plot.vtk_mesh(V))
    grid.point_data["uh"] = uh.x.array
    warped = grid.warp_by_scalar("uh", factor=1.0)
    
    plotter.clear()
    plotter.add_mesh(warped, show_edges=True, lighting=False, cmap=viridis,
                     scalar_bar_args=sargs, clim=[0, 1])
    plotter.view_xy()
    plotter.camera.zoom(1.5)
    plotter.add_title(f"{case_name}\nL2 Error: {rel_err:.2e}", font_size=12)
    plotter.write_frame()

plotter.close()

## GIFCreates GIF showing solutions for each case.

In [ ]:
# Visualize the analytical solution
visualize_analytical()

# Plot comparison of relative errors
plt.figure(figsize=(10, 6))
bars = plt.bar(range(len(case_names)), relative_errors, tick_label=case_names)
plt.xlabel('Case')
plt.ylabel('Relative L2 Error')
plt.title('Relative L2 Error Comparison for h- and p-Refinement')
plt.yscale('log')  # Use logarithmic scale for better visualization
plt.xticks(rotation=45, ha='right')
plt.grid(True, which="both", ls="--")

# Add error values on top of bars
for bar, error in zip(bars, relative_errors):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{error:.2e}',
             ha='center', va='bottom')

plt.tight_layout()
plt.savefig('relative_l2_error_comparison_poisson.png')
plt.show()

# Print relative errors
print("\nRelative L2 Errors:")
for name, error in zip(case_names, relative_errors):
    print(f"{name}: {error:.2e}")

## OutputGenerates analytical image, error plot, and prints errors.