Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/advanced/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ Darcy flow, Richards equation, and variably-saturated groundwater modelling.

**[→ Porous Media Flow](porous-flow.md)**

### State Snapshots & Restore
A "stash for timesteps": snapshot the full model state, try a step,
restore exactly if you don't like it. For backtracking, adaptive Δt,
and predictor–corrector workflows.

**[→ State Snapshots & Restore](snapshot-restore.md)**

### Troubleshooting
Common issues, debugging strategies, and solutions.

Expand Down Expand Up @@ -85,6 +92,7 @@ custom-meshes
curved-boundary-conditions
mesh-adaptation
porous-flow
snapshot-restore
troubleshooting
api-patterns
SWARM-INTEGRATION-STATISTICS
Expand Down
206 changes: 206 additions & 0 deletions docs/advanced/snapshot-restore.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
---
title: "State Snapshots & Restore"
---

# State Snapshots & Restore

## Overview

`Model.snapshot()` and `Model.restore()` are a *stash for timesteps* —
a quick "hold that thought, I might need to come back" mechanism for
time-stepping code. Take a snapshot, try a step, and if you don't like
the result, restore and try again. The system is put back exactly as it
was, as if the discarded step never happened.

Typical uses:

- **Backtrack past an instability** — a step blows up; restore and
continue with a smaller Δt or a different scheme.
- **Adaptive Δt with an error / CFL check** — take a step, measure it,
restore and retry if it violated your criterion.
- **Predictor–corrector probing** — try a predictor, inspect the
corrector, fall back if it isn't converging.
- **Multi-stage time integration** (RK-style) — restore to the start
of a step between stages.

This is intentionally *not* archival checkpointing. It is fast,
in-memory, and meant to be used freely within a run. For long-term,
on-disk restart files, use the existing `mesh.write_timestep()` /
`read_timestep()` path, which is unchanged and serves a different
purpose.

## The API

```python
import underworld3 as uw

model = uw.get_default_model()

# ... set up mesh, variables, swarm, solvers, step a few times ...

token = model.snapshot() # capture everything, return a token

# ... take a speculative step you might regret ...

model.restore(token) # put everything back exactly
```

`snapshot()` returns a plain in-memory token. You can hold several at
once and restore any of them. `restore()` returns the model to the
exact state at the moment that token was captured.

## What is captured

You do not enumerate anything — `snapshot()` captures the full state
of the model automatically:

- mesh coordinates,
- all mesh-variable values,
- all swarm particle positions and swarm-variable values,
- solver-internal time-integration history (the `DDt` operators that
drive `AdvDiffusion`, viscoelastic stress history, etc.),
- everything on the model tracker (see below).

Restore rebuilds swarm populations from the snapshot, so it is correct
even if particles migrated, were added, or were lost between snapshot
and restore — that is exactly the situation restore exists for.

## The model tracker: time, step, and your own quantities

A subtle trap in time-stepping scripts: your loop counter and
simulation time usually live in plain Python variables, and
`restore()` has no way to know about them.

```python
model_time = 0.0
token = model.snapshot()
model_time = 5.0 # advance
model.restore(token)
# model_time is still 5.0 — restore cannot reach a local variable
```

`Model.tracker` solves this. It is a model-dwelling record of where the
run is — and anything you put on it is automatically captured and
restored.

```python
model.tracker.time = 0.0
model.tracker.step = 0

token = model.snapshot()

model.tracker.time = 5.0
model.tracker.step = 100

model.restore(token)

model.tracker.time # 0.0 — reverted automatically
model.tracker.step # 0 — reverted automatically
```

`time`, `step` and `dt` come pre-seeded as conventions, but they have
no special status. Any attribute you assign becomes managed state:

```python
model.tracker.peak_velocity = 0.0
model.tracker.energy_history = np.zeros(3)
```

These now travel with every snapshot and revert on every restore — no
extra code, no special handling in your solvers. Using the tracker is
optional; solvers do not depend on it. It is simply the place to keep
the things you want `restore()` to manage.

