In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib

In [None]:
import basix

import dolfinx as dfx

import matplotlib.pyplot as plt

from mpi4py.MPI import COMM_WORLD as comm, SUM

import numpy as np

import pyvista as pv

import ufl

from pyMoBiMP.cahn_hilliard_utils import (
    c_of_y, compute_chemical_potential, _free_energy as free_energy)

from CH_4_DFN_grid import (AnalyzeOCP,
                           DFN_function_space,
                           DFN_FEM_form,
                           create_particle_summation_measure,
                           FileOutput,
                           plot_solution_on_grid,
                           physical_setup,
                           TestCurrent,
                           time_stepping,
                           Voltage
                           )

## Create the mesh and function space

In [None]:
n_radial = 16
n_particles = 128

mesh = dfx.mesh.create_rectangle(comm, 
                                 ((0., 0.), 
                                  (1., n_particles)), 
                                 (n_radial, n_particles),
                                 cell_type=dfx.mesh.CellType.quadrilateral)

V = DFN_function_space(mesh)

dA = create_particle_summation_measure(mesh)

## Physical properties

In [None]:
A, a_ratios, L, Ls = physical_setup(V)

In [None]:
u = dfx.fem.Function(V)
u0 = dfx.fem.Function(V)

y, mu = ufl.split(u)
y0, mu0 = ufl.split(u0)

In [None]:
I_global = dfx.fem.Constant(mesh, 0.01)
V_cell = Voltage(u, I_global)

Ls = V_cell.Ls

I_particle = - Ls * (mu + V_cell)

## The FEM form

In [None]:
v = ufl.TestFunction(V)

dt = dfx.fem.Constant(mesh, 0.0)

F = DFN_FEM_form(u, u0, v, dt, V_cell)


In [None]:
# du = ufl.TrialFunction(V)

# dc, dmu = ufl.split(du)

# J0 = ufl.derivative(F, u, du)

# dVdu = Ls / L * dmu * v_c * a_ratios * dA
# dFdV = Ls * v_c * dt * dA

In [None]:
# dfx.fem.petsc.assemble_vector(dfx.fem.form(dFdV))[:].shape

In [None]:
# mat = dfx.fem.petsc.assemble_vector(dfx.fem.form(dVdu))

## Solver

In [None]:
from dolfinx.fem.petsc import NonlinearProblem as NonlinearProblemBase
from pyMoBiMP.fenicsx_utils import NewtonSolver
from petsc4py import PETSc

class NonlinearProblem(NonlinearProblemBase):
    def __init__(self, *args, callback=lambda: None, **kwargs):
        super().__init__(*args, **kwargs)

        self.callback = callback

    def form(self, x):
        super().form(x)

        self.callback()

problem = NonlinearProblem(F, u, callback=V_cell.update)
solver = NewtonSolver(comm, problem)

solver.error_on_nonconvergence = False
solver.convergence_criterion = "incremental"
solver.rtol = 1e-7
solver.max_it = 50

ksp = solver.krylov_solver
opts = PETSc.Options()
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"
ksp.setFromOptions()

In [None]:
u0.sub(0).x.array[:] = -6  # This corresponds to roughly c = 1e-3

dt.value = 0.

# u.interpolate(u0)  # Initial guess

residual = dfx.fem.form(F)

print(dfx.fem.petsc.assemble_vector(residual).norm())
its, success = solver.solve(u)
error = dfx.fem.petsc.assemble_vector(residual).norm()
print(its, error)
assert np.isclose(error, 0.)

## Simulation setup and output

In [None]:
T_final = 650.

dt_min = 1e-9
dt_max = 1e-2

In [None]:
rt_analysis = AnalyzeOCP(u,
                         c_of_y,
                         V_cell,
                         filename="CH_4_DFN_rt.txt")

output = FileOutput(u,
                    np.linspace(0, T_final, 101),
                    filename="CH_4_DFN.xdmf",
                    variable_transform=c_of_y)

callback = TestCurrent(u, V_cell, I_global)

In [None]:
time_stepping(
    solver,
    u,
    u0,
    T_final,
    dt,
    V_cell,
    dt_max=dt_max,
    dt_min=dt_min,
    dt_increase=1.1,
    tol=1e-4,
    runtime_analysis=rt_analysis,
    output=output,
    callback=callback
)

In [None]:
y = u.sub(0)

c = dfx.fem.Function(u.sub(0).collapse().function_space)
c.interpolate(
    dfx.fem.Expression(c_of_y(y), c.function_space.element.interpolation_points()))

plot_solution_on_grid(c)