Skip to content

Commit

Permalink
Fix lineage class serialization (#860)
Browse files Browse the repository at this point in the history
* Add IOspec for Lineage class

* Revert #841

* Fix IOSpec version to 0.2.0, remove reader
  • Loading branch information
michalk8 committed Apr 21, 2022
1 parent ac6efec commit 3d0ae2f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
30 changes: 29 additions & 1 deletion cellrank/tl/_lineage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import List, Tuple, Union, Mapping, TypeVar, Callable, Iterable, Optional
from typing import (
Any,
List,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Iterable,
Optional,
)
from typing_extensions import Literal

from copy import copy
Expand All @@ -20,6 +30,8 @@
_compute_mean_color,
_create_categorical_colors,
)
from anndata._io.specs.methods import H5Group, ZarrGroup, write_basic
from anndata._io.specs.registry import _REGISTRY, IOSpec

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1303,3 +1315,19 @@ def _mutual_info(reference, query):
weights[i, :] = mutual_info_regression(reference, target)

return weights


_SPEC = IOSpec("array", "0.2.0")


@_REGISTRY.register_write(H5Group, Lineage, _SPEC)
@_REGISTRY.register_write(H5Group, LineageView, _SPEC)
@_REGISTRY.register_write(ZarrGroup, Lineage, _SPEC)
@_REGISTRY.register_write(ZarrGroup, LineageView, _SPEC)
def _write_lineage(
f: Any,
k: str,
elem: Union[Lineage, LineageView],
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> None:
write_basic(f, k, elem=elem.X, dataset_kwargs=dataset_kwargs)
15 changes: 7 additions & 8 deletions cellrank/tl/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Tuple, Union, Mapping, Optional, Sequence
from typing_extensions import Literal

from copy import deepcopy
from enum import auto
from types import MappingProxyType
from pathlib import Path
Expand Down Expand Up @@ -1042,7 +1041,7 @@ def _write_macrostates(
self._set("_coarse_tmat", value=tmat, shadow_only=True)
self._set("_coarse_init_dist", value=init_dist, shadow_only=True)
self._set("_coarse_stat_dist", value=stat_dist, shadow_only=True)
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value={"coarse_tmat": tmat, **dists})
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value=AnnData(tmat, obs=dists))
else:
for attr in ["_schur_vectors", "_schur_matrix", "_coarse_tmat", "_coarse_init_dist", "_coarse_stat_dist"]:
self._set(attr, value=None, shadow_only=True)
Expand Down Expand Up @@ -1104,14 +1103,14 @@ def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool:
self._macrostates_colors = self.macrostates_memberships.colors.copy()
self.params[key] = self._read_params(key)

tmat = deepcopy(self.adata.uns[Key.uns.coarse(self.backward)])
if not isinstance(tmat, dict):
tmat = self.adata.uns[Key.uns.coarse(self.backward)].copy()
if not isinstance(tmat, AnnData):
raise TypeError(f"Expected coarse-grained transition matrix to be stored "
f"as `dict`, found `{type(tmat).__name__}`.")
f"as `AnnData`, found `{type(tmat).__name__}`.")

self._coarse_tmat = tmat["coarse_tmat"]
self._coarse_init_dist = tmat["coarse_init_dist"]
self._coarse_stat_dist = tmat.get("coarse_stat_dist", None)
self._coarse_tmat = pd.DataFrame(tmat.X, index=tmat.obs_names, columns=tmat.obs_names)
self._coarse_init_dist = tmat.obs["coarse_init_dist"]
self._coarse_stat_dist = tmat.obs.get("coarse_stat_dist", None)

self._set(obj=self._shadow_adata.uns, key=Key.uns.coarse(self.backward), value=tmat)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
anndata<0.8 # TODO(mihalk8): remove me after #777
anndata>=0.8,<0.9
docrep>=0.3.0
joblib>=0.13.1
matplotlib>=3.3.0
Expand Down

0 comments on commit 3d0ae2f

Please sign in to comment.