Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add order option in GPCCA.plot_coarse_T #804

Merged
merged 4 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 59 additions & 8 deletions cellrank/tl/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cellrank import logging as logg
from cellrank._key import Key
from cellrank.tl._enum import ModeEnum
from cellrank.ul._docs import d
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import (
save_fig,
_eigengap,
Expand Down Expand Up @@ -46,6 +46,12 @@ class TermStatesMethod(ModeEnum): # noqa: D101
STABILITY = auto()


class CoarseTOrder(ModeEnum): # noqa: D101
STABILITY = auto() # diagonal
INCOMING = auto()
STAT_DIST = auto()


@d.dedent
class GPCCA(TermStatesEstimator, LinDriversMixin, SchurMixin, EigenMixin):
"""
Expand Down Expand Up @@ -426,10 +432,12 @@ def fit(
return self

@d.dedent
@inject_docs(o=CoarseTOrder)
def plot_coarse_T(
self,
show_stationary_dist: bool = True,
show_initial_dist: bool = False,
order: Optional[Literal["stability", "incoming", "stat_dist"]] = "stability",
cmap: Union[str, ListedColormap] = "viridis",
xtick_rotation: float = 45,
annotate: bool = True,
Expand All @@ -450,6 +458,12 @@ def plot_coarse_T(
Whether to show :attr:`coarse_stationary_distribution`, if present.
show_initial_dist
Whether to show :attr:`coarse_initial_distribution`.
order
How to order the coarse-grained transition matrix. Valid options are:

- `{o.STABILITY!r}` - order by the values on the diagonal.
- `{o.INCOMING!r}` - order by the incoming mass, excluding the diagonal.
- `{o.STAT_DIST!r}` - order by coarse stationary distribution. If not present, use `{o.STABILITY!r}`.
cmap
Colormap to use.
xtick_rotation
Expand All @@ -471,6 +485,44 @@ def plot_coarse_T(
%(just_plots)s
"""

def order_matrix(
order: Optional[CoarseTOrder],
) -> Tuple[pd.DataFrame, Optional[pd.Series], Optional[pd.Series]]:
coarse_T = self.coarse_T
init_d = self.coarse_initial_distribution
stat_d = self.coarse_stationary_distribution

if order is None:
return coarse_T, init_d, stat_d

order = CoarseTOrder(order)
if order == CoarseTOrder.STAT_DIST and stat_d is None:
order = CoarseTOrder.STABILITY
logg.warning(
f"Unable to order by `{CoarseTOrder.STAT_DIST}`, no coarse stationary distribution. "
f"Using `order={order}`"
)

if order == CoarseTOrder.INCOMING:
values = (coarse_T.sum(0) - np.diag(coarse_T)).argsort(kind="stable")
names = values.index[values][::-1]
elif order == CoarseTOrder.STABILITY:
names = coarse_T.index[
np.argsort(np.diag(coarse_T), kind="stable")[::-1]
]
elif order == CoarseTOrder.STAT_DIST:
names = stat_d.index[stat_d.argsort(kind="stable")][::-1]
else:
raise NotImplementedError(f"Order `{order}` is not yet implemented.")

coarse_T = coarse_T.loc[names][names]
if init_d is not None:
init_d = init_d[names]
if stat_d is not None:
stat_d = stat_d[names]

return coarse_T, init_d, stat_d

def stylize_dist(
ax: Axes, data: np.ndarray, xticks_labels: Sequence[str] = ()
) -> None:
Expand Down Expand Up @@ -515,7 +567,9 @@ def annotate_heatmap(im, valfmt: str = "{x:.2f}") -> None:
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
texts.append(text)

def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
def annotate_dist_ax(
ax: Axes, data: np.ndarray, valfmt: str = "{x:.2f}"
) -> None:
if ax is None:
return

Expand All @@ -534,15 +588,12 @@ def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
**kw,
)

coarse_T = self.coarse_T
coarse_init_d = self.coarse_initial_distribution
coarse_stat_d = self.coarse_stationary_distribution

if coarse_T is None:
if self.coarse_T is None:
raise RuntimeError(
"Compute coarse-grained transition matrix first as `.compute_macrostates()` with `n_states > 1`."
)

coarse_T, coarse_init_d, coarse_stat_d = order_matrix(order)
if show_stationary_dist and coarse_stat_d is None:
logg.warning("Coarse stationary distribution is `None`, ignoring")
show_stationary_dist = False
Expand Down Expand Up @@ -576,7 +627,7 @@ def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None
init_ax, stat_ax = None, None

labels = list(self.coarse_T.columns)
labels = list(coarse_T.columns)

tmp = coarse_T
if show_initial_dist:
Expand Down
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_cmap.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_init_dist.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_no_annot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_no_cbar.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_stat_dist.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_stat_init_dist.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_ground_truth_figures/gpcca_coarse_T_xtick_rot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2463,6 +2463,10 @@ def test_gpcca_coarse_T_cmap(self, mc: GPCCA, fpath: str):
def test_gpcca_coarse_T_xtick_rot(self, mc: GPCCA, fpath: str):
mc.plot_coarse_T(xtick_rotation=0, dpi=DPI, save=fpath)

@compare(kind="gpcca")
def test_gpcca_coarse_T_no_order(self, mc: GPCCA, fpath: str):
mc.plot_coarse_T(order=None, dpi=DPI, save=fpath)

@compare(kind="gpcca")
def test_scvelo_gpcca_meta_states(self, mc: GPCCA, fpath: str):
mc.plot_macrostates(dpi=DPI, save=fpath)
Expand Down