From 5e18f6e0f944365dc2e60d12cbbe7694ef016c1e Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Mon, 21 Apr 2025 16:57:33 +0200 Subject: [PATCH 1/7] Add solver to callback call --- src/cellflow/training/_callbacks.py | 23 ++++++++++++++++++++--- src/cellflow/training/_trainer.py | 4 ++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index fda372fc..032f5a9d 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -14,6 +14,7 @@ compute_scalar_mmd, compute_sinkhorn_div, ) +from cellflow.solvers import _genot, _otfm __all__ = [ "BaseCallback", @@ -103,6 +104,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to compute metrics @@ -112,6 +114,8 @@ def on_log_iteration( Validation data in nested dictionary format with same keys as ``predicted_data`` predicted_data Predicted data in nested dictionary format with same keys as ``validation_data`` + solver + OTFM/GENOT solver with a conditional velocity field. Returns ------- @@ -176,6 +180,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + _, ) -> dict[str, float]: """Called at each validation/log iteration to compute metrics @@ -248,6 +253,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + _, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -305,6 +311,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + _, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -443,6 +450,7 @@ def on_log_iteration( self, valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, Any]: """Called at each validation/log iteration to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks. @@ -452,6 +460,8 @@ def on_log_iteration( Validation data in nested dictionary format with same keys as ``pred_data`` pred_data Predicted data in nested dictionary format with same keys as ``valid_data`` + solver + OTFM/GENOT solver with a conditional velocity field. Returns ------- @@ -460,7 +470,7 @@ def on_log_iteration( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_log_iteration(valid_data, pred_data) + results = callback.on_log_iteration(valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: @@ -468,7 +478,12 @@ def on_log_iteration( return dict_to_log - def on_train_end(self, valid_data, pred_data) -> dict[str, Any]: + def on_train_end( + self, + valid_data: dict[str, dict[str, ArrayLike]], + pred_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, + ) -> dict[str, Any]: """Called at the end of training to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks. Parameters @@ -477,6 +492,8 @@ def on_train_end(self, valid_data, pred_data) -> dict[str, Any]: Validation data in nested dictionary format with same keys as ``pred_data`` pred_data: dict Predicted data in nested dictionary format with same keys as ``valid_data`` + solver + OTFM/GENOT solver with a conditional velocity field. Returns ------- @@ -485,7 +502,7 @@ def on_train_end(self, valid_data, pred_data) -> dict[str, Any]: dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_log_iteration(valid_data, pred_data) + results = callback.on_log_iteration(valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index e898393c..68774e69 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -122,7 +122,7 @@ def train( valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_log_iteration") # Run callbacks - metrics = crun.on_log_iteration(valid_true_data, valid_pred_data) # type: ignore[arg-type] + metrics = crun.on_log_iteration(valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] self._update_logs(metrics) # Update progress bar @@ -133,7 +133,7 @@ def train( if num_iterations > 0: valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_train_end") - metrics = crun.on_train_end(valid_true_data, valid_pred_data) + metrics = crun.on_train_end(valid_true_data, valid_pred_data, self.solver) self._update_logs(metrics) self.solver.is_trained = True From e65003157d44e75855903ef0b93e614cb9cfec1c Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Thu, 24 Apr 2025 18:09:57 +0200 Subject: [PATCH 2/7] Add typing and solver for on_train_end --- src/cellflow/training/_callbacks.py | 26 +++++++++++++++++--------- src/cellflow/training/_trainer.py | 3 ++- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 032f5a9d..e8abe219 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -115,7 +115,8 @@ def on_log_iteration( predicted_data Predicted data in nested dictionary format with same keys as ``validation_data`` solver - OTFM/GENOT solver with a conditional velocity field. + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. Returns ------- @@ -128,6 +129,7 @@ def on_train_end( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at the end of training to compute metrics @@ -137,6 +139,9 @@ def on_train_end( Validation data in nested dictionary format with same keys as ``predicted_data`` predicted_data Predicted data in nested dictionary format with same keys as ``validation_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. Returns ------- @@ -180,7 +185,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], - _, + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to compute metrics @@ -205,6 +210,7 @@ def on_train_end( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at the end of training to compute metrics @@ -215,7 +221,7 @@ def on_train_end( predicted_data Predicted data in nested dictionary format with same keys as ``validation_data`` """ - return self.on_log_iteration(validation_data, predicted_data) + return self.on_log_iteration(validation_data, predicted_data, solver) class PCADecodedMetrics(Metrics): @@ -253,7 +259,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], - _, + _: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -311,7 +317,7 @@ def on_log_iteration( self, validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], - _, + _: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -461,7 +467,8 @@ def on_log_iteration( pred_data Predicted data in nested dictionary format with same keys as ``valid_data`` solver - OTFM/GENOT solver with a conditional velocity field. + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. Returns ------- @@ -493,7 +500,8 @@ def on_train_end( pred_data: dict Predicted data in nested dictionary format with same keys as ``valid_data`` solver - OTFM/GENOT solver with a conditional velocity field. + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. Returns ------- @@ -502,10 +510,10 @@ def on_train_end( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_log_iteration(valid_data, pred_data, solver) + results = callback.on_train_end(valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: - callback.on_log_iteration(dict_to_log) # type: ignore[call-arg] + callback.on_train_end(dict_to_log) # type: ignore[call-arg] return dict_to_log diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 68774e69..526bb157 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -19,7 +19,8 @@ class CellFlowTrainer: dataloader Data sampler. solver - OTFM/GENOT solver with a conditional velocity field. + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. seed Random seed for subsampling validation data. From 5327a05a7b2b1d03c2a672e8e35caf4998cf3b6d Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Thu, 1 May 2025 12:37:09 +0200 Subject: [PATCH 3/7] Added test for custom callbacks --- tests/trainer/test_trainer.py | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4823ae41..42db39fc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,6 +7,7 @@ import cellflow from cellflow.solvers import _otfm +from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics from cellflow.utils import match_linear x_test = jnp.ones((10, 5)) * 10 @@ -15,6 +16,24 @@ vf_rng = jax.random.PRNGKey(111) +class CustomCallback(ComputationCallback): + def __init__(self): + super().__init__() + + def on_train_begin(self): + pass + + def on_log_iteration(self, validation_data, predicted_data, solver): + pred_1 = solver.predict(x_test, cond) + pred_2 = solver.predict(x_test, cond) + return {"diff": np.sum(abs(pred_1 - pred_2), axis=0)} + + def on_train_end(self, validation_data, predicted_data, solver): + pred_1 = solver.predict(x_test, cond) + pred_2 = solver.predict(x_test, cond) + return {"diff": np.sum(abs(pred_1 - pred_2), axis=0)} + + class TestTrainer: @pytest.mark.parametrize("valid_freq", [10, 1]) def test_cellflow_trainer(self, dataloader, valid_freq): @@ -35,7 +54,7 @@ def test_cellflow_trainer(self, dataloader, valid_freq): rng=vf_rng, ) - trainer = cellflow.training.CellFlowTrainer(solver=model) + trainer = CellFlowTrainer(solver=model) trainer.train( dataloader=dataloader, num_iterations=2, @@ -71,9 +90,9 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali ) metric_to_compute = "e_distance" - metrics_callback = cellflow.training.Metrics(metrics=[metric_to_compute]) + metrics_callback = Metrics(metrics=[metric_to_compute]) - trainer = cellflow.training.CellFlowTrainer(solver=model) + trainer = CellFlowTrainer(solver=model) trainer.train( dataloader=dataloader, valid_loaders=valid_loader if use_validdata else None, @@ -95,3 +114,38 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali assert out[0].shape == (1, 12) assert isinstance(out[1], np.ndarray) assert out[1].shape == (1, 12) + + def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): + opt = optax.adam(1e-3) + vf = cellflow.networks.ConditionalVelocityField( + condition_mode="stochastic", + output_dim=5, + max_combination_length=2, + condition_embedding_dim=12, + hidden_dims=(32, 32), + decoder_dims=(32, 32), + ) + solver = _otfm.OTFlowMatching( + vf=vf, + match_fn=match_linear, + flow=dynamics.ConstantNoiseFlow(0.0), + optimizer=opt, + conditions=cond, + rng=vf_rng, + ) + + custom_callback = CustomCallback() + + trainer = CellFlowTrainer(solver=solver) + trainer.train( + dataloader=dataloader, + valid_loaders=valid_loader, + num_iterations=2, + valid_freq=1, + callbacks=[custom_callback], + ) + + logs = trainer.training_logs + assert "diff" in logs + assert isinstance(logs["diff"][0], np.ndarray) + assert logs["diff"][0].shape == (5,) From 54e23862897da5166d405a11719d5554c199e2af Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Thu, 1 May 2025 16:16:21 +0200 Subject: [PATCH 4/7] Rename flow --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e2332145..9d52fb41 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -128,7 +128,7 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): solver = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - flow=dynamics.ConstantNoiseFlow(0.0), + probability_path=dynamics.ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, From facf52d1c6b9e7d65761c65316df0513f2bf07f8 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Mon, 12 May 2025 13:23:52 +0200 Subject: [PATCH 5/7] Added source data parameter + stochastic test --- src/cellflow/training/_callbacks.py | 26 +++++++++++++++++++++----- src/cellflow/training/_trainer.py | 12 +++++++----- tests/trainer/test_trainer.py | 22 ++++++++++++++-------- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index e8abe219..fa9b9921 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -102,6 +102,7 @@ def on_train_begin(self) -> Any: @abc.abstractmethod def on_log_iteration( self, + source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -110,6 +111,8 @@ def on_log_iteration( Parameters ---------- + source_data + Source data in nested dictionary format with same keys as ``validation_data`` validation_data Validation data in nested dictionary format with same keys as ``predicted_data`` predicted_data @@ -127,6 +130,7 @@ def on_log_iteration( @abc.abstractmethod def on_train_end( self, + source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -135,6 +139,8 @@ def on_train_end( Parameters ---------- + source_data + Source data in nested dictionary format with same keys as ``validation_data`` validation_data Validation data in nested dictionary format with same keys as ``predicted_data`` predicted_data @@ -183,6 +189,7 @@ def on_train_begin(self, *args: Any, **kwargs: Any) -> Any: def on_log_iteration( self, + source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -208,6 +215,7 @@ def on_log_iteration( def on_train_end( self, + source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -221,7 +229,7 @@ def on_train_end( predicted_data Predicted data in nested dictionary format with same keys as ``validation_data`` """ - return self.on_log_iteration(validation_data, predicted_data, solver) + return self.on_log_iteration(source_data, validation_data, predicted_data, solver) class PCADecodedMetrics(Metrics): @@ -257,9 +265,10 @@ def __init__( def on_log_iteration( self, + _source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], - _: _otfm.OTFlowMatching | _genot.GENOT, + _solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -315,9 +324,10 @@ def __init__( def on_log_iteration( self, + _source_data: dict[str, dict[str, ArrayLike]], validation_data: dict[str, dict[str, ArrayLike]], predicted_data: dict[str, dict[str, ArrayLike]], - _: _otfm.OTFlowMatching | _genot.GENOT, + _solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction @@ -454,6 +464,7 @@ def on_train_begin(self) -> Any: def on_log_iteration( self, + source_data: dict[str, dict[str, ArrayLike]], valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -462,6 +473,8 @@ def on_log_iteration( Parameters ---------- + source_data + Source data in nested dictionary format with same keys as ``valid_data`` valid_data Validation data in nested dictionary format with same keys as ``pred_data`` pred_data @@ -477,7 +490,7 @@ def on_log_iteration( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_log_iteration(valid_data, pred_data, solver) + results = callback.on_log_iteration(source_data, valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: @@ -487,6 +500,7 @@ def on_log_iteration( def on_train_end( self, + source_data: dict[str, dict[str, ArrayLike]], valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -495,6 +509,8 @@ def on_train_end( Parameters ---------- + source_data + Source data in nested dictionary format with same keys as ``valid_data`` valid_data: dict Validation data in nested dictionary format with same keys as ``pred_data`` pred_data: dict @@ -510,7 +526,7 @@ def on_train_end( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_train_end(valid_data, pred_data, solver) + results = callback.on_train_end(source_data, valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 526bb157..aace5464 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -52,6 +52,7 @@ def _validation_step( """Compute predictions for validation data.""" # TODO: Sample fixed number of conditions to validate on + valid_source_data: dict[str, dict[str, ArrayLike]] = {} valid_pred_data: dict[str, dict[str, ArrayLike]] = {} valid_true_data: dict[str, dict[str, ArrayLike]] = {} for val_key, vdl in val_data.items(): @@ -59,10 +60,11 @@ def _validation_step( src = batch["source"] condition = batch.get("condition", None) true_tgt = batch["target"] + valid_source_data[val_key] = src valid_pred_data[val_key] = jax.tree.map(self.solver.predict, src, condition) valid_true_data[val_key] = true_tgt - return valid_true_data, valid_pred_data + return valid_source_data, valid_true_data, valid_pred_data def _update_logs(self, logs: dict[str, Any]) -> None: """Update training logs.""" @@ -120,10 +122,10 @@ def train( if ((it - 1) % valid_freq == 0) and (it > 1): # Get predictions from validation data - valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_log_iteration") + valid_source_data, valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_log_iteration") # Run callbacks - metrics = crun.on_log_iteration(valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] + metrics = crun.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] self._update_logs(metrics) # Update progress bar @@ -133,8 +135,8 @@ def train( pbar.set_postfix(postfix_dict) if num_iterations > 0: - valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_train_end") - metrics = crun.on_train_end(valid_true_data, valid_pred_data, self.solver) + valid_source_data, valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_train_end") + metrics = crun.on_train_end(valid_source_data, valid_true_data, valid_pred_data, self.solver) self._update_logs(metrics) self.solver.is_trained = True diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9d52fb41..25792464 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,23 +14,28 @@ t_test = jnp.ones((10, 1)) cond = {"pert1": jnp.ones((1, 2, 3))} vf_rng = jax.random.PRNGKey(111) +callback_rng = jax.random.PRNGKey(222) class CustomCallback(ComputationCallback): - def __init__(self): + def __init__(self, rng): super().__init__() + self.rng = rng + self.rng_1, self.rng_2, self.rng = jax.random.split(self.rng, 3) def on_train_begin(self): pass - def on_log_iteration(self, validation_data, predicted_data, solver): - pred_1 = solver.predict(x_test, cond) - pred_2 = solver.predict(x_test, cond) + def on_log_iteration(self, source_data, validation_data, predicted_data, solver): + source_array = source_data["val"]["my_naming_of_pert"] + pred_1 = solver.predict(source_array, cond, rng=self.rng_1) + pred_2 = solver.predict(source_array, cond, rng=self.rng_2) return {"diff": np.sum(abs(pred_1 - pred_2), axis=0)} - def on_train_end(self, validation_data, predicted_data, solver): - pred_1 = solver.predict(x_test, cond) - pred_2 = solver.predict(x_test, cond) + def on_train_end(self, source_data, validation_data, predicted_data, solver): + source_array = source_data["val"]["my_naming_of_pert"] + pred_1 = solver.predict(source_array, cond, rng=self.rng_1) + pred_2 = solver.predict(source_array, cond, rng=self.rng_2) return {"diff": np.sum(abs(pred_1 - pred_2), axis=0)} @@ -134,7 +139,7 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): rng=vf_rng, ) - custom_callback = CustomCallback() + custom_callback = CustomCallback(rng=callback_rng) trainer = CellFlowTrainer(solver=solver) trainer.train( @@ -149,3 +154,4 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): assert "diff" in logs assert isinstance(logs["diff"][0], np.ndarray) assert logs["diff"][0].shape == (5,) + assert 0 < np.mean(logs["diff"][0]) < 10 From f1e87043d9153304043d5b6e6e0c186c68fa8720 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 May 2025 11:24:11 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/cellflow/training/_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index aace5464..98a594c1 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -122,7 +122,9 @@ def train( if ((it - 1) % valid_freq == 0) and (it > 1): # Get predictions from validation data - valid_source_data, valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_log_iteration") + valid_source_data, valid_true_data, valid_pred_data = self._validation_step( + valid_loaders, mode="on_log_iteration" + ) # Run callbacks metrics = crun.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] @@ -135,7 +137,9 @@ def train( pbar.set_postfix(postfix_dict) if num_iterations > 0: - valid_source_data, valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_train_end") + valid_source_data, valid_true_data, valid_pred_data = self._validation_step( + valid_loaders, mode="on_train_end" + ) metrics = crun.on_train_end(valid_source_data, valid_true_data, valid_pred_data, self.solver) self._update_logs(metrics) From b6c0bdbdee8d348d79df579dc52cad002126a404 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Sat, 17 May 2025 17:42:27 +0200 Subject: [PATCH 7/7] fix docs --- src/cellflow/training/_callbacks.py | 186 +++++++++++++++++----------- 1 file changed, 112 insertions(+), 74 deletions(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index fa9b9921..e96f9711 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -102,21 +102,21 @@ def on_train_begin(self) -> Any: @abc.abstractmethod def on_log_iteration( self, - source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to compute metrics Parameters ---------- - source_data - Source data in nested dictionary format with same keys as ``validation_data`` - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` solver with a conditional velocity field. @@ -130,21 +130,21 @@ def on_log_iteration( @abc.abstractmethod def on_train_end( self, - source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at the end of training to compute metrics Parameters ---------- - source_data - Source data in nested dictionary format with same keys as ``validation_data`` - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` solver with a conditional velocity field. @@ -189,24 +189,33 @@ def on_train_begin(self, *args: Any, **kwargs: Any) -> Any: def on_log_iteration( self, - source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to compute metrics Parameters ---------- - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``valid_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. + + Returns + ------- + Computed metrics between the true validation data and predicted validation data as a dictionary """ metrics = {} for metric in self.metrics: - for k in validation_data.keys(): - out = jtu.tree_map(metric_to_func[metric], validation_data[k], predicted_data[k]) + for k in valid_true_data.keys(): + out = jtu.tree_map(metric_to_func[metric], valid_true_data[k], valid_pred_data[k]) out_flattened = jt.flatten(out)[0] for agg_fn in self.metric_aggregation: metrics[f"{k}_{metric}_{agg_fn}"] = agg_fn_to_func[agg_fn](out_flattened) @@ -215,21 +224,30 @@ def on_log_iteration( def on_train_end( self, - source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at the end of training to compute metrics Parameters ---------- - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. + + Returns + ------- + Computed metrics between the true validation data and predicted validation data as a dictionary """ - return self.on_log_iteration(source_data, validation_data, predicted_data, solver) + return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver) class PCADecodedMetrics(Metrics): @@ -265,24 +283,34 @@ def __init__( def on_log_iteration( self, - _source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], - _solver: _otfm.OTFlowMatching | _genot.GENOT, + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction Parameters ---------- - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. + + Returns + ------- + Computed metrics between the reconstructed true validation data and reconstructed + predicted validation data as a dictionary. """ - validation_data_decoded = jtu.tree_map(self.reconstruct_data, validation_data) - predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data) + valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data) + predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) - metrics = super().on_log_iteration(validation_data_decoded, predicted_data_decoded) + metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded) metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} return metrics @@ -324,27 +352,37 @@ def __init__( def on_log_iteration( self, - _source_data: dict[str, dict[str, ArrayLike]], - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], - _solver: _otfm.OTFlowMatching | _genot.GENOT, + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _otfm.OTFlowMatching | _genot.GENOT, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction Parameters ---------- - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. + + Returns + ------- + Computed metrics between the reconstructed true validation data and reconstructed + predicted validation data as a dictionary. """ - validation_data_in_anndata = jtu.tree_map(self._create_anndata, validation_data) - predicted_data_in_anndata = jtu.tree_map(self._create_anndata, predicted_data) + valid_true_data_in_anndata = jtu.tree_map(self._create_anndata, valid_true_data) + predicted_data_in_anndata = jtu.tree_map(self._create_anndata, valid_pred_data) - validation_data_decoded = jtu.tree_map(self.reconstruct_data, validation_data_in_anndata) + valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data_in_anndata) predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data_in_anndata) - metrics = super().on_log_iteration(validation_data_decoded, predicted_data_decoded) + metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded) metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} return metrics @@ -464,7 +502,7 @@ def on_train_begin(self) -> Any: def on_log_iteration( self, - source_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -473,12 +511,12 @@ def on_log_iteration( Parameters ---------- - source_data - Source data in nested dictionary format with same keys as ``valid_data`` - valid_data - Validation data in nested dictionary format with same keys as ``pred_data`` - pred_data - Predicted data in nested dictionary format with same keys as ``valid_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` solver with a conditional velocity field. @@ -490,7 +528,7 @@ def on_log_iteration( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_log_iteration(source_data, valid_data, pred_data, solver) + results = callback.on_log_iteration(valid_source_data, valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: @@ -500,7 +538,7 @@ def on_log_iteration( def on_train_end( self, - source_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, @@ -509,12 +547,12 @@ def on_train_end( Parameters ---------- - source_data - Source data in nested dictionary format with same keys as ``valid_data`` - valid_data: dict - Validation data in nested dictionary format with same keys as ``pred_data`` - pred_data: dict - Predicted data in nested dictionary format with same keys as ``valid_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` solver with a conditional velocity field. @@ -526,7 +564,7 @@ def on_train_end( dict_to_log: dict[str, Any] = {} for callback in self.computation_callbacks: - results = callback.on_train_end(source_data, valid_data, pred_data, solver) + results = callback.on_train_end(valid_source_data, valid_data, pred_data, solver) dict_to_log.update(results) for callback in self.logging_callbacks: