In [None]:
import numpy as np
import pybamm

from li02_model import LiO2_1D
from jiang_params import jiang_2020_params

def main():
    # --------------------
    # Build model + params
    # --------------------
    model = LiO2_1D()
    params = jiang_2020_params()

    # --------------------
    # Geometry (cathode only)
    # --------------------
    x = pybamm.SpatialVariable("x", domain="cathode", coord_sys="cartesian")
    Lc = pybamm.Parameter("Lc")
    geometry = {"cathode": {x: {"min": pybamm.Scalar(0), "max": Lc}}}
    params.process_geometry(geometry)


    # --------------------
    # Mesh + discretisation
    # --------------------
    submesh_types = {"cathode": pybamm.Uniform1DSubMesh}
    var_pts = {x: 80}  # increase for smoother profiles
    mesh = pybamm.Mesh(geometry, submesh_types, var_pts)

    spatial_methods = {"cathode": pybamm.FiniteVolume()}
    disc = pybamm.Discretisation(mesh, spatial_methods)

    # Process model (apply geometry/discretisation and parameters)
    params.process_model(model)
    disc.process_model(model)

    # --------------------
    # Solve
    # --------------------
    t_eval = np.linspace(0, 2000, 200)  # seconds
    solver = pybamm.CasadiSolver(mode="safe", atol=1e-6, rtol=1e-6, root_tol=1e-5,)
    solution = solver.solve(model, t_eval, inputs={"J": .2})

    # -------------------
    # Plot
    # --------------------
    # QuickPlot will automatically handle spatial variables
    qp = pybamm.QuickPlot(
        solution,
        output_variables=[
            "O2 concentration",
            "Porosity",
            "Solid potential",
            "Electrolyte potential",
            "Cathode interfacial current density",
            "Film drop [V]",
            "Active area",
        ],
    )
    qp.dynamic_plot()

if __name__ == "__main__":
    main()


In [None]:
import numpy as np
import pybamm
import importlib
from li02_model import LiO2_1D
from jiang_params import jiang_2020_params

