Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions psi/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,29 @@ def maybe_reload(self) -> bool:
is ~1μs on modern kernels; actual reload happens only when setup
has finished a write.

If the backing file has vanished after we previously loaded it, drop
the in-memory entries instead of silently serving stale values — a
wiped ``cache.enc`` should not produce a cache that pretends entries
still exist. The next provider lookup repopulates as needed.

Returns:
True if a reload happened, False otherwise.
True if entries were reloaded or cleared, False otherwise.
"""
try:
current_mtime = self._path.stat().st_mtime_ns
except FileNotFoundError:
return False
if self._mtime_ns == 0:
return False
with self._lock:
logger.warning(
"Cache file {} disappeared after previous load; clearing "
"{} in-memory entries to avoid serving stale values.",
self._path,
len(self._entries),
)
self._entries = {}
self._mtime_ns = 0
return True
if current_mtime == self._mtime_ns:
return False
self.load()
Expand Down
49 changes: 49 additions & 0 deletions psi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ def cache_init(
"Default: <config_dir>/cache.key.",
),
] = None,
force: Annotated[
bool,
typer.Option(
"--force",
help="Overwrite an existing cache file. The previous file is "
"rotated to '<path>.bak-<UTC timestamp>' before the new one is written.",
),
] = False,
config: ConfigOption = None,
) -> None:
"""Provision the cache encryption key and write an empty cache file."""
Expand All @@ -266,6 +274,8 @@ def cache_init(
cache_path = settings.cache.resolve_path(settings.state_dir)
cache_path.parent.mkdir(parents=True, exist_ok=True)

bak_path = _guard_existing_cache(cache_path, force=force)

if backend == "tpm":
raw_key = os.urandom(32)
target_key_path = key_path or (settings.config_dir / "cache.key")
Expand Down Expand Up @@ -298,6 +308,8 @@ def cache_init(

cache = Cache(cache_path, TpmBackend(key=raw_key))
cache.save()
if bak_path is not None:
console.print(f"Previous cache rotated to [bold]{bak_path}[/bold]", highlight=False)
console.print(
f"Sealed TPM key → [bold]{target_key_path}[/bold]\n"
f"Empty cache → [bold]{cache_path}[/bold]\n"
Expand All @@ -321,13 +333,50 @@ def cache_init(
cache.save()
finally:
hsm_backend.close()
if bak_path is not None:
console.print(f"Previous cache rotated to [bold]{bak_path}[/bold]", highlight=False)
console.print(
f"Empty cache → [bold]{cache_path}[/bold]\n"
"Cache will be unsealed via PKCS#11 at 'psi serve' startup.",
highlight=False,
)


def _guard_existing_cache(cache_path: Path, *, force: bool) -> Path | None:
"""Refuse to clobber an existing cache file unless ``force`` is set.

When ``force`` is set and the file exists, rotate it to
``<name>.bak-<UTC timestamp>`` so a misuse can be undone. The header
is not decrypted — wrapping in a backup is cheap and makes the
"I just nuked the populated cache" path recoverable.

Returns:
The path the previous file was rotated to, or ``None`` if no file
existed.

Raises:
ConfigError: If the file exists and ``force`` is False.
"""
from datetime import UTC, datetime

from psi.errors import ConfigError

if not cache_path.exists():
return None
if not force:
msg = (
f"Cache file already exists at {cache_path}. Re-running 'psi cache "
"init' would replace it with an empty cache and the existing "
"entries would be unrecoverable. Pass --force to overwrite; the "
"previous file is rotated to a '.bak-<timestamp>' sibling first."
)
raise ConfigError(msg)
timestamp = datetime.now(tz=UTC).strftime("%Y%m%dT%H%M%SZ")
bak = cache_path.with_name(f"{cache_path.name}.bak-{timestamp}")
cache_path.rename(bak)
return bak


@cache_app.command(name="refresh")
def cache_refresh(config: ConfigOption = None) -> None:
"""Re-run setup to refresh every cached secret from providers."""
Expand Down
9 changes: 9 additions & 0 deletions psi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ class SecretNotFoundError(PsiError):

class DriftDetectedError(PsiError):
"""Podman secret state diverged from the fetch — drop-ins are incomplete."""


class OrphanedSecretsError(PsiError):
"""One or more Podman shell secrets have no backing mapping file.

