Skip to content

Commit

Permalink
Rename TransportMapKernel -> RealTimeKernel (#1080)
Browse files Browse the repository at this point in the history
* Rename `TransportMapKernel` -> `RealTimeKernel`

* Enable testing on 3.11

* Update submodule
  • Loading branch information
michalk8 committed Jun 28, 2023
1 parent 9ef5a1f commit dad776a
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ['3.8', '3.10'] # , '3.11']
python: ['3.8', '3.10', '3.11']
slepc: [noslepc]
include:
- os: macos-latest
Expand Down
4 changes: 2 additions & 2 deletions docs/about/version2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ Important changes in version 2
* Removal of the `external API`: anyone wishing to contribute to CellRank can do this now directly via :mod:`~cellrank.kernels` and
:mod:`~cellrank.estimators`. We welcome any contribution to CellRank, see our :doc:`contribution guide <../contributing>`, and feel free to
get in touch via an `issue <https://github.com/theislab/cellrank/issues/new/choose>`_ or `email <mailto:info@cellrank.org>`_.
* Replacement of the old `WOTKernel` with a new :class:`~cellrank.kernels.TransportMapKernel`: this is CellRank's interface
* Replacement of the old `WOTKernel` with a new :class:`~cellrank.kernels.RealTimeKernel`: this is CellRank's interface
with `moscot <https://moscot-tools.org>`_, enabling us to analyze large-scale time-course studies with additional spatial or lineage readout :cite:`klein:23,lange:23`. In addition,
the :class:`~cellrank.kernels.TransportMapKernel` interfaces with `Waddington-OT <https://broadinstitute.github.io/wot/>`_ :cite:`schiebinger:19`.
the :class:`~cellrank.kernels.RealTimeKernel` interfaces with `Waddington-OT <https://broadinstitute.github.io/wot/>`_ :cite:`schiebinger:19`.

There are many more changes and improvements in CellRank 2. For example, the computation of fate probabilities is 30x faster compared
to version 1, we fixed many bugs, and improved and extended our documentation and :doc:`tutorials <../notebooks/tutorials/index>`.
2 changes: 1 addition & 1 deletion docs/api/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ takes in multi-view single-cell data and outputs a cell-cell transition matrix.
kernels.ConnectivityKernel
kernels.PseudotimeKernel
kernels.CytoTRACEKernel
kernels.TransportMapKernel
kernels.RealTimeKernel
kernels.PrecomputedKernel
2 changes: 1 addition & 1 deletion docs/notebooks
4 changes: 2 additions & 2 deletions src/cellrank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cellrank.kernels._experimental_time_kernel import ExperimentalTimeKernel
from cellrank.kernels._precomputed_kernel import PrecomputedKernel
from cellrank.kernels._pseudotime_kernel import PseudotimeKernel
from cellrank.kernels._transport_map_kernel import TransportMapKernel
from cellrank.kernels._real_time_kernel import RealTimeKernel
from cellrank.kernels._velocity_kernel import VelocityKernel

__all__ = [
Expand All @@ -16,6 +16,6 @@
"ExperimentalTimeKernel",
"PrecomputedKernel",
"PseudotimeKernel",
"TransportMapKernel",
"RealTimeKernel",
"VelocityKernel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cellrank.kernels._base_kernel import UnidirectionalKernel
from cellrank.settings import settings

__all__ = ["TransportMapKernel"]
__all__ = ["RealTimeKernel"]

if TYPE_CHECKING:
from moscot.problems.spatiotemporal import SpatioTemporalProblem
Expand All @@ -53,7 +53,7 @@ class SelfTransitions(ModeEnum):

# TODO(michalk8): subclass the `ExperimentalTimeKernel`
@d.dedent
class TransportMapKernel(UnidirectionalKernel):
class RealTimeKernel(UnidirectionalKernel):
"""Kernel which computes transition matrix using optimal transport couplings.
This class should be constructed using either:
Expand Down Expand Up @@ -146,7 +146,7 @@ def compute_transition_matrix(
conn_weight: Optional[float] = None,
conn_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
) -> "TransportMapKernel":
) -> "RealTimeKernel":
"""Compute transition matrix from optimal transport couplings.
Parameters
Expand Down Expand Up @@ -224,7 +224,7 @@ def from_moscot(
sparsify_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
copy: bool = False,
**kwargs: Any,
) -> "TransportMapKernel":
) -> "RealTimeKernel":
"""Construct the kernel from :mod:`moscot` :cite:`klein:23`.
Parameters
Expand All @@ -240,7 +240,7 @@ def from_moscot(
copy
Whether to copy the underlying arrays. Note that :class:`jax arrays <jax.Array>` are always copied.
kwargs
Keyword arguments for :class:`~cellrank.kernels.TransportMapKernel`.
Keyword arguments for :class:`~cellrank.kernels.RealTimeKernel`.
Returns
-------
Expand All @@ -259,8 +259,8 @@ def from_moscot(
problem = mt.problems.TemporalProblem(adata)
problem = problem.prepare(time_key="day").solve()
tmk = cr.kernels.TransportMapKernel.from_moscot(problem)
tmk = tmk.compute_transition_matrix()
rtk = cr.kernels.RealTimeKernel.from_moscot(problem)
rtk = rtk.compute_transition_matrix()
"""
from moscot.utils.subset_policy import SequentialPolicy, TriangularPolicy

Expand Down Expand Up @@ -305,7 +305,7 @@ def from_wot(
path: Union[str, pathlib.Path],
time_key: str,
**kwargs: Any,
) -> "TransportMapKernel":
) -> "RealTimeKernel":
"""Construct the kernel from Waddington-OT :cite:`schiebinger:19`.
Parameters
Expand All @@ -317,7 +317,7 @@ def from_wot(
time_key
Key in :attr:`~anndata.AnnData.obs` containing the experimental time.
kwargs
Keyword arguments for :class:`~cellrank.kernels.TransportMapKernel`.
Keyword arguments for :class:`~cellrank.kernels.RealTimeKernel`.
Returns
-------
Expand All @@ -336,8 +336,8 @@ def from_wot(
ot_model = wot.ot.OTModel(adata, day_field="day")
ot_model.compute_all_transport_maps(tmap_out="tmaps/")
tmk = cr.kernels.TransportMapKernel.from_wot(adata, path="tmaps/", time_key="day")
tmk = tmk.compute_transition_matrix()
rtk = cr.kernels.RealTimeKernel.from_wot(adata, path="tmaps/", time_key="day")
rtk = rtk.compute_transition_matrix()
"""
path = pathlib.Path(path)
dtype = type(adata.obs[time_key].iloc[0])
Expand Down
16 changes: 8 additions & 8 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CytoTRACEKernel,
PrecomputedKernel,
PseudotimeKernel,
TransportMapKernel,
RealTimeKernel,
VelocityKernel,
)
from cellrank.kernels._base_kernel import (
Expand Down Expand Up @@ -1124,14 +1124,14 @@ def test_inversion_bwd(self, adata: AnnData):
np.testing.assert_array_equal(np.max(pt) - pt, k.pseudotime)


class TestTransportMapKernel:
class TestRealTimeKernel:
@pytest.mark.parametrize("policy", ["sequential", "triu"])
@pytest.mark.parametrize("n", [5, 7])
def test_default_initialization(self, adata: AnnData, n: int, policy: str):
adata.obs["exp_time"] = pd.cut(adata.obs["dpt_pseudotime"], n)
cats = adata.obs["exp_time"].cat.categories

tmk = TransportMapKernel(adata, time_key="exp_time", policy=policy)
tmk = RealTimeKernel(adata, time_key="exp_time", policy=policy)

if policy == "sequential":
assert tmk.couplings == {key: None for key in zip(cats[:-1], cats[1:])}
Expand All @@ -1155,7 +1155,7 @@ def test_explicit_initialization(self, adata: AnnData, correct_shape: bool):
val.var_names = adata.obs_names[col == tgt]
couplings[src, tgt] = val

tmk = TransportMapKernel(adata, couplings=couplings, time_key="exp_time")
tmk = RealTimeKernel(adata, couplings=couplings, time_key="exp_time")
assert tmk.couplings == couplings

if correct_shape:
Expand All @@ -1180,7 +1180,7 @@ def test_explicit_shuffle(self, adata_large: AnnData):
rng.shuffle(ixs)
adata_large = adata_large[ixs].copy()

tmk = TransportMapKernel(adata_large, time_key="time", couplings=expected)
tmk = RealTimeKernel(adata_large, time_key="time", couplings=expected)
tmk = tmk.compute_transition_matrix()
tmat = tmk.transition_matrix

Expand Down Expand Up @@ -1220,7 +1220,7 @@ def test_from_moscot(

problem = problem.prepare(policy=policy, time_key="exp_time", xy_callback_kwargs={"n_comps": 5}).solve()

tmk = TransportMapKernel.from_moscot(
tmk = RealTimeKernel.from_moscot(
problem,
sparse_mode=sparse_mode,
)
Expand All @@ -1245,7 +1245,7 @@ def test_from_wot(self, adata: AnnData, tmpdir):
ot_model = wot.ot.OTModel(adata, day_field="exp_time", growth_iters=gr_iters)
ot_model.compute_all_transport_maps(tmap_out=f"{tmpdir}/")

tmk = TransportMapKernel.from_wot(adata, path=tmpdir, time_key="exp_time")
tmk = RealTimeKernel.from_wot(adata, path=tmpdir, time_key="exp_time")
obs = pd.read_csv(tmpdir / "tmaps_g.txt", index_col=0, sep="\t")
tmk = tmk.compute_transition_matrix()

Expand All @@ -1271,7 +1271,7 @@ def test_from_moscot_set_solution(self, adata_large: AnnData):
expected[src, tgt] = tmp = tmp / np.sum(tmp, axis=-1, keepdims=True)
subprob.set_solution(tmp)

tmk = TransportMapKernel.from_moscot(problem)
tmk = RealTimeKernel.from_moscot(problem)
for (src, tgt), actual in tmk.couplings.items():
np.testing.assert_allclose(actual.X, expected[src, tgt], rtol=1e-6, atol=1e-6)

Expand Down

0 comments on commit dad776a

Please sign in to comment.