def diagnose_solver_failure(
    model_or_sim,
    parameter_values=None,
    inputs=None,
    t0=0.0,
    t_eval=None,
    solver=None,
    atol=1e-8,
    rtol=1e-8,
    residual_norm="inf",  # "inf" or "l2"
    max_print=12,
):
    """
    PyBaMM-only diagnostic helper for models failing with SolverError.

    Works best if you pass a pybamm.Simulation that already includes any mesh/geometry setup.
    If you pass a model, you MUST also pass parameter_values, and the model must be buildable
    by pybamm.Simulation(model, parameter_values=...).

    What it checks (in this order):
      1) Well-posedness checks (symbols, domains, BCs, etc.) on built model
      2) Initial condition vector y0 (finite? shape?)
      3) Algebraic residual at t=t0 using y0 (finite? large? which entries?)
      4) RHS evaluation at t=t0 using y0 (finite?)
      5) Event function values at t=t0 (already negative / triggered?)
      6) Optional: attempt a tiny integration step to see the first failure

    Returns a dict with numeric diagnostics + any caught exception text.
    """

    inputs = {} if inputs is None else dict(inputs)

    # --- Build a Simulation in a "minimal assumptions" way ---
    if isinstance(model_or_sim, pybamm.Simulation):
        sim = model_or_sim
        if solver is not None:
            sim.solver = solver
    else:
        if parameter_values is None:
            raise ValueError("If you pass a model, you must pass parameter_values.")
        sim = pybamm.Simulation(model_or_sim, parameter_values=parameter_values, solver=solver)

    # optional tiny default t_eval
    if t_eval is None:
        t_eval = np.array([t0, t0 + 1e-6])

    out = {
        "built": False,
        "well_posed": None,
        "y0_finite": None,
        "y0_min": None,
        "y0_max": None,
        "alg_residual_finite": None,
        "alg_residual_norm": None,
        "rhs_finite": None,
        "rhs_min": None,
        "rhs_max": None,
        "events": [],
        "tiny_solve_ok": None,
        "exception": None,
    }

    def _norm(v):
        if residual_norm == "l2":
            return float(np.linalg.norm(v))
        return float(np.max(np.abs(v)))  # inf-norm

    try:
        # --- Build (parameter processing + discretisation) ---
        sim.build()
        out["built"] = True
        m = sim.built_model

        # --- Well-posedness ---
        # This catches lots of “structural” issues before the solver even runs.
        try:
            m.check_well_posedness()
            out["well_posed"] = True
        except Exception as e:
            out["well_posed"] = False
            out["exception"] = f"check_well_posedness failed: {e!r}"
            # keep going: sometimes it still provides useful residual info

        # --- Initial conditions vector ---
        # concatenated_initial_conditions is a Symbol -> evaluate gives y0
        y0 = m.concatenated_initial_conditions.evaluate(t=t0, inputs=inputs).reshape(-1)

        out["y0_finite"] = bool(np.all(np.isfinite(y0)))
        out["y0_min"] = float(np.nanmin(y0))
        out["y0_max"] = float(np.nanmax(y0))

        if not out["y0_finite"]:
            bad = np.where(~np.isfinite(y0))[0][:max_print]
            raise FloatingPointError(f"y0 contains NaN/Inf at indices {bad.tolist()}")

        # --- Algebraic residual at t0 ---
        # If this is huge or non-finite, your consistent-state step will fail.
        if m.concatenated_algebraic is not None:
            alg = m.concatenated_algebraic.evaluate(t=t0, y=y0, inputs=inputs).reshape(-1)
            out["alg_residual_finite"] = bool(np.all(np.isfinite(alg)))
            if not out["alg_residual_finite"]:
                bad = np.where(~np.isfinite(alg))[0][:max_print]
                raise FloatingPointError(
                    f"Algebraic residual contains NaN/Inf at indices {bad.tolist()}"
                )
            out["alg_residual_norm"] = _norm(alg)

        # --- RHS at t0 ---
        rhs = m.concatenated_rhs.evaluate(t=t0, y=y0, inputs=inputs).reshape(-1)
        out["rhs_finite"] = bool(np.all(np.isfinite(rhs)))
        out["rhs_min"] = float(np.nanmin(rhs))
        out["rhs_max"] = float(np.nanmax(rhs))
        if not out["rhs_finite"]:
            bad = np.where(~np.isfinite(rhs))[0][:max_print]
            raise FloatingPointError(f"RHS contains NaN/Inf at indices {bad.tolist()}")

        # --- Events at t0 ---
        # If any event is already negative at t0, the solver may stop instantly or fail.
        if hasattr(m, "events") and m.events:
            for ev in m.events:
                try:
                    val = ev.expression.evaluate(t=t0, y=y0, inputs=inputs)
                    val = float(np.array(val).reshape(-1)[0])
                except Exception as e:
                    val = np.nan
                    out["events"].append({"name": ev.name, "value": val, "eval_error": repr(e)})
                    continue
                out["events"].append({"name": ev.name, "value": val, "triggered_at_t0": bool(val < 0)})

        # --- Tiny solve attempt (often reproduces the exact failure mode quickly) ---
        # Use the sim's solver if present; otherwise construct a robust default.
        if sim.solver is None:
            sim.solver = pybamm.CasadiSolver(atol=atol, rtol=rtol, mode="safe")

        try:
            _ = sim.solve(t_eval=t_eval, inputs=inputs)
            out["tiny_solve_ok"] = True
        except Exception as e:
            out["tiny_solve_ok"] = False
            out["exception"] = f"tiny solve failed: {e!r}"

    except Exception as e:
        out["exception"] = out["exception"] or repr(e)

    # --- Pretty print summary (compact) ---
    print("\n=== PyBaMM Solver Failure Diagnostics ===")
    print(f"built: {out['built']}")
    print(f"well_posed: {out['well_posed']}")
    print(f"y0: finite={out['y0_finite']}, min={out['y0_min']:.3e}, max={out['y0_max']:.3e}")
    if out["alg_residual_norm"] is not None:
        print(
            f"algebraic residual: finite={out['alg_residual_finite']}, "
            f"{residual_norm}-norm={out['alg_residual_norm']:.3e}"
        )
    print(
        f"rhs: finite={out['rhs_finite']}, min={out['rhs_min']:.3e}, max={out['rhs_max']:.3e}"
        if out["rhs_finite"] is not None
        else "rhs: (not evaluated)"
    )
    if out["events"]:
        print("events at t0:")
        for ev in out["events"][:max_print]:
            trig = ev.get("triggered_at_t0", False)
            extra = f", eval_error={ev['eval_error']}" if "eval_error" in ev else ""
            print(f"  - {ev['name']}: {ev['value']:.3e} (triggered={trig}){extra}")
    print(f"tiny_solve_ok: {out['tiny_solve_ok']}")
    if out["exception"] is not None:
        print(f"exception: {out['exception']}")
    print("========================================\n")

    return out

