diff --git a/docs/notebooks b/docs/notebooks index bd10c2c1a..6777840bc 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit bd10c2c1a7f7ddf5674403997bc2704533b85367 +Subproject commit 6777840bcc146924d208c0b9b02677da63620805 diff --git a/pyproject.toml b/pyproject.toml index 540f751da..4d094a94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ maintainers = [ dependencies = [ - "anndata>=0.8", + "anndata>=0.9", "docrep>=0.3.0", "joblib>=0.13.1", "matplotlib>=3.5.0,<3.7.2", @@ -74,6 +74,7 @@ test = [ "pytest-mock>=3.5.0", "pytest-cov>=4", "coverage[toml]>=7", + "zarr", "igraph", "leidenalg", "Pillow", diff --git a/src/cellrank/_utils/_lineage.py b/src/cellrank/_utils/_lineage.py index 104467339..1531a9b4c 100644 --- a/src/cellrank/_utils/_lineage.py +++ b/src/cellrank/_utils/_lineage.py @@ -1195,17 +1195,15 @@ def _mutual_info(reference, query): return weights -_SPEC = IOSpec("array", "0.2.0") - - -@_REGISTRY.register_write(H5Group, Lineage, _SPEC) -@_REGISTRY.register_write(H5Group, LineageView, _SPEC) -@_REGISTRY.register_write(ZarrGroup, Lineage, _SPEC) -@_REGISTRY.register_write(ZarrGroup, LineageView, _SPEC) +@_REGISTRY.register_write(H5Group, Lineage, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(H5Group, LineageView, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(ZarrGroup, Lineage, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(ZarrGroup, LineageView, IOSpec("array", "0.2.0")) def _write_lineage( f: Any, k: str, elem: Union[Lineage, LineageView], + _writer: Any, dataset_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> None: - write_basic(f, k, elem=elem.X, dataset_kwargs=dataset_kwargs) + write_basic(f, k, elem=elem.X, _writer=_writer, dataset_kwargs=dataset_kwargs) diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 6fca7343e..039bcf513 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -11,6 +11,8 @@ from matplotlib import colors +from anndata import AnnData, read_h5ad, read_zarr + from cellrank._utils import Lineage from cellrank._utils._colors import _compute_mean_color, _create_categorical_colors from cellrank._utils._lineage import _HT_CELLS, LineageView, PrimingDegree @@ -973,7 +975,7 @@ def test_double_view_owner(self, lineage: Lineage): assert x.owner is lineage -class TestPickling: +class TestIO: def test_pickle_normal(self, lineage: Lineage): handle = io.BytesIO() @@ -1015,6 +1017,24 @@ def test_pickle_transposed(self, lineage: Lineage): assert res._n_lineages == lineage._n_lineages assert res._is_transposed == lineage._is_transposed + def test_anndata_write(self, lineage: Lineage, tmp_path): + rng = np.random.default_rng(0) + adata = AnnData(rng.normal(size=(lineage.shape[0], 13))) + adata.obsm["lin"] = lineage + + assert isinstance(adata.obsm["lin"], Lineage) + + adata.write_h5ad(tmp_path / "tmp.h5ad") + adata.write_zarr(tmp_path / "tmp.zarr") + + adata_h5ad = read_h5ad(tmp_path / "tmp.h5ad") + adata_zarr = read_zarr(tmp_path / "tmp.zarr") + + assert isinstance(adata_h5ad.obsm["lin"], np.ndarray) + assert isinstance(adata_zarr.obsm["lin"], np.ndarray) + np.testing.assert_array_equal(adata_h5ad.obsm["lin"], adata.obsm["lin"].X) + np.testing.assert_array_equal(adata_zarr.obsm["lin"], adata.obsm["lin"].X) + class TestPriming: def test_invalid_method(self, lineage: Lineage):