In [None]:
import logging

import optimistix as optx

from atmodeller import (
    ChemicalSpecies,
    EquilibriumModel,
    Planet,
    SolverParameters,
    SpeciesNetwork,
    debug_logger,
    earth_oceans_to_hydrogen_mass,
)
from atmodeller.solubility import get_solubility_models

logger = debug_logger()
logger.setLevel(logging.INFO)

# For more output use DEBUG
# logger.setLevel(logging.DEBUG)

# Iteration

This notebook is available at `notebooks/iteration.ipynb` and is easiest to obtain by downloading the source code.

## Simple changing constraints

For models that require iterative updates where constraints evolve gradually&mdash;such as during time integration or other forms of sequential solving&mdash;*Atmodeller* can be used as follows. The order of the arguments and the size of the arrays must match those used to initialise the model, although the values themselves can vary between iterations.

A simple Python looping structure can be used to perform these updates. This approach is often the most intuitive way to couple *Atmodeller* with external codes or models that provide new constraints at each step. While not always the most performant strategy, especially for large parameter sweeps or high-resolution simulations, it offers a clear and flexible mechanism for driving iterative update processes.

In [None]:
# Atmodeller initialisation outside of the iterative update (e.g., time loop)

solubility_models = get_solubility_models()

H2_g = ChemicalSpecies.create_gas("H2")
H2O_g = ChemicalSpecies.create_gas("H2O", solubility=solubility_models["H2O_peridotite_sossi23"])
O2_g = ChemicalSpecies.create_gas("O2")
species = SpeciesNetwork((H2_g, H2O_g, O2_g))

planet = Planet()
model = EquilibriumModel(species)

# Optionally, set the solver and its parameters. For an iterative update loop, you typically want
# the solver to report failures (throw=True) so you can handle them. Otherwise, failed solutions
# will propagate through the loop and generate meaningless results.
solver = optx.Newton
solver_parameters = SolverParameters(solver=solver, throw=True)

# Solve once for the initial state
oceans = 1
h_kg = earth_oceans_to_hydrogen_mass(oceans)
o_kg = 6.25774e20

mass_constraints = {"H": h_kg, "O": o_kg}
model.solve(
    state=planet,
    mass_constraints=mass_constraints,
    solver_parameters=solver_parameters,
    solver_type="basic",
)

# Get the solution from the initial state to provide as the guess for the next solution, which
# usually works well when constraints are not changing much between iterations.
output = model.output

# Iterative loop parameters
start_index = 1
end_index = 4

# Using Atmodeller in the iterative update loop

# This is the update loop, where something changes and you want to re-solve using Atmodeller
for ii in range(start_index, end_index):
    # Let's say we update the mass constraints. The number of constraints and the value type (here,
    # floats) must remain the same as the initialised model, but you can update their values.
    logger.info("Iteration %d", ii)
    logger.info("Your code does something here to compute new masses")
    # For example, decrease H and O masses by factors that depend on the iteration number,
    # mimicking atmospheric escape or other loss processes.
    H_decrease = 1 - 0.1 * ii
    O_decrease = 1 - 0.05 * ii
    # Let's also change the melt fraction. We must create a new Planet with the desired properties.
    planet = Planet(mantle_melt_fraction=1 - 0.1 * ii)
    mass_constraints = {"H": h_kg * H_decrease, "O": o_kg * O_decrease}
    # These solves are fast because they use the JAX-compiled code after compiling once. Note that
    # we pass in an estimate of the initial_log_number_moles from the previous iteration, which
    # helps with both convergence and speed.
    logger.info("Atmodeller solve using JIT compiled code")
    model.solve(
        state=planet,  # Pass in the new planet
        mass_constraints=mass_constraints,  # Pass in the new constraints
        solver_parameters=solver_parameters,  # Keep this the same
        initial_log_number_moles=output.log_number_moles,  # Pass in the previous solution
    )
    # Update output with the new solution to use as the initial guess for the next iteration
    output = model.output

    # Quick look at the solution
    solution = output.quick_look()
    logger.info("solution = %s", solution)

    # Get complete solution as a dictionary
    # If required, get complete output to feedback into other calculations during the time loop
    # solution_asdict = output.asdict()

## Fully JAX compatible approach

You may be thinking: "*Atmodeller* is a JAX-compatible code, so why would I embed JAX-compiled functions inside an inefficient Python for loop?" And you'd be absolutely right. While simple loops offer clarity, they also limit performance by forcing execution back onto the Python interpreter at every step. Instead, there is a far more optimal way to integrate an *Atmodeller* solver into a JAX workflow&mdash;one that keeps the entire update sequence within JAXâ€™s functional, compiled execution model. By restructuring the iterative procedure into a form suitable for ``jax.lax.scan`` or similar control-flow primitives, the full computation can be jitted end-to-end, avoiding Python overhead and enabling XLA to optimise the entire sequence as a single fused computation. This approach preserves the flexibility of iterative updates while achieving JAX-level performance and full accelerator compatibility.

In [1]:
# TODO