# (optional) always reload your model file after edits
import li02_model
importlib.reload(li02_model)
from li02_model import LiO2_1D

# --- Parameters: try to load Jiang set; otherwise fall back to a minimal dict ---
try:
    from jiang_params import jiang_2020_params
    params = jiang_2020_params()
    print("Loaded parameters from jiang_params.jiang_2020_params()")
except Exception as e:
    print("Could not import jiang_2020_params(); using minimal fallback params.\n", repr(e))
    params = pybamm.ParameterValues(
        {
            "F": 96485.33212,
            "R": 8.314462618,
            "T": 298.15,
            "Eeq": 2.96,
            "Lc": 200e-6,
            "R_ohm": 0.0,
            "R_film": 0.0,
            "k_film": 1.0,
            "l_sigma": 1.0,
            "k_in": 1.0,
            "tau_ramp": 1.0,
            "D_Li": 1e-10,
            "D_O2": 1e-10,
            "S_O2": 1.0,
            "S_O2_vol": 1.0,
            "p": 1.0,
            "tP": 0.4,
            "sigma": 10.0,
            "kappa": 1.0,
            "beta": 1.0,
            "Bruggeman exponent": 1.5,
            "eps0": 0.8,
            "eps_s0": 0.2,
            "eps_min": 1e-4,
            "a0": 1e5,
            "r0": 1e-6,
            "lm": 1e-7,
            "c_Li_0": 1000.0,
            "cO2_0": 1.0,
            "cO2_ext": 1.0,
            "c_Li2O2_max": 1e6,
            "rho_Li2O2": 2140.0,
            "M_Li2O2": 45.88e-3,
            "V_cut": 2.0,
        }
    )

# --- Build a Simulation with geometry + discretisation settings ---
model = LiO2_1D()

x = pybamm.SpatialVariable("x", domain="cathode", coord_sys="cartesian")
geometry = {"cathode": {x: {"min": pybamm.Scalar(0), "max": pybamm.Parameter("Lc")}}}
submesh_types = {"cathode": pybamm.Uniform1DSubMesh}
var_pts = {x: 80}
spatial_methods = {"cathode": pybamm.FiniteVolume()}

sim = pybamm.Simulation(
    model,
    parameter_values=params,
    geometry=geometry,
    submesh_types=submesh_types,
    var_pts=var_pts,
    spatial_methods=spatial_methods,
    solver=pybamm.CasadiSolver(mode="safe", atol=1e-8, rtol=1e-8),
)

# --- Call your diagnostic helper (assumes diagnose_solver_failure is defined above) ---
diag = diagnose_solver_failure(
    sim,
    inputs={"J": 0.20},                 # <- change current here
    t0=0.0,
    t_eval=np.array([0.0, 1e-6]),       # tiny step to reproduce early failure
)

