Skip to content

Commit

Permalink
Fix/assigning to categories (#1079)
Browse files Browse the repository at this point in the history
* Fix assigning to `.cat.categories`

* Remove dead code
  • Loading branch information
michalk8 committed Jun 28, 2023
1 parent ab03900 commit 9ef5a1f
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 245 deletions.
43 changes: 1 addition & 42 deletions src/cellrank/_utils/_colors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -278,44 +278,3 @@ def _compute_mean_color(cols: List[str]) -> str:
cols = np.array([colors.rgb_to_hsv(colors.to_rgb(c)) for c in cols])

return colors.to_hex(colors.hsv_to_rgb(np.mean(cols, axis=0)))


def _colors_in_order(
adata,
clusters: Optional[Iterable[str]] = None,
cluster_key: str = "clusters",
) -> List[Any]:
"""Get list of colors from AnnData in defined order.
Extracts a list of colors from ``adata.uns[cluster_key]`` in the order defined by the ``clusters``.
Parameters
----------
%(adata)s
clusters
Subset of the clusters we want the color for. Must be a subset of ``adata.obs['{cluster_key}'].cat.categories``.
cluster_key
Key from :attr:~`anndata.AnnData.obs``.
Returns
-------
List of colors in order defined by `clusters`.
"""
assert cluster_key in adata.obs, f"Could not find {cluster_key} in `adata.obs`."

if clusters is not None:
assert np.all(np.in1d(clusters, adata.obs[cluster_key].cat.categories)), "Not all `clusters` found."

assert f"{cluster_key}_colors" in adata.uns, f"No colors associated to {cluster_key} in `adata.uns`."

if clusters is None:
clusters = adata.obs[cluster_key].cat.categories

color_list = []
all_clusters = adata.obs[cluster_key].cat.categories

for cl in clusters:
mask = np.in1d(all_clusters, cl)
color_list.append(adata.uns[f"{cluster_key}_colors"][mask][0])

return color_list
43 changes: 1 addition & 42 deletions src/cellrank/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pandas as pd
import scipy.sparse as sp
import scipy.stats as st
from pandas.api.types import infer_dtype, is_bool_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype, is_categorical_dtype
from sklearn.cluster import KMeans
from statsmodels.stats.multitest import multipletests

Expand Down Expand Up @@ -528,17 +528,6 @@ def perm_test_extractor(res: Sequence[Tuple[np.ndarray, np.ndarray]]) -> Tuple[n
return corr, pvals, corr_ci_low, corr_ci_high


def _make_cat(labels: List[List[Any]], n_states: int, state_names: Sequence[str]) -> pd.Series:
"""Get categorical from list of lists."""
labels_new = np.repeat(np.nan, n_states)
for i, c in enumerate(labels):
labels_new[c] = i
labels_new = pd.Series(labels_new, index=state_names, dtype="category")
labels_new.cat.categories = labels_new.cat.categories.astype("int")

return labels_new


def _filter_cells(distances: sp.spmatrix, rc_labels: pd.Series, n_matches_min: int) -> pd.Series:
"""Filter out some cells that look like transient states based on their neighbors."""
if not is_categorical_dtype(rc_labels):
Expand Down Expand Up @@ -1014,16 +1003,6 @@ def _unique_order_preserving(iterable: Iterable[Hashable]) -> List[Hashable]:
return [i for i in iterable if i not in seen and not seen.add(i)]


def _info_if_obs_keys_categorical_present(
adata: AnnData, keys: Iterable[str], msg_fmt: str, warn_once: bool = True
) -> None:
for key in keys:
if key in adata.obs and is_categorical_dtype(adata.obs[key]):
logg.info(msg_fmt.format(key))
if warn_once:
break


def _one_hot(n, cat: Optional[int] = None) -> np.ndarray:
"""
One-hot encode cat to a vector of length n.
Expand Down Expand Up @@ -1383,26 +1362,6 @@ def _calculate_lineage_absorption_time_means(
return res


def _maybe_subset_hvgs(adata: AnnData, use_highly_variable: Optional[Union[bool, str]]) -> AnnData:
if use_highly_variable in (None, False):
return adata
key = "highly_variable" if use_highly_variable is True else use_highly_variable

if key not in adata.var.keys():
logg.warning(f"Unable to find HVGs in `adata.var[{key!r}]`. Using all genes")
return adata

if not is_bool_dtype(adata.var[key]):
logg.warning(
f"Expected `adata.var[{key!r}]` to be of bool dtype, "
f"found `{infer_dtype(adata.var[key])}`. Using all genes"
)
return adata

logg.info(f"Using `{np.sum(adata.var[key])}` HVGs from `adata.var[{key!r}]`")
return adata[:, adata.var[key]]


def _check_collection(
adata: AnnData,
needles: Iterable[str],
Expand Down
45 changes: 1 addition & 44 deletions src/cellrank/estimators/mixins/_fate_probabilities.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
import types
from typing import (
Any,
Dict,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Any, Dict, Literal, Mapping, NamedTuple, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -148,39 +138,6 @@ def _plot_continuous(
...


def _normalize_abs_times(
keys: Sequence[str], time_to_absorption: Any = None
) -> Dict[Tuple[str, ...], Literal["mean", "var"]]:
if time_to_absorption is None:
return {}

if isinstance(time_to_absorption, (str, tuple)):
time_to_absorption = [time_to_absorption]
if not isinstance(time_to_absorption, dict):
time_to_absorption = {ln: "mean" for ln in time_to_absorption}

res = {}
for ln, moment in time_to_absorption.items():
if moment not in ("mean", "var"):
raise ValueError(f"Moment must be either `'mean'` or `'var'`, found `{moment!r}` in `{ln}`.")

seen = set()
if isinstance(ln, str):
ln = tuple(keys) if ln == "all" else (ln,)
sorted_ln = tuple(sorted(ln)) # preserve the user order

if sorted_ln not in seen:
seen.add(sorted_ln)
for lin in ln:
if lin not in keys:
raise ValueError(
f"Invalid absorbing state `{lin!r}` in `{ln}`. " f"Valid options are `{list(keys)}`."
)
res[tuple(ln)] = moment

return res


class FateProbsMixin:
"""Mixin that supports computation of fate probabilities and mean times to absorption."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ def _set_categorical_labels(
series_query=series_query,
colors_reference=colors_reference,
)
categories.cat.categories = names
cats = categories.cat.categories
categories = categories.cat.rename_categories(dict(zip(cats, names)))
else:
colors = _create_categorical_colors(len(categories.cat.categories))

Expand Down
81 changes: 1 addition & 80 deletions src/cellrank/kernels/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable

import wrapt

import numba as nb
import numpy as np
import pandas as pd
import scipy.sparse as sp
from numba import prange
from pandas.api.types import infer_dtype, is_categorical_dtype, is_numeric_dtype

Expand Down Expand Up @@ -59,16 +58,6 @@ def np_std(array: np.ndarray, axis: int) -> np.ndarray: # noqa
return _np_apply_along_axis(np.std, axis, array)


@nb.njit(**jit_kwargs)
def np_max(array: np.ndarray, axis: int) -> np.ndarray: # noqa
return _np_apply_along_axis(np.max, axis, array)


@nb.njit(**jit_kwargs)
def np_sum(array: np.ndarray, axis: int) -> np.ndarray: # noqa
return _np_apply_along_axis(np.sum, axis, array)


@nb.njit(**jit_kwargs)
def norm(array: np.ndarray, axis: int) -> np.ndarray: # noqa
return _np_apply_along_axis(np.linalg.norm, axis, array)
Expand Down Expand Up @@ -107,74 +96,6 @@ def _random_normal(
).T


@nb.njit(**jit_kwargs)
def _get_probs_for_zero_vec(size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Get a vector with uniform probability and a vector of zeros.
Parameters
----------
size
Size of the vector.
Returns
-------
The probability and variance vectors.
"""
# float32 doesn't have enough precision
return (
np.ones(size, dtype=np.float64) / size,
np.zeros(size, dtype=np.float64),
)


def _reconstruct_one(
data: np.ndarray,
mat: sp.csr_matrix,
ixs: Optional[np.ndarray] = None,
) -> Tuple[sp.csr_matrix, sp.csr_matrix]:
"""Transform :class:`~numpy.ndarray` into :class:`~scipy.sparse.csr_matrix`.
Parameters
----------
data
Array of shape ``(2 x number_of_nnz)``.
mat
The original sparse matrix.
ixs
Indices that were used to sort the data.
Returns
-------
The probability and correlation matrix.
"""
assert data.shape == (2, mat.nnz), f"Dimension or shape mismatch: `{data.shape}`, `{2, mat.nnz}`."

aixs = None
if ixs is not None:
aixs = np.argsort(ixs)
assert len(ixs) == mat.shape[0], f"Shape mismatch: `{ixs.shape}`, `{mat.shape}`"
mat = mat[ixs]

# strange bug happens when no copying and eliminating zeros from cors (it's no longer row-stochastic)
# only happens when using numba
probs = sp.csr_matrix((np.array(data[0]), np.array(mat.indices), np.array(mat.indptr)))
cors = sp.csr_matrix((np.array(data[1]), np.array(mat.indices), np.array(mat.indptr)))

if aixs is not None:
assert len(aixs) == probs.shape[0], f"Shape mismatch: `{ixs.shape}`, `{probs.shape}`."
probs, cors = probs[aixs], cors[aixs]

probs.eliminate_zeros()
cors.eliminate_zeros()

row_sums = np.array(probs.sum(1).squeeze())
close_to_1 = np.isclose(row_sums, 1.0)
if not np.all(close_to_1):
raise ValueError(f"Matrix is not row-stochastic. The following rows don't sum to 1: `{row_sums[~close_to_1]}`.")

return probs, cors


@nb.njit(**jit_kwargs)
def _calculate_starts(indptr: np.ndarray, ixs: np.ndarray) -> np.ndarray:
"""Get the position where to put the data.
Expand Down
36 changes: 0 additions & 36 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
)
from cellrank.kernels._utils import (
_calculate_starts,
_get_probs_for_zero_vec,
_np_apply_along_axis,
_random_normal,
_reconstruct_one,
)
from cellrank.kernels.utils._similarity import (
_predict_transition_probabilities_jax,
Expand Down Expand Up @@ -876,45 +874,11 @@ def wrapped(axis: int, x: np.ndarray):
for fn in (np.var, np.std):
np.testing.assert_allclose(fn(x, axis=axis), _create_numba_fn(fn)(axis, x))

def test_zero_unif_sum_to_1_vector(self):
sum_to_1, zero = _get_probs_for_zero_vec(10)

assert zero.shape == (10,)
assert sum_to_1.shape == (10,)
np.testing.assert_array_equal(zero, np.zeros_like(zero))
np.testing.assert_allclose(sum_to_1, np.ones_like(sum_to_1) / sum_to_1.shape[0])
assert np.isclose(sum_to_1.sum(), 1.0)

def test_calculate_starts(self):
starts = _calculate_starts(sp.diags(np.ones(10)).tocsr().indptr, np.arange(10))

np.testing.assert_array_equal(starts, np.arange(11))

@pytest.mark.parametrize(("seed", "shuffle"), zip(range(4), [False] * 2 + [True] * 2))
def test_reconstruct_one(self, seed: int, shuffle: bool):
rng = np.random.default_rng(42)

m1 = sp.random(100, 10, random_state=seed, density=0.5, format="lil")
m1[:, 0] = 0.1
m1 /= m1.sum(1)
m1 = sp.csr_matrix(m1)

m2_data = rng.normal(size=(m1.nnz))
m2 = sp.csr_matrix((m2_data, m1.indices, m1.indptr))

if shuffle:
ixs = np.arange(100)
rng.shuffle(ixs)
data = np.c_[m1[ixs, :].data, m2[ixs, :].data].T
else:
ixs = None
data = np.c_[m1.data, m2.data].T

r1, r2 = _reconstruct_one(data, m1, ixs=ixs)

np.testing.assert_array_equal(r1.A, m1.A)
np.testing.assert_array_equal(r2.A, m2.A)

@jax_not_installed_skip
@pytest.mark.parametrize(
("seed", "c", "s"),
Expand Down

0 comments on commit 9ef5a1f

Please sign in to comment.