```{note} Reserved name
`state` is reserved on the tracker (it is the snapshot mechanism's own
hook). Do not use `model.tracker.state` for your own quantity.
```

```{note} git-stash semantics
Restore returns to *exactly* the captured point. A quantity you add to
the tracker *after* taking a snapshot is removed by a restore of that
snapshot — the same way `git stash pop` does not keep work you started
afterwards.
```

## Worked example: adaptive-Δt backtracking

A canonical CFL-controlled stepping loop. The speculative step is
taken, checked, and either kept or discarded:

```python
import numpy as np
import underworld3 as uw

model = uw.get_default_model()
# ... mesh, swarm, velocity field V_fn, solvers set up ...

cfl_limit = mesh.get_min_radius()
dt = 0.5

while model.tracker.time < t_end:
token = model.snapshot()
coords_before = swarm._particle_coordinates.data.copy()

# Speculative step at the current Δt.
swarm.advection(V_fn, delta_t=dt)
# ... your solves for this step ...

# CFL check.
moved = np.linalg.norm(
swarm._particle_coordinates.data - coords_before, axis=1
).max()

if moved > cfl_limit:
# Too big — discard and retry with a smaller Δt.
model.restore(token)
dt *= 0.5
continue

# Good step — commit.
model.tracker.time += dt
model.tracker.step += 1
dt = min(dt * 1.1, dt_max) # let Δt grow again
```

Because the swarm, fields, solver history *and* the tracker's `time` /
`step` are all captured, the `continue` path leaves no trace: the next
attempt starts from precisely where the failed one began.

## Guarantees and scope

```{note} What is guaranteed
- **Discarding a step leaves no trace.** A snapshot → speculative
step → restore → continue reproduces a run that never took the
speculative step *bit-for-bit*, including across MPI ranks and
through real PETSc solves.
- **Parallel-correct.** Works under MPI at any (fixed) rank count.
Restore recovers the exact global state even if the discarded step
migrated or lost particles across ranks.
```

```{warning} Limitations
- **In-memory only.** Snapshots live in process memory and are not
written to disk; they do not survive the process exiting. They are
also a full copy of model state — holding many large snapshots at
once costs memory.
- **Same rank count.** A snapshot taken on *N* MPI ranks is restored
on *N* ranks. Changing the rank count is not supported by this
mechanism (use the `write_timestep` restart path for that).
- **No mesh adaptation across a snapshot.** If the mesh is adapted
between snapshot and restore, restore refuses with a clear error
rather than corrupting state.
- **Recovery vs. a never-snapshotted run** is bit-exact for the
*discarded-step* guarantee above. Continuing after a restore that
ran a real solver may differ from a run that never snapshotted by a
small amount within solver tolerance — restore resyncs solver
fields rather than reproducing their exact internal buffers. This
does not affect the correctness of backtracking.
```

## Related

- [Parallel-Safe Scripting](parallel-computing.md) — MPI patterns;
snapshot/restore is parallel-correct at fixed rank count.
- Developer reference: the state-as-dataclass contract for adding new
snapshot-managed solver helpers lives in the developer guide.
3 changes: 3 additions & 0 deletions src/underworld3/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
restore,
)
from .state import Snapshottable, SnapshottableState
from .tracker import ModelTracker, TrackerState

__all__ = [
"CheckpointBackend",
Expand All @@ -37,4 +38,6 @@
"restore",
"Snapshottable",
"SnapshottableState",
"ModelTracker",
"TrackerState",
]
140 changes: 140 additions & 0 deletions src/underworld3/checkpoint/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Model-dwelling tracker — snapshot-managed evolving state.

``Model.tracker`` is the authoritative *record* of where a run is:
simulation time, step, dt, plus any user-registered quantities. It is
deliberately NOT something solvers depend on — solvers and DDt are
untouched, and a user need not use the tracker at all. Its one
superpower: everything living in the tracker is automatically
captured by ``Model.snapshot()`` and reverted by ``Model.restore()``,
whereas a loose Python variable (``model_time = 0.0`` in a script) is
not.