diag


In [None]:
import numpy as np
import pybamm

def charge_balance_check(sim, t=0.0, J=None, npts=200, atol=1e-6, rtol=1e-3, verbose=True):
    """
    Consistency check for the algebraic charge-balance coupling in the cathode domain.

    Verifies (approximately):
        ∫ a*j dx  ≈  i_l(L) - i_l(0)
        ∫ a*j dx  ≈ -(i_s(L) - i_s(0))

    where
        i_s = -sigma_eff * d(phi_s)/dx
        i_l = -kappa_eff * d(phi_l)/dx     (no concentrated-solution term)

    Parameters
    ----------
    sim : pybamm.Simulation
        Must already be built AND solved (sim.solve called).
    t : float
        Time at which to evaluate (default 0.0).
    J : float or None
        Applied current density used as input "J". If None, we won't compare to J directly.
    npts : int
        Number of points for the domain integral.
    atol, rtol : float
        Absolute/relative tolerances used to report pass/fail.
    verbose : bool
        Print a compact report.

    Returns
    -------
    dict with values of boundary currents, integral, and mismatches.
    """
    if sim.solution is None:
        raise ValueError("sim.solution is None. Run sim.solve(...) before calling this helper.")

    sol = sim.solution
    model = sim.model
    param = sim.parameter_values

    # Pull primary fields from your model.variables (as defined in li02_model.py)
    eps   = model.variables["Porosity"]
    phi_s = model.variables["Solid potential"]
    phi_l = model.variables["Electrolyte potential"]
    j_c   = model.variables["Cathode interfacial current density"]
    a     = model.variables["Active area"]

    # Reconstruct the same "clipped" porosity used in your model
    eps_s0 = pybamm.Parameter("eps_s0")
    eps_floor = pybamm.Scalar(1e-6)
    eps_cap   = pybamm.Scalar(1e-6)
    eps_clip = pybamm.minimum(pybamm.maximum(eps, eps_floor),
                              (pybamm.Scalar(1) - eps_s0 - eps_cap))

    # Effective transport (matching your model)
    b     = pybamm.Parameter("Bruggeman exponent")
    sigma = pybamm.Parameter("sigma")
    kappa = pybamm.Parameter("kappa")
    sigma_eff = sigma * (1 - eps_clip) ** b
    kappa_eff = kappa * eps_clip ** b

    # Currents (1D)
    i_s = -sigma_eff * pybamm.grad(phi_s)
    i_l = -kappa_eff * pybamm.grad(phi_l)

    # Volumetric source
    src = a * j_c

    # Make processed variables (so we can evaluate them on grids)
    i_s_p = pybamm.ProcessedVariable(i_s, sol, param)
    i_l_p = pybamm.ProcessedVariable(i_l, sol, param)
    src_p = pybamm.ProcessedVariable(src, sol, param)

    # Domain length
    Lc = float(param.evaluate(pybamm.Parameter("Lc")))
    x = np.linspace(0.0, Lc, npts)

    # Evaluate
    src_x = src_p(t, x)                 # shape (npts,)
    src_int = float(np.trapz(src_x, x)) # ∫ a*j dx

    i_s_0 = float(i_s_p(t, 0.0))
    i_s_L = float(i_s_p(t, Lc))
    i_l_0 = float(i_l_p(t, 0.0))
    i_l_L = float(i_l_p(t, Lc))

    # Flux differences
    rhs_from_il = i_l_L - i_l_0
    rhs_from_is = -(i_s_L - i_s_0)

    # Mismatches
    mis_il = src_int - rhs_from_il
    mis_is = src_int - rhs_from_is

    # Relative errors (guard against divide-by-zero)
    denom_il = max(atol, abs(src_int), abs(rhs_from_il))
    denom_is = max(atol, abs(src_int), abs(rhs_from_is))
    rel_il = abs(mis_il) / denom_il
    rel_is = abs(mis_is) / denom_is

    ok_il = (abs(mis_il) <= atol) or (rel_il <= rtol)
    ok_is = (abs(mis_is) <= atol) or (rel_is <= rtol)

    out = {
        "t": t,
        "Lc": Lc,
        "J_input": J,
        "i_s(0)": i_s_0,
        "i_s(L)": i_s_L,
        "i_l(0)": i_l_0,
        "i_l(L)": i_l_L,
        "int_a_j_dx": src_int,
        "il_flux_diff": rhs_from_il,
        "is_flux_diff": rhs_from_is,
        "mismatch_vs_il": mis_il,
        "mismatch_vs_is": mis_is,
        "relerr_vs_il": rel_il,
        "relerr_vs_is": rel_is,
        "pass_vs_il": ok_il,
        "pass_vs_is": ok_is,
    }

    if verbose:
        print("\n=== Charge-balance consistency check ===")
        print(f"t = {t:.6g} s,  Lc = {Lc:.6g} m")
        if J is not None:
            print(f"J input = {J:.6g} A/m^2")
        print("\nBoundary currents (A/m^2):")
        print(f"  i_s(0) = {i_s_0:.6e}    i_s(L) = {i_s_L:.6e}")
        print(f"  i_l(0) = {i_l_0:.6e}    i_l(L) = {i_l_L:.6e}")
        print("\nIntegral/source balance (A/m^2):")
        print(f"  ∫ a*j dx      = {src_int:.6e}")
        print(f"  i_l(L)-i_l(0) = {rhs_from_il:.6e}   mismatch = {mis_il:.6e}   rel = {rel_il:.3e}   pass={ok_il}")
        print(f" -(i_s(L)-i_s(0))= {rhs_from_is:.6e}   mismatch = {mis_is:.6e}   rel = {rel_is:.3e}   pass={ok_is}")

        if J is not None:
            print("\nBC sanity (only meaningful if your BCs enforce these):")
            print(f"  compare i_s(L) to J:  i_s(L)-J = {i_s_L - J:.6e}")
            print(f"  compare i_l(0) to J:  i_l(0)-J = {i_l_0 - J:.6e}")
            print(f"  compare i_l(L) to 0:  i_l(L)   = {i_l_L:.6e}")
            print(f"  compare i_s(0) to 0:  i_s(0)   = {i_s_0:.6e}")

    return out

