diff --git a/cellrank/tl/_lineage.py b/cellrank/tl/_lineage.py index 655ab6d7d..da00f5afd 100644 --- a/cellrank/tl/_lineage.py +++ b/cellrank/tl/_lineage.py @@ -1,4 +1,14 @@ -from typing import List, Tuple, Union, Mapping, TypeVar, Callable, Iterable, Optional +from typing import ( + Any, + List, + Tuple, + Union, + Mapping, + TypeVar, + Callable, + Iterable, + Optional, +) from typing_extensions import Literal from copy import copy @@ -20,6 +30,8 @@ _compute_mean_color, _create_categorical_colors, ) +from anndata._io.specs.methods import H5Group, ZarrGroup, write_basic +from anndata._io.specs.registry import _REGISTRY, IOSpec import numpy as np import pandas as pd @@ -1303,3 +1315,19 @@ def _mutual_info(reference, query): weights[i, :] = mutual_info_regression(reference, target) return weights + + +_SPEC = IOSpec("array", "0.2.0") + + +@_REGISTRY.register_write(H5Group, Lineage, _SPEC) +@_REGISTRY.register_write(H5Group, LineageView, _SPEC) +@_REGISTRY.register_write(ZarrGroup, Lineage, _SPEC) +@_REGISTRY.register_write(ZarrGroup, LineageView, _SPEC) +def _write_lineage( + f: Any, + k: str, + elem: Union[Lineage, LineageView], + dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> None: + write_basic(f, k, elem=elem.X, dataset_kwargs=dataset_kwargs) diff --git a/cellrank/tl/estimators/terminal_states/_gpcca.py b/cellrank/tl/estimators/terminal_states/_gpcca.py index f92e1f970..ee8b37ce9 100644 --- a/cellrank/tl/estimators/terminal_states/_gpcca.py +++ b/cellrank/tl/estimators/terminal_states/_gpcca.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Tuple, Union, Mapping, Optional, Sequence from typing_extensions import Literal -from copy import deepcopy from enum import auto from types import MappingProxyType from pathlib import Path @@ -1042,7 +1041,7 @@ def _write_macrostates( self._set("_coarse_tmat", value=tmat, shadow_only=True) self._set("_coarse_init_dist", value=init_dist, shadow_only=True) self._set("_coarse_stat_dist", value=stat_dist, shadow_only=True) - self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value={"coarse_tmat": tmat, **dists}) + self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value=AnnData(tmat, obs=dists)) else: for attr in ["_schur_vectors", "_schur_matrix", "_coarse_tmat", "_coarse_init_dist", "_coarse_stat_dist"]: self._set(attr, value=None, shadow_only=True) @@ -1104,14 +1103,14 @@ def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool: self._macrostates_colors = self.macrostates_memberships.colors.copy() self.params[key] = self._read_params(key) - tmat = deepcopy(self.adata.uns[Key.uns.coarse(self.backward)]) - if not isinstance(tmat, dict): + tmat = self.adata.uns[Key.uns.coarse(self.backward)].copy() + if not isinstance(tmat, AnnData): raise TypeError(f"Expected coarse-grained transition matrix to be stored " - f"as `dict`, found `{type(tmat).__name__}`.") + f"as `AnnData`, found `{type(tmat).__name__}`.") - self._coarse_tmat = tmat["coarse_tmat"] - self._coarse_init_dist = tmat["coarse_init_dist"] - self._coarse_stat_dist = tmat.get("coarse_stat_dist", None) + self._coarse_tmat = pd.DataFrame(tmat.X, index=tmat.obs_names, columns=tmat.obs_names) + self._coarse_init_dist = tmat.obs["coarse_init_dist"] + self._coarse_stat_dist = tmat.obs.get("coarse_stat_dist", None) self._set(obj=self._shadow_adata.uns, key=Key.uns.coarse(self.backward), value=tmat) diff --git a/requirements.txt b/requirements.txt index e32024ae3..432616ff7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -anndata<0.8 # TODO(mihalk8): remove me after #777 +anndata>=0.8,<0.9 docrep>=0.3.0 joblib>=0.13.1 matplotlib>=3.3.0