Lookup will return 404 for these secrets and any container that depends
on them will fail to start. Distinct from drift — drift is "secret in
Podman, not in fetch", orphan is "secret in Podman, no mapping on disk".
"""
74 changes: 69 additions & 5 deletions psi/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import httpx
from loguru import logger

from psi.errors import DriftDetectedError, ProviderError
from psi.errors import DriftDetectedError, OrphanedSecretsError, ProviderError
from psi.systemd import daemon_reload

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,10 +42,15 @@ def run_setup(
provider: If set, only process workloads using this provider.

Raises:
DriftDetectedError: when one or more Podman secrets under ``<workload>--*``
are missing from the current fetch. Drop-ins are still written
and systemd is still reloaded — the error fires at the end so
the caller (and the setup systemd unit) sees a non-zero exit.
OrphanedSecretsError: when one or more Podman shell secrets have no
backing mapping file in ``state_dir``. Lookups for these will
return 404 and consuming containers will fail to start. Takes
precedence over drift since the failure mode is harder.
DriftDetectedError: when one or more Podman secrets under
``<workload>--*`` are missing from the current fetch. Drop-ins
are still written and systemd is still reloaded — the error
fires at the end so the caller (and the setup systemd unit)
sees a non-zero exit.
"""
settings.state_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -80,8 +85,34 @@ def run_setup(

logger.info("Reloading systemd...")
daemon_reload(settings.scope)

orphans = _check_orphans(settings)
for orphan in orphans:
logger.warning(
"Orphan: Podman secret '{}' has no mapping file in {} — lookups "
"will return 404 and any container that uses this secret will "
"fail to start. Re-create the mapping (e.g. 'psi nitrokeyhsm "
"store {}') or remove the stale entry with 'podman secret rm {}'.",
orphan,
settings.state_dir,
orphan,
orphan,
)

logger.info("Setup complete.")

if orphans:
msg = (
f"Orphaned secrets detected: {len(orphans)} Podman shell "
"secret(s) have no backing mapping file. Lookups will return "
"404 and consuming containers will fail to start. Re-create "
"the mappings or remove the stale Podman secrets. Run 'psi "
"setup --dry-run' for per-secret details."
)
if drift:
msg += f" Additionally, {len(drift)} drift entry/entries — see warnings above."
raise OrphanedSecretsError(msg)

if drift:
msg = (
f"Drift detected: {len(drift)} Podman secret(s) not present in "
Expand Down Expand Up @@ -447,6 +478,39 @@ def _workload_podman_names(workload_name: str, secrets: list[dict]) -> set[str]:
}


def _check_orphans(settings: PsiSettings) -> list[str]:
"""Return Podman shell secrets with no backing mapping file in ``state_dir``.

These are the failure mode the regular setup path otherwise wouldn't
surface: a Podman secret created via ``shell`` driver whose
corresponding ``state_dir/<SECRET_ID>`` file is missing means
``psi serve`` will reply 404 for any lookup and the consuming
container will fail to start. ``setup`` itself can succeed (it only
re-registers Infisical-provider secrets), so the regression goes
unnoticed until something restarts.

Returns:
Sorted list of orphaned Podman secret names. Returns an empty list
if the Podman API is unreachable; the primary fetch-and-register
path would already have failed loudly in that case.
"""
try:
secrets = _list_podman_shell_secrets()
except httpx.HTTPError as e:
logger.warning("Cannot list Podman secrets to check for orphans: {}", e)
return []
orphans: list[str] = []
for secret in secrets:
spec = secret.get("Spec", {})
name = spec.get("Name", "")
secret_id = secret.get("ID", "")
if not name or not secret_id:
continue
if not (settings.state_dir / secret_id).exists():
orphans.append(name)
return sorted(orphans)


def _check_workload_drift(
workload_name: str,
merged: dict[str, str],
Expand Down
46 changes: 46 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,17 @@ def test_missing_file_returns_false_no_crash(
cache = Cache(tmp_path / "does-not-exist.enc", backend)
assert cache.maybe_reload() is False

def test_clears_entries_when_file_vanishes_after_load(self, cache: Cache) -> None:
"""A wiped cache.enc must not leave the in-memory dict serving stale values."""
cache.set("k", b"v")
cache.save()
assert cache.get("k") == b"v"

cache.path.unlink()

assert cache.maybe_reload() is True
assert cache.get("k") is None


class TestLegacyV1PayloadDiscarded:
def test_v1_payload_is_treated_as_empty_with_fresh_hmac_key(
Expand All @@ -293,3 +304,38 @@ def test_v1_payload_is_treated_as_empty_with_fresh_hmac_key(
fresh.load()
assert len(fresh) == 0
assert fresh.get("abc123hex") is None


class TestGuardExistingCache:
"""`psi cache init` must not silently clobber a populated cache file."""

def test_no_existing_file_returns_none(self, tmp_path: Path) -> None:
from psi.cli import _guard_existing_cache

result = _guard_existing_cache(tmp_path / "cache.enc", force=False)
assert result is None

def test_existing_file_without_force_raises(self, tmp_path: Path) -> None:
from psi.cli import _guard_existing_cache
from psi.errors import ConfigError

path = tmp_path / "cache.enc"
path.write_bytes(b"existing payload")

with pytest.raises(ConfigError, match="already exists"):
_guard_existing_cache(path, force=False)
assert path.exists(), "guard must not delete the existing file when refusing"

def test_existing_file_with_force_rotates_to_bak(self, tmp_path: Path) -> None:
from psi.cli import _guard_existing_cache

path = tmp_path / "cache.enc"
original = b"existing payload"
path.write_bytes(original)

bak = _guard_existing_cache(path, force=True)
assert bak is not None
assert bak.exists()
assert bak.read_bytes() == original
assert not path.exists()
assert bak.name.startswith("cache.enc.bak-")
Loading
Loading