Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lineage class serialization #860

Merged
merged 4 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion cellrank/tl/_lineage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import List, Tuple, Union, Mapping, TypeVar, Callable, Iterable, Optional
from typing import (
Any,
List,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Iterable,
Optional,
)
from typing_extensions import Literal

from copy import copy
Expand All @@ -20,6 +30,8 @@
_compute_mean_color,
_create_categorical_colors,
)
from anndata._io.specs.methods import H5Group, ZarrGroup, write_basic
from anndata._io.specs.registry import _REGISTRY, IOSpec

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1303,3 +1315,19 @@ def _mutual_info(reference, query):
weights[i, :] = mutual_info_regression(reference, target)

return weights


_SPEC = IOSpec("array", "0.2.0")


@_REGISTRY.register_write(H5Group, Lineage, _SPEC)
@_REGISTRY.register_write(H5Group, LineageView, _SPEC)
@_REGISTRY.register_write(ZarrGroup, Lineage, _SPEC)
@_REGISTRY.register_write(ZarrGroup, LineageView, _SPEC)
def _write_lineage(
f: Any,
k: str,
elem: Union[Lineage, LineageView],
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> None:
write_basic(f, k, elem=elem.X, dataset_kwargs=dataset_kwargs)
15 changes: 7 additions & 8 deletions cellrank/tl/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Tuple, Union, Mapping, Optional, Sequence
from typing_extensions import Literal

from copy import deepcopy
from enum import auto
from types import MappingProxyType
from pathlib import Path
Expand Down Expand Up @@ -1042,7 +1041,7 @@ def _write_macrostates(
self._set("_coarse_tmat", value=tmat, shadow_only=True)
self._set("_coarse_init_dist", value=init_dist, shadow_only=True)
self._set("_coarse_stat_dist", value=stat_dist, shadow_only=True)
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value={"coarse_tmat": tmat, **dists})
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value=AnnData(tmat, obs=dists))
else:
for attr in ["_schur_vectors", "_schur_matrix", "_coarse_tmat", "_coarse_init_dist", "_coarse_stat_dist"]:
self._set(attr, value=None, shadow_only=True)
Expand Down Expand Up @@ -1104,14 +1103,14 @@ def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool:
self._macrostates_colors = self.macrostates_memberships.colors.copy()
self.params[key] = self._read_params(key)

tmat = deepcopy(self.adata.uns[Key.uns.coarse(self.backward)])
if not isinstance(tmat, dict):
tmat = self.adata.uns[Key.uns.coarse(self.backward)].copy()
if not isinstance(tmat, AnnData):
raise TypeError(f"Expected coarse-grained transition matrix to be stored "
f"as `dict`, found `{type(tmat).__name__}`.")
f"as `AnnData`, found `{type(tmat).__name__}`.")

self._coarse_tmat = tmat["coarse_tmat"]
self._coarse_init_dist = tmat["coarse_init_dist"]
self._coarse_stat_dist = tmat.get("coarse_stat_dist", None)
self._coarse_tmat = pd.DataFrame(tmat.X, index=tmat.obs_names, columns=tmat.obs_names)
self._coarse_init_dist = tmat.obs["coarse_init_dist"]
self._coarse_stat_dist = tmat.obs.get("coarse_stat_dist", None)

self._set(obj=self._shadow_adata.uns, key=Key.uns.coarse(self.backward), value=tmat)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
anndata<0.8 # TODO(mihalk8): remove me after #777
anndata>=0.8,<0.9
docrep>=0.3.0
joblib>=0.13.1
matplotlib>=3.3.0
Expand Down