In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib

In [None]:
import basix

import dolfinx as dfx

import matplotlib.pyplot as plt

from mpi4py.MPI import COMM_WORLD as comm, SUM

import numpy as np

import pyvista as pv

import ufl

from pyMoBiMP.cahn_hilliard_utils import (
    c_of_y, compute_chemical_potential, _free_energy as free_energy)

from CH_4_DFN_grid import (AnalyzeOCP,
                           DFN_function_space,
                           create_particle_summation_measure,
                           FileOutput,
                           plot_solution_on_grid,
                           physical_setup,
                           TestCurrent,
                           time_stepping,
                           Voltage
                           )

## Create the mesh and function space

In [None]:
n_radial = 16
n_particles = 16

mesh = dfx.mesh.create_rectangle(comm, ((0., 0.), (1., n_particles)), (n_radial, n_particles))

V = DFN_function_space(mesh)

dA = create_particle_summation_measure(mesh)

form = dfx.fem.form(dfx.fem.Constant(mesh, 1.) * ufl.dx)
value  = dfx.fem.assemble_scalar(form)

print(value)
assert np.isclose(value, n_particles)

form = dfx.fem.form(dfx.fem.Constant(mesh, 1.) * dA)
value  = dfx.fem.assemble_scalar(form)

print(value)
assert np.isclose(value, n_particles)

r, _ = ufl.SpatialCoordinate(mesh)

form = dfx.fem.form(r * dA)
value  = dfx.fem.assemble_scalar(form)

print(value)
assert np.isclose(value, n_particles)

form = dfx.fem.form((1 - r) * dA)
value  = dfx.fem.assemble_scalar(form)

print(value)
assert np.isclose(value, 0.)

## Physical properties

In [None]:
A, a_ratios, L, Ls = physical_setup(V)

# make sure a_ratio sum up to 1 under the integral measure
assert np.isclose(dfx.fem.assemble_scalar(dfx.fem.form(a_ratios * dA)), 1.0, rtol=1e-12)

In [None]:
u = dfx.fem.Function(V)
u0 = dfx.fem.Function(V)

y, mu = ufl.split(u)
y0, mu0 = ufl.split(u0)

In [None]:
I_global = dfx.fem.Constant(mesh, 0.01)
V_cell = Voltage(u, I_global)

Ls = V_cell.Ls

I_particle = - Ls * (mu + V_cell)

## The FEM form

In [None]:
v_c, v_mu = ufl.TestFunctions(V)

theta = 1.0
dt = dfx.fem.Constant(mesh, 1e-6)

c = c_of_y(y)
c0 = c_of_y(y0)

V0, dofs = V.sub(0).collapse()
r = dfx.fem.Function(V0)
r.interpolate(lambda x: x[0])

def M(c):
    return c * (1 - c)

lam = 0.1

def grad_c_bc(c):
    return 0.

s_V = 4 * np.pi * r**2
s_A = 2 * np.pi * r**2

dx = ufl.dx  # The volume element

mu_chem = compute_chemical_potential(free_energy, c)
mu_theta = theta * mu + (theta - 1.0) * mu0

flux = M(c) * mu_theta.dx(0)

F1 = s_V * (c - c0) * v_c * dx
F1 += s_V * flux * v_c.dx(0) * dt * dx
F1 -= I_particle * v_c * dt * dA

F2 = s_V * mu * v_mu * dx
F2 -= s_V * mu_chem * v_mu * dx
F2 -= lam * (s_V * c.dx(0) * v_mu.dx(0) * dx)
F2 += grad_c_bc(c) * (s_A * v_mu * dA)

F = F1 + F2

In [None]:
du = ufl.TrialFunction(V)

dc, dmu = ufl.split(du)

J0 = ufl.derivative(F, u, du)

dVdu = Ls / L * dmu * v_c * a_ratios * dA
dFdV = Ls * dmu * dt * dA

dVdu * dFdV

## Solver

In [None]:
from dolfinx.fem.petsc import NonlinearProblem as NonlinearProblemBase
from dolfinx.nls.petsc import NewtonSolver
from petsc4py import PETSc

class NonlinearProblem(NonlinearProblemBase):
    def __init__(self, *args, callback=lambda: None, **kwargs):
        super().__init__(*args, **kwargs)

        self.callback = callback

    def form(self, x):
        super().form(x)

        self.callback()

problem = NonlinearProblem(F, u, callback=V_cell.update)
solver = NewtonSolver(comm, problem)

solver.error_on_nonconvergence = False
solver.convergence_criterion = "incremental"
solver.rtol = 1e-5
solver.max_it = 50

ksp = solver.krylov_solver
opts = PETSc.Options()
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"
ksp.setFromOptions()

In [None]:
u0.sub(0).x.array[:] = -6  # This corresponds to roughly c = 1e-3

dt.value = 0.

# u.interpolate(u0)  # Initial guess

residual = dfx.fem.form(F)

print(dfx.fem.petsc.assemble_vector(residual).norm())
its, success = solver.solve(u)
print(its, dfx.fem.petsc.assemble_vector(residual).norm())

## Simulation setup and output

In [None]:
T_final = 650.

dt_min = 1e-9
dt_max = 1e-1

In [None]:
rt_analysis = AnalyzeOCP(u,
                         c_of_y,
                         V_cell,
                         filename="CH_4_DFN_rt.txt")

output = FileOutput(u,
                    np.linspace(0, T_final, 101),
                    filename="CH_4_DFN.xdmf",
                    variable_transform=c_of_y)

callback = TestCurrent(u, V_cell, I_global)

In [None]:
time_stepping(
    solver,
    u,
    u0,
    T_final,
    dt,
    V_cell,
    dt_max=dt_max,
    dt_min=dt_min,
    dt_increase=1.5,
    tol=1e-2,
    runtime_analysis=rt_analysis,
    output=output,
    callback=callback
)

In [None]:
y = u.sub(0)

c = dfx.fem.Function(u.sub(0).collapse().function_space)
c.interpolate(
    dfx.fem.Expression(c_of_y(y), c.function_space.element.interpolation_points()))

plot_solution_on_grid(c)