def debug_well_posedness_structure(sim):
    sim.build()
    m = sim.built_model

    print("=== Algebraic keys (variables with algebraic equations) ===")
    alg_keys = list(m.algebraic.keys())
    for k in alg_keys:
        print(" -", k)

    print("\n=== Differential keys (variables with rhs) ===")
    rhs_keys = list(m.rhs.keys())
    for k in rhs_keys:
        print(" -", k)

    print("\n=== y_slices (what is actually in the state vector) ===")
    # y_slices maps Variable -> slice
    for var, slc in m.y_slices.items():
        # variable domains help identify duplicates
        dom = getattr(var, "domain", None)
        print(f" - {var}  domain={dom}  slice={slc}")

    print("\nCounts:")
    print("  #algebraic keys:", len(alg_keys))
    print("  #rhs keys:", len(rhs_keys))
    print("  #y_slices vars:", len(m.y_slices))

debug_well_posedness_structure(sim)


sim = pybamm.Simulation(
    model,
    parameter_values=params,
    geometry=geometry,
    submesh_types=submesh_types,
    var_pts=var_pts,
    spatial_methods=spatial_methods,
    solver=pybamm.CasadiSolver(mode="safe", atol=1e-8, rtol=1e-8),
)
t_eval=np.array([0.0, 1e-6])
sim.solve(t_eval, inputs={"J": 0})
res = charge_balance_check(sim, t=0.0, J=2.0, npts=300)


In [None]:
m = sim.built_model

# 1) Get the actual well-posedness failure details
try:
    m.check_well_posedness()
    print("Well-posedness: OK")
