In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dolfinx as dfx

from dolfinx.fem.petsc import NonlinearProblem

%matplotlib widget
from matplotlib import pyplot as plt
# plt.style.use('fivethirtyeight')

from mpi4py import MPI

import numpy as np

from petsc4py import PETSc

import pyvista

import random

import ufl

from fenicsx_utils import (evaluation_points_and_cells,
                           get_mesh_spacing,
                           time_stepping,
                           NewtonSolver,
                           Fenicx1DOutput,
                           RuntimeAnalysisBase)

# Below is a workaround to get the pyvista viz. running.
# https://github.com/pyvista/pyvista/issues/4776 (accessed: 2024/01/16)
from trame.app import get_server

CLIENT_TYPE = get_server().client_type

if CLIENT_TYPE == 'vue2':
# if True:  # <- Strangely, the workaround does not work. Force vuetify2.
    from trame.widgets import vuetify2
else:
    from trame.widgets import vuetify3 as vuetify

In [None]:
comm_world = MPI.COMM_WORLD

In [None]:
# Set up the mesh

n_elem = 128

mesh = dfx.mesh.create_unit_interval(comm_world, n_elem)

dx_cell = get_mesh_spacing(mesh)

print(f"Cell spacing: h = {dx_cell}")

# For later plotting use
x = np.linspace(0, 1, 101)
points_on_proc, cells = evaluation_points_and_cells(mesh, x)

In [None]:
elem1 = ufl.FiniteElement("Lagrange", mesh.ufl_cell(), 1)

V = dfx.fem.FunctionSpace(mesh, elem1 * elem1)  # A mixed two-component function space

In [None]:
# The mixed-element functions
u = dfx.fem.Function(V)
u0 = dfx.fem.Function(V)

In [None]:
# Compute the chemical potential df/dc
a = 6. / 4
b = 0.2
cc = 5

# a = 5. # 6. / 4
# b = 0. # 0.2
# cc = 0 # 5

free_energy = lambda u, log, sin: u * log(u) + (1-u) * log(1-u) + a * u * (1 - u) + b * sin(cc * np.pi * u)

fig, ax = plt.subplots()

eps = 1e-3

c_plot = np.linspace(eps, 1-eps, 200)

ax.plot(c_plot, free_energy(c_plot, np.log, np.sin))

plt.show()

In [None]:
c_of_y = lambda y, exp: exp(y) / (1 + exp(y))

In [None]:
# Test the transformation

y_of_c = lambda c, log: log(c / (1 - c))

np.abs(c_of_y(y_of_c(c_plot, np.log), np.exp) - c_plot).max()

In [None]:
dt = dfx.fem.Constant(mesh, dx_cell * 0.1)

In [None]:
# The variational form
# --------------------
from cahn_hilliard_utils import cahn_hilliard_form

c_of_y1 = lambda y: c_of_y(y, ufl.exp)
# c_of_y=lambda y: y

I_charge = dfx.fem.Constant(mesh, 0.1)

params = dict(I_charge=I_charge)

F = cahn_hilliard_form(
    u,
    u0,
    dt,
    free_energy=lambda c: free_energy(c, ufl.ln, ufl.sin),
    theta=0.75,
    c_of_y=c_of_y1,
    M=lambda c: 0.1 * c * (1 - c),
    lam=1.0,
    **params
)

In [None]:
# Initial data
# ------------

u_ini = dfx.fem.Function(V)

# Random
c_ini_fun = lambda x: 0.01 * np.random.randn(*x[0].shape)

# Zero-mean
c_ini_fun = lambda x: np.cos(np.pi * x[0])

# Constant
c_ini_fun = lambda x: eps * np.ones_like(x[0])

# Initial charge distribution.
# c_ini_fun = lambda x: eps + 0.5 * np.sin(np.pi * x[0])


# Store concentration-like quantity into state vector
# ---------------------------------------------------

W = V.sub(1).collapse()[0]
c_ini = dfx.fem.Function(W)
c_ini.interpolate(c_ini_fun)

y_ini = dfx.fem.Expression(y_of_c(c_ini, ufl.ln), W.element.interpolation_points())

u_ini.sub(0).interpolate(y_ini)


# Store chemical potential into state vector
# ------------------------------------------

W = u_ini.sub(1).function_space.element
c_ini = ufl.variable(c_ini)
dFdc = ufl.diff(free_energy(c_ini, ufl.ln, ufl.sin), c_ini)

u_ini.sub(1).interpolate(dfx.fem.Expression(dFdc, W.interpolation_points()))

u_ini.x.scatter_forward()

plt.figure()

