Skip to content

v0.3.9: Stream predictions to parquet instead of buffering in memory (#172)

Choose a tag to compare

@github-actions github-actions released this 01 May 14:48
· 6 commits to main since this release
8dbe5da
* feat: stream EQ_predict outputs to disk per-batch (#170)

EQ_predict was buffering every batch's predict_step dict in memory before writing
anything to disk, so the predict job's CPU memory footprint scaled with the full
cohort (probabilities + per-batch query_embed of hidden_size * 4 B/row). For huge
evaluation task spaces / cohorts this OOMs.

Replace the gather/hstack/write flow with a Lightning BasePredictionWriter
callback (write_interval='batch') paired with trainer.predict(..., return_predictions=False).
The callback opens a pyarrow.parquet.ParquetWriter at predict-start, appends one
row group per batch in write_on_batch_end (slicing the dataset's schema_df with a
running offset for identifiers), and is closed by main()'s try/finally so a
Python-level exception still flushes the parquet footer and leaves a valid
PredictionSchema parquet covering the batches that finished. Hard kills (SIGKILL)
remain unrecoverable by design.

Output format and CLI surface are unchanged; existing tests/test_predict_cli.py
passes without modification.

* chore: ruff-format trainer.predict call onto one line

Fixes the code-quality CI check (ruff-format reformatted the multi-line
trainer.predict(...) call to a single line since it fits within the project's
line-length limit).

* refuse multi-process predict in EQ_predict

Codex review on PR #172 flagged that _StreamingPredictionWriter opens a single
output_parquet and tracks a process-local schema_df offset, so multi-process
predict (DDP / multi-device / multi-node) would have every rank open the same
path and slice schema_df from offset 0 — corrupting the parquet and mispairing
probabilities with identifiers.

Add a startup guard that raises if trainer.world_size > 1, with an error message
pointing at the resolved_config.yaml knobs to override.  Predict is single-pass
by design (per the module docstring); supporting sharded multi-rank writes is a
larger feature and out of scope here.

* test: writer unit tests + multi-process guard test (#172 review)

Address Codex's review on PR #172, which asked for writer-level tests covering
the gaps the subprocess CLI tests can't reach cheaply, plus a guard test for the
distributed/multi-device refusal path I added.

- Extract the multi-process guard into _check_single_process_trainer(trainer) so
  it's directly unit-testable without going through main().
- New tests/test_predict_logic.py with 5 unit tests:
  - single batch: identifiers + boolean_value reconstruction + probs land 1:1
  - multi-batch + partial last batch: row order preserved across the writer's
    offset boundary (the gap the CLI tests miss because their fixture fits in
    one batch under the demo model's batch_size)
  - empty cohort: zero-row schema_df produces a valid PredictionSchema parquet
  - guard accepts devices=1 / strategy=auto
  - guard raises on devices=2 / strategy=ddp (world_size > 1)

---------

Co-authored-by: gkondas <gkondas@umich.edu>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>