except Exception as e:
    print("Well-posedness FAILED with:\n", repr(e))

# 2) Compute algebraic residual at t=0 using y0
t0 = 0.0
inputs = {"J": 0.20}  # match your run
y0 = m.concatenated_initial_conditions.evaluate(t=t0, inputs=inputs).reshape(-1)

alg = m.concatenated_algebraic.evaluate(t=t0, y=y0, inputs=inputs).reshape(-1)

print("\nAlgebraic residual inf-norm:", float(np.max(np.abs(alg))))

# 3) Show top offending residual indices (largest magnitude)
k = 15
idx = np.argsort(np.abs(alg))[-k:][::-1]
print(f"\nTop {k} algebraic residual entries:")
for i in idx:
    print(f"  idx={int(i):6d}  residual={float(alg[i]): .3e}")

# 4) Map residual indices back to variables (slice info)
#    This helps identify whether it's phi_s, phi_l, j_c, etc.
print("\nResidual index → variable mapping (by y_slices):")
for name, slc in m.y_slices.items():
    # Only show algebraic variables slices that overlap the algebraic block
    # (y_slices includes both rhs+algebraic in the full y vector)
    # We'll just report which slices contain any of the top indices.
    hits = [i for i in idx if slc.start <= i < slc.stop]
    if hits:
        print(f"  {name}: slice({slc.start},{slc.stop}) contains {len(hits)} of top-{k}")


In [None]:
m = sim.built_model
idx = np.array([239,198,180,181,182,183,184,185,186,187,188,189,190,191,192])  # or your computed idx

def iter_slices(slc):
    if isinstance(slc, slice):
        yield slc
    elif isinstance(slc, (list, tuple)):
        for s in slc:
            if isinstance(s, slice):
                yield s

print("Residual index → variable mapping (by y_slices):")
for name, slc in m.y_slices.items():
    hits = 0
    hit_ranges = []
    for s in iter_slices(slc):
        h = [int(i) for i in idx if s.start <= i < s.stop]
        if h:
            hits += len(h)
            hit_ranges.append(f"[{min(h)}..{max(h)}] in slice({s.start},{s.stop})")
    if hits:
        print(f"  {name}: {hits} hits; " + "; ".join(hit_ranges))

In [None]:
m = sim.built_model

print("n_rhs vars:", len(m.rhs))
print("n_algebraic vars:", len(m.algebraic))

print("\nAlgebraic variable names:")
for v in m.algebraic.keys():
    print(" -", v.name)

# also helpful: see if any algebraic var appears to be "defined twice" conceptually
print("\nAlgebraic equation symbols (one per algebraic var):")
for v, eq in m.algebraic.items():
    print(f" - {v.name}: {type(eq).__name__}")


In [None]:
m = sim.built_model
print("Porosity in rhs?", any(v.name == "Porosity" for v in m.rhs.keys()))
print("Porosity in algebraic?", any(v.name == "Porosity" for v in m.algebraic.keys()))


In [None]:
m = sim.built_model
for v, eq in m.algebraic.items():
    print("\n=== Algebraic equation for:", v.name, "===\n")
    print(eq)  # pretty-prints the expression tree


In [None]:
m = sim.built_model
t0 = 0.0
inputs = {"J": 0.20}
y0 = m.concatenated_initial_conditions.evaluate(t=t0, inputs=inputs).reshape(-1)
alg = m.concatenated_algebraic.evaluate(t=t0, y=y0, inputs=inputs).reshape(-1)

# split algebraic residual by algebraic variable slices
for v in m.algebraic.keys():
    slc = m.y_slices[v]  # slices in full y
    # algebraic residual corresponds to algebraic variables in order; easiest robust way is:
    # use v's slice length to pull matching chunk from alg by accumulating lengths
    print(slc)

In [None]:
import sys
import pybamm

print("Python version:", sys.version)
print("PyBaMM version:", pybamm.__version__)