Add managed quantities by plain attribute assignment::

model.tracker.time = 0.0
model.tracker.step = 0
model.tracker.my_diagnostic = np.zeros(3)

Any attribute set on the tracker whose name does not start with an
underscore is a managed state variable: part of every snapshot,
restored exactly on rollback, with no special status in solvers and
no dataclass authoring required. Underscore-prefixed names are
internal and not managed.

``time``, ``step`` and ``dt`` are ordinary managed entries pre-seeded
with sensible defaults (``0.0`` / ``0`` / ``None``). They are
conventions, not privileged fields — consistent with the design
intent that user-added quantities are first-class.
"""

from __future__ import annotations

import copy
from dataclasses import dataclass, field

from underworld3.utilities._api_tools import uw_object

from .state import SnapshottableState


@dataclass
class TrackerState(SnapshottableState):
"""Snapshot of a :class:`ModelTracker`.

The tracker is extensible, so the State carries an open mapping
rather than fixed fields. ``time`` / ``step`` / ``dt`` are
ordinary entries in ``managed``.
"""

_schema_version: int = 1
managed: dict = field(default_factory=dict)


class ModelTracker(uw_object):
"""One per :class:`underworld3.Model`, auto-registered as a
:class:`~underworld3.checkpoint.Snapshottable` state-bearer so the
snapshot machinery captures and restores it with no extra
plumbing. See the module docstring for the user-facing contract.
"""

def __init__(self):
# _managed must exist before any public attribute assignment
# routes through __setattr__.
object.__setattr__(
self, "_managed", {"time": 0.0, "step": 0, "dt": None}
)
super().__init__() # uw_object: sets self._uw_id (underscore)

# --- attribute routing: public -> managed, underscore -> real ---

def __setattr__(self, name, value):
if name.startswith("_"):
object.__setattr__(self, name, value)
return
# Respect class-level data descriptors — notably the `state`
# property. Without this guard, `tracker.state = ...` (done by
# the snapshot machinery on restore) would be captured as a
# managed quantity instead of invoking the property setter,
# and restore would silently no-op. `state` is therefore a
# reserved name and cannot be a user-managed quantity.
cls_attr = getattr(type(self), name, None)
if hasattr(cls_attr, "__set__") or hasattr(cls_attr, "__get__"):
object.__setattr__(self, name, value)
return
self._managed[name] = value

def __getattr__(self, name):
# __getattr__ only fires when normal lookup fails, so it never
# shadows real attributes or class properties (state,
# instance_number, ...).
if name.startswith("_"):
raise AttributeError(name)
managed = object.__getattribute__(self, "_managed")
if name in managed:
return managed[name]
raise AttributeError(
f"ModelTracker has no managed quantity {name!r}; assign "
f"model.tracker.{name} = ... to create it"
)

def __delattr__(self, name):
if name.startswith("_"):
object.__delattr__(self, name)
elif name in self._managed:
del self._managed[name]
else:
raise AttributeError(name)

# --- convenience ---

def __contains__(self, name):
return name in self._managed

def keys(self):
"""Names of all managed quantities (including time/step/dt)."""
return list(self._managed.keys())

def __repr__(self):
items = ", ".join(f"{k}={v!r}" for k, v in self._managed.items())
return f"ModelTracker({items})"

# --- Snapshottable contract ---

@property
def state(self) -> TrackerState:
# Deep-copy on read so a held .state is isolated from later
# mutation even if not routed through the snapshot machinery.
return TrackerState(managed=copy.deepcopy(self._managed))

@state.setter
def state(self, s: TrackerState) -> None:
if s._schema_version != TrackerState._schema_version:
raise ValueError(
f"TrackerState schema version mismatch: snapshot "
f"{s._schema_version} vs current "
f"{TrackerState._schema_version}"
)
# Replace wholesale: restore returns to exactly the captured
# point, so a quantity added *after* the snapshot is dropped
# on restore (git-stash semantics).
object.__setattr__(self, "_managed", copy.deepcopy(s.managed))
Loading
Loading