You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
* feat: stream EQ_predict outputs to disk per-batch (#170)
EQ_predict was buffering every batch's predict_step dict in memory before writing
anything to disk, so the predict job's CPU memory footprint scaled with the full
cohort (probabilities + per-batch query_embed of hidden_size * 4 B/row). For huge
evaluation task spaces / cohorts this OOMs.
Replace the gather/hstack/write flow with a Lightning BasePredictionWriter
callback (write_interval='batch') paired with trainer.predict(..., return_predictions=False).
The callback opens a pyarrow.parquet.ParquetWriter at predict-start, appends one
row group per batch in write_on_batch_end (slicing the dataset's schema_df with a
running offset for identifiers), and is closed by main()'s try/finally so a
Python-level exception still flushes the parquet footer and leaves a valid
PredictionSchema parquet covering the batches that finished. Hard kills (SIGKILL)
remain unrecoverable by design.
Output format and CLI surface are unchanged; existing tests/test_predict_cli.py
passes without modification.
* chore: ruff-format trainer.predict call onto one line
Fixes the code-quality CI check (ruff-format reformatted the multi-line
trainer.predict(...) call to a single line since it fits within the project's
line-length limit).
* refuse multi-process predict in EQ_predict
Codex review on PR #172 flagged that _StreamingPredictionWriter opens a single
output_parquet and tracks a process-local schema_df offset, so multi-process
predict (DDP / multi-device / multi-node) would have every rank open the same
path and slice schema_df from offset 0 — corrupting the parquet and mispairing
probabilities with identifiers.
Add a startup guard that raises if trainer.world_size > 1, with an error message
pointing at the resolved_config.yaml knobs to override. Predict is single-pass
by design (per the module docstring); supporting sharded multi-rank writes is a
larger feature and out of scope here.
* test: writer unit tests + multi-process guard test (#172 review)
Address Codex's review on PR #172, which asked for writer-level tests covering
the gaps the subprocess CLI tests can't reach cheaply, plus a guard test for the
distributed/multi-device refusal path I added.
- Extract the multi-process guard into _check_single_process_trainer(trainer) so
it's directly unit-testable without going through main().
- New tests/test_predict_logic.py with 5 unit tests:
- single batch: identifiers + boolean_value reconstruction + probs land 1:1
- multi-batch + partial last batch: row order preserved across the writer's
offset boundary (the gap the CLI tests miss because their fixture fits in
one batch under the demo model's batch_size)
- empty cohort: zero-row schema_df produces a valid PredictionSchema parquet
- guard accepts devices=1 / strategy=auto
- guard raises on devices=2 / strategy=ddp (world_size > 1)
---------
Co-authored-by: gkondas <gkondas@umich.edu>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>