In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dolfinx as dfx

from dolfinx.fem.petsc import NonlinearProblem

%matplotlib widget
from matplotlib import pyplot as plt
# plt.style.use('fivethirtyeight')

from mpi4py import MPI

import numpy as np

from petsc4py import PETSc

import random

import ufl

from cahn_hilliard_utils import (
    charge_discharge_stop, 
    AnalyzeOCP,
    y_of_c,
    c_of_y,
    populate_initial_data)

from fenicsx_utils import (evaluation_points_and_cells,
                           get_mesh_spacing,
                           time_stepping,
                           NewtonSolver,
                           Fenicx1DOutput)

from gmsh_utils import dfx_spherical_mesh

from plotting_utils import (
    add_arrow, 
    plot_charging_cycle, 
    plot_time_sequence,
    animate_time_series)

comm_world = MPI.COMM_WORLD

In [None]:
# Discretization
# --------------

# Set up the mesh
n_elem = 128

mesh = dfx.mesh.create_unit_interval(comm_world, n_elem)

dx_cell = get_mesh_spacing(mesh)

print(f"Cell spacing: h = {dx_cell}")

# For later plotting use
x = np.linspace(0, 1, 101)
points_on_proc, cells = evaluation_points_and_cells(mesh, x)

# Initial timestep size
dt = dfx.fem.Constant(mesh, dx_cell * 0.01)

In [None]:
elem1 = ufl.FiniteElement("Lagrange", mesh.ufl_cell(), 1)

mixed_element = elem1 * elem1

V = dfx.fem.FunctionSpace(mesh, mixed_element)  # A mixed two-component function space

In [None]:
# The mixed-element functions
u = dfx.fem.Function(V)
u0 = dfx.fem.Function(V)

In [None]:
# Compute the chemical potential df/dc
a = 6. / 4
b = 0.2
cc = 5

# a = 5. # 6. / 4
# b = 0. # 0.2
# cc = 0 # 5

free_energy = lambda u, log, sin: u * log(u) + (1-u) * log(1-u) + a * u * (1 - u) + b * sin(cc * np.pi * u)

fig, ax = plt.subplots()

eps = 1e-3

c_plot = np.linspace(eps, 1-eps, 200)

ax.plot(c_plot, free_energy(c_plot, np.log, np.sin))

plt.show()

In [None]:
# Experimental setup
# ------------------

T_final = 2.  # ending time

# charging current
I_charge = dfx.fem.Constant(mesh, 1.0)

def experiment(t, u, I_charge, **kwargs):

    return charge_discharge_stop(t, u, I_charge, c_of_y=lambda y: c_of_y(y, ufl.exp), **kwargs)

event_params = dict(I_charge=I_charge, stop_at_empty=False, cycling=False)

In [None]:
# The variational form
# --------------------
from cahn_hilliard_utils import cahn_hilliard_form

c_of_y1 = lambda y: c_of_y(y, ufl.exp)
# c_of_y=lambda y: y

params = dict(I_charge=I_charge)

F = cahn_hilliard_form(
    u,
    u0,
    dt,
    free_energy=lambda c: free_energy(c, ufl.ln, ufl.sin),
    theta=0.75,
    c_of_y=c_of_y1,
    M=lambda c: 1. * c * (1 - c),
    lam=0.1,
    **params
)

In [None]:
# Initial data
# ------------

u_ini = dfx.fem.Function(V)

# Constant
c_ini_fun = lambda x: eps * np.ones_like(x[0])

# Initial charge distribution.
# c_ini_fun = lambda x: eps + 0.5 * np.sin(np.pi * x[0])

populate_initial_data(u_ini, c_ini_fun, lambda c: free_energy(c, ufl.ln, ufl.sin))

plt.figure()

plt.plot(x, c_of_y(u_ini.sub(0).eval(points_on_proc, cells), np.exp))
plt.plot(x, u_ini.sub(1).eval(points_on_proc, cells))

plt.show()

In [None]:
problem = NonlinearProblem(F, u)

solver = NewtonSolver(comm_world, problem)

In [None]:
u.interpolate(u_ini)

n_out = 501

output = Fenicx1DOutput(u, np.linspace(0, T_final, n_out), x)

rt_analysis = AnalyzeOCP(c_of_y = lambda y: c_of_y(y, ufl.exp))

time_stepping(
    solver,
    u,
    u0,
    T_final,
    dt,
    dt_increase=1.0,
    dt_max=1e-3,
    event_handler=experiment,
    output=output,
    runtime_analysis=rt_analysis,
    **event_params,
)

In [None]:
fig, ax = plot_time_sequence(output, lambda y: c_of_y(y, np.exp))

plt.show()

In [None]:
mesh_3d, _, _ = dfx_spherical_mesh(resolution=1.)

In [None]:
from plotting_utils import PyvistaAnimation

anim = PyvistaAnimation(
    output,
    mesh_3d=mesh_3d,
    c_of_y=lambda y: c_of_y(y, np.exp),
    res=1.0,
    clim=[0.0, 1.0],
    cmap="hot",
)

widget = anim.get_slider_widget()

In [None]:
anim.get_gif_animation()

In [None]:
q, f_bar, mu_bc = np.array(rt_analysis.data).T
t = rt_analysis.t

fig, ax = plot_charging_cycle(q, mu_bc, eps)

# TODO: use AD or something else to generalize
q_plot = np.linspace(eps, 1-eps, 101)
dFdc = np.log(q_plot / (1 - q_plot)) + a * (1 - 2 * q_plot) + b * np.cos(cc * np.pi * q_plot)

ax.plot(q_plot, -dFdc, 'r--', label=r"$f(q)$")

fig.savefig("pp_output/CH_4_min_charging_cycle.pdf")

plt.show()