# Test MODFLOW API

In [None]:
from pathlib import Path

import flopy
import pastas as ps
from pastas.timer import SolveTimer

from pastas_plugins import modflow as ppmf

bindir = Path("bin")
if not (bindir / "mf6").exists():
    bindir.mkdir(parents=True, exist_ok=True)
    flopy.utils.get_modflow(bindir, repo="modflow6")

dll = bindir / "libmf6.so"

ps.set_use_cache(True)

## Load Data

In [None]:
ds = ps.load_dataset("collenteur_2019")
head = ds["head"].squeeze().dropna()
prec = ds["rain"].squeeze().dropna().resample("D").asfreq().fillna(0.0)
evap = ds["evap"].squeeze().dropna()

prec = prec.loc["2002-11-01":]
evap = evap.loc["2002-11-01":]

## MODFLOW API

In [None]:
ml1 = ps.Model(head, name="mftest")
mfml = ppmf.ModflowModelApi(
    model=ml1,
    dll=dll,
    sim_ws=Path("mftest"),
    silent=False,
)
rch = ppmf.ModflowRch(prec, evap)
ghb = ppmf.ModflowGhb()
mfml.add_modflow_package([rch, ghb])
ml1.add_stressmodel(mfml)

In [None]:
ml1.parameters

In [None]:
ml1.get_parameters("mfapi")

In [None]:
p = tuple(ml1.parameters.initial.values)

In [None]:
# %%time
sim = ml1.simulate(p)

In [None]:
with SolveTimer() as t:
    ml1.solve(diff_step=1e-4, callback=t.timer)

## Classic MF6

Solve with classic MF6 implementation.

In [None]:
ml2 = ps.Model(head, name="mftest")
mfml2 = ppmf.ModflowModel(
    model=ml2,
    exe_name=bindir / "mf6",
    sim_ws=Path("mftest"),
    silent=True,
)
rch = ppmf.ModflowRch(prec, evap)
ghb = ppmf.ModflowGhb()
mfml2.add_modflow_package([rch, ghb])
ml2.add_stressmodel(mfml2)

In [None]:
p = tuple(ml2.parameters.initial.values)

In [None]:
%%time
sim2 = ml2.simulate(p)

In [None]:
with SolveTimer() as t:
    ml2.solve(diff_step=1e-4, callback=t.timer)

## Pastas

In [None]:
ml3 = ps.Model(head, name="pastas")
rm = ps.RechargeModel(prec, evap, rfunc=ps.Exponential(), name="rch")
ml3.add_stressmodel(rm)
ml3.solve()

In [None]:
ax = ml3.plot(figsize=(8, 3))
sim1 = ml1.simulate()
ax.plot(sim1.index, sim1, label=f"API ({ml1.stats.rsq():.2%})")
sim2 = ml2.simulate()
ax.plot(sim2.index, sim2, label=f"MF6 ({ml2.stats.rsq():.2%})")
ax.legend(loc=(0, 1), frameon=False, ncol=4)

## MODFLOW API objects (for testing)

## MODFLOW with callback