Skip to content

Commit

Permalink
Revert #841
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Mar 25, 2022
1 parent 18e13a5 commit 239d1da
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions cellrank/tl/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Tuple, Union, Mapping, Optional, Sequence
from typing_extensions import Literal

from copy import deepcopy
from enum import auto
from types import MappingProxyType
from pathlib import Path
Expand Down Expand Up @@ -1042,7 +1041,7 @@ def _write_macrostates(
self._set("_coarse_tmat", value=tmat, shadow_only=True)
self._set("_coarse_init_dist", value=init_dist, shadow_only=True)
self._set("_coarse_stat_dist", value=stat_dist, shadow_only=True)
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value={"coarse_tmat": tmat, **dists})
self._set(obj=self.adata.uns, key=Key.uns.coarse(self.backward), value=AnnData(tmat, obs=dists))
else:
for attr in ["_schur_vectors", "_schur_matrix", "_coarse_tmat", "_coarse_init_dist", "_coarse_stat_dist"]:
self._set(attr, value=None, shadow_only=True)
Expand Down Expand Up @@ -1104,14 +1103,14 @@ def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool:
self._macrostates_colors = self.macrostates_memberships.colors.copy()
self.params[key] = self._read_params(key)

tmat = deepcopy(self.adata.uns[Key.uns.coarse(self.backward)])
if not isinstance(tmat, dict):
tmat = self.adata.uns[Key.uns.coarse(self.backward)].copy()
if not isinstance(tmat, AnnData):
raise TypeError(f"Expected coarse-grained transition matrix to be stored "
f"as `dict`, found `{type(tmat).__name__}`.")
f"as `AnnData`, found `{type(tmat).__name__}`.")

self._coarse_tmat = tmat["coarse_tmat"]
self._coarse_init_dist = tmat["coarse_init_dist"]
self._coarse_stat_dist = tmat.get("coarse_stat_dist", None)
self._coarse_tmat = pd.DataFrame(tmat.X, index=tmat.obs_names, columns=tmat.obs_names)
self._coarse_init_dist = tmat.obs["coarse_init_dist"]
self._coarse_stat_dist = tmat.obs.get("coarse_stat_dist", None)

self._set(obj=self._shadow_adata.uns, key=Key.uns.coarse(self.backward), value=tmat)

Expand Down

0 comments on commit 239d1da

Please sign in to comment.