Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 137 additions & 58 deletions src/cellflow/training/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_scalar_mmd,
compute_sinkhorn_div,
)
from cellflow.solvers import _genot, _otfm

__all__ = [
"BaseCallback",
Expand Down Expand Up @@ -101,17 +102,24 @@ def on_train_begin(self) -> Any:
@abc.abstractmethod
def on_log_iteration(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at each validation/log iteration to compute metrics

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``validation_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Expand All @@ -122,17 +130,24 @@ def on_log_iteration(
@abc.abstractmethod
def on_train_end(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at the end of training to compute metrics

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``validation_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Expand Down Expand Up @@ -174,22 +189,33 @@ def on_train_begin(self, *args: Any, **kwargs: Any) -> Any:

def on_log_iteration(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at each validation/log iteration to compute metrics

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``valid_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Computed metrics between the true validation data and predicted validation data as a dictionary
"""
metrics = {}
for metric in self.metrics:
for k in validation_data.keys():
out = jtu.tree_map(metric_to_func[metric], validation_data[k], predicted_data[k])
for k in valid_true_data.keys():
out = jtu.tree_map(metric_to_func[metric], valid_true_data[k], valid_pred_data[k])
out_flattened = jt.flatten(out)[0]
for agg_fn in self.metric_aggregation:
metrics[f"{k}_{metric}_{agg_fn}"] = agg_fn_to_func[agg_fn](out_flattened)
Expand All @@ -198,19 +224,30 @@ def on_log_iteration(

def on_train_end(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at the end of training to compute metrics

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``validation_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Computed metrics between the true validation data and predicted validation data as a dictionary
"""
return self.on_log_iteration(validation_data, predicted_data)
return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)


class PCADecodedMetrics(Metrics):
Expand Down Expand Up @@ -246,22 +283,34 @@ def __init__(

def on_log_iteration(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``validation_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Computed metrics between the reconstructed true validation data and reconstructed
predicted validation data as a dictionary.
"""
validation_data_decoded = jtu.tree_map(self.reconstruct_data, validation_data)
predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data)
valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data)
predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data)

metrics = super().on_log_iteration(validation_data_decoded, predicted_data_decoded)
metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded)
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
return metrics

Expand Down Expand Up @@ -303,25 +352,37 @@ def __init__(

def on_log_iteration(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, float]:
"""Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction

Parameters
----------
validation_data
Validation data in nested dictionary format with same keys as ``predicted_data``
predicted_data
Predicted data in nested dictionary format with same keys as ``validation_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Computed metrics between the reconstructed true validation data and reconstructed
predicted validation data as a dictionary.
"""
validation_data_in_anndata = jtu.tree_map(self._create_anndata, validation_data)
predicted_data_in_anndata = jtu.tree_map(self._create_anndata, predicted_data)
valid_true_data_in_anndata = jtu.tree_map(self._create_anndata, valid_true_data)
predicted_data_in_anndata = jtu.tree_map(self._create_anndata, valid_pred_data)

validation_data_decoded = jtu.tree_map(self.reconstruct_data, validation_data_in_anndata)
valid_true_data_decoded = jtu.tree_map(self.reconstruct_data, valid_true_data_in_anndata)
predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data_in_anndata)

metrics = super().on_log_iteration(validation_data_decoded, predicted_data_decoded)
metrics = super().on_log_iteration(valid_true_data_decoded, predicted_data_decoded)
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
return metrics

Expand Down Expand Up @@ -441,17 +502,24 @@ def on_train_begin(self) -> Any:

def on_log_iteration(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_data: dict[str, dict[str, ArrayLike]],
pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, Any]:
"""Called at each validation/log iteration to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks.

Parameters
----------
valid_data
Validation data in nested dictionary format with same keys as ``pred_data``
pred_data
Predicted data in nested dictionary format with same keys as ``valid_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Expand All @@ -460,23 +528,34 @@ def on_log_iteration(
dict_to_log: dict[str, Any] = {}

for callback in self.computation_callbacks:
results = callback.on_log_iteration(valid_data, pred_data)
results = callback.on_log_iteration(valid_source_data, valid_data, pred_data, solver)
dict_to_log.update(results)

for callback in self.logging_callbacks:
callback.on_log_iteration(dict_to_log) # type: ignore[call-arg]

return dict_to_log

def on_train_end(self, valid_data, pred_data) -> dict[str, Any]:
def on_train_end(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_data: dict[str, dict[str, ArrayLike]],
pred_data: dict[str, dict[str, ArrayLike]],
solver: _otfm.OTFlowMatching | _genot.GENOT,
) -> dict[str, Any]:
"""Called at the end of training to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks.

Parameters
----------
valid_data: dict
Validation data in nested dictionary format with same keys as ``pred_data``
pred_data: dict
Predicted data in nested dictionary format with same keys as ``valid_data``
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.

Returns
-------
Expand All @@ -485,10 +564,10 @@ def on_train_end(self, valid_data, pred_data) -> dict[str, Any]:
dict_to_log: dict[str, Any] = {}

for callback in self.computation_callbacks:
results = callback.on_log_iteration(valid_data, pred_data)
results = callback.on_train_end(valid_source_data, valid_data, pred_data, solver)
dict_to_log.update(results)

for callback in self.logging_callbacks:
callback.on_log_iteration(dict_to_log) # type: ignore[call-arg]
callback.on_train_end(dict_to_log) # type: ignore[call-arg]

return dict_to_log
19 changes: 13 additions & 6 deletions src/cellflow/training/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class CellFlowTrainer:
dataloader
Data sampler.
solver
OTFM/GENOT solver with a conditional velocity field.
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.
seed
Random seed for subsampling validation data.

Expand Down Expand Up @@ -51,17 +52,19 @@ def _validation_step(
"""Compute predictions for validation data."""
# TODO: Sample fixed number of conditions to validate on

valid_source_data: dict[str, dict[str, ArrayLike]] = {}
valid_pred_data: dict[str, dict[str, ArrayLike]] = {}
valid_true_data: dict[str, dict[str, ArrayLike]] = {}
for val_key, vdl in val_data.items():
batch = vdl.sample(mode=mode)
src = batch["source"]
condition = batch.get("condition", None)
true_tgt = batch["target"]
valid_source_data[val_key] = src
valid_pred_data[val_key] = jax.tree.map(self.solver.predict, src, condition)
valid_true_data[val_key] = true_tgt

return valid_true_data, valid_pred_data
return valid_source_data, valid_true_data, valid_pred_data

def _update_logs(self, logs: dict[str, Any]) -> None:
"""Update training logs."""
Expand Down Expand Up @@ -119,10 +122,12 @@ def train(

if ((it - 1) % valid_freq == 0) and (it > 1):
# Get predictions from validation data
valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_log_iteration")
valid_source_data, valid_true_data, valid_pred_data = self._validation_step(
valid_loaders, mode="on_log_iteration"
)

# Run callbacks
metrics = crun.on_log_iteration(valid_true_data, valid_pred_data) # type: ignore[arg-type]
metrics = crun.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type]
self._update_logs(metrics)

# Update progress bar
Expand All @@ -132,8 +137,10 @@ def train(
pbar.set_postfix(postfix_dict)

if num_iterations > 0:
valid_true_data, valid_pred_data = self._validation_step(valid_loaders, mode="on_train_end")
metrics = crun.on_train_end(valid_true_data, valid_pred_data)
valid_source_data, valid_true_data, valid_pred_data = self._validation_step(
valid_loaders, mode="on_train_end"
)
metrics = crun.on_train_end(valid_source_data, valid_true_data, valid_pred_data, self.solver)
self._update_logs(metrics)

self.solver.is_trained = True
Expand Down
Loading
Loading