plt.plot(x, c_of_y(u_ini.sub(0).eval(points_on_proc, cells), np.exp))
plt.plot(x, u_ini.sub(1).eval(points_on_proc, cells))

plt.show()

In [None]:
class AnalyzeOCP(RuntimeAnalysisBase):

    def setup(self, *args, **kwargs):
        return super().setup(*args, **kwargs)

    def analyze(self, u_state, t):

        y, mu = u_state.split()

        c = c_of_y(y, ufl.exp)

        r = ufl.SpatialCoordinate(mesh)

        c = ufl.variable(c)
        dFdc = ufl.diff(free_energy(c, ufl.ln, ufl.sin), c)

        charge = dfx.fem.form(3 * c * r**2 * ufl.dx)
        charge = dfx.fem.assemble_scalar(charge)

        chem_pot = dfx.fem.form(3 * dFdc * r**2 * ufl.dx)
        chem_pot = dfx.fem.assemble_scalar(chem_pot)

        x, cell = evaluation_points_and_cells(mesh, np.array([1.]))

        mu_bc = float(mu.eval(x, cell))

        self.data.append([charge, chem_pot, mu_bc])

        return super().analyze(u_state, t)

In [None]:
problem = NonlinearProblem(F, u)

solver = NewtonSolver(comm_world, problem)

In [None]:
V.sub(0).collapse()[0]

In [None]:
T = 10.  # ending time

u.interpolate(u_ini)

n_out = 51

def event(t, u, I_charge, c_thr=0.99):

    V = u.function_space

    W, dof = V.sub(0).collapse()

    y, mu = u.split()

    c = dfx.fem.Function(W)

    c.interpolate(dfx.fem.Expression(c_of_y(y, ufl.exp), W.element.interpolation_points()))

    max_c = mesh.comm.allreduce(max(c.x.array), op=MPI.MAX)

    if max_c > c_thr and I_charge.value > 0.:
        print(f">>> total charge exceeds maximum (max(c) = {max_c:1.3f} > {c_thr:1.3f}).")
        print(">>> Stop charging.")
        I_charge.value = 0.

output = Fenicx1DOutput(u, np.linspace(0, T, n_out), x)

rt_analysis = AnalyzeOCP()

t_out, x_out, y_out = time_stepping(solver,
                                    u,
                                    u0,
                                    T,
                                    n_out,
                                    dt,
                                    event_handler=event,
                                    output=output,
                                    runtime_analysis=rt_analysis,
                                    **params)

In [None]:
fig, axs = plt.subplots(2, 1, sharex=True)

t_out, data_out = output.get_output(return_time=True)

data_out = np.array(data_out).squeeze()

for it_out, (data_t, t) in enumerate(zip(data_out, t_out)):

    y_t = data_t[0]
    mu_t = data_t[1]

    c_t = c_of_y(y_t, np.exp)

    ax = axs[0]

    color = (it_out / len(t_out), 0, 0)

    ax.plot(x_out, c_t, color=color)

    ax = axs[1]

    color = (0, 0, it_out / len(t_out))

    ax.plot(x_out, mu_t, color=color)

plt.show()

In [None]:
fig, ax = plt.subplots()

c = c_of_y(u.sub(0), ufl.exp)
mu = u.sub(1).eval(points_on_proc, cells)

W = V.sub(1).collapse()[0]

c = ufl.variable(c)
dFdc = ufl.diff(free_energy(c, ufl.ln, ufl.sin), c)

chem_pot = dfx.fem.Function(W)
chem_pot.interpolate(dfx.fem.Expression(dFdc, W.element.interpolation_points()))
chem_pot = chem_pot.eval(points_on_proc, cells)

ax.plot(x_out, mu)
ax.plot(x_out, chem_pot)

plt.show()

In [None]:
q, f_bar, mu = np.array(rt_analysis.data).T
t = rt_analysis.t

fig, ax = plt.subplots()

ax1 = ax.twinx()

(p1,) = ax1.plot(t, q, '--', label=r"$q$")
ax1.set_ylabel(r"$q$")

(p2,) = ax.plot(t, mu, label=r"$\left. \mu \right|_{\partial \omega_I}$")
(p3,) = ax.plot(t, f_bar, label=r"$\bar f$")

labs = [l.get_label() for l in [p1, p2, p3]]

ax.legend([p1, p2, p3], labs, loc="lower right")

plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(q, mu, label=r"$\left. \mu \right|_{\partial \omega_I}$")
ax.plot(q, f_bar, label=r"$\overline{f(c)}$")
ax.plot(q, free_energy(q, np.log, np.sin), label=r"$f(q)$")

ax.set_xlabel(r"q")

ax.legend()

plt.show()