Skip to content

Commit

Permalink
Fix use_raw=None in scanpy.tl.score_genes (#1999)
Browse files Browse the repository at this point in the history
* Fix use_raw=None in scanpy.tl.score_genes

* Fix type of use_raw

* Add release notes #1999

* [ci skip] Improve release notes #1999

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>
  • Loading branch information
michalk8 and ivirshup committed Oct 27, 2021
1 parent 1b74015 commit 8c07642
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/release-notes/1.8.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
.. rubric:: Bug fixes

- Fix plotting after :func:`scanpy.tl.filter_rank_genes_groups` :pr:`1942` :smaller:`S Rybakov`
- Fix ``use_raw=None`` using :attr:`anndata.AnnData.var_names` if :attr:`anndata.AnnData.raw`
is present in :func:`scanpy.tl.score_genes` :pr:`1999` :smaller:`M Klein`

.. rubric:: Performance enhancements

Expand Down
17 changes: 17 additions & 0 deletions scanpy/tests/test_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,20 @@ def test_one_gene():
# https://github.com/theislab/scanpy/issues/1395
adata = _create_adata(100, 1000, p_zero=0, p_nan=0)
sc.tl.score_genes(adata, [adata.var_names[0]])


def test_use_raw_None():
adata = _create_adata(100, 1000, p_zero=0, p_nan=0)
adata_raw = adata.copy()
adata_raw.var_names = [str(i) for i in range(adata_raw.n_vars)]
adata.raw = adata_raw

sc.tl.score_genes(adata, adata_raw.var_names[:3], use_raw=None)


@pytest.mark.parametrize("gene_pool", [[], ["foo", "bar"]])
def test_invalid_gene_pool(gene_pool):
adata = _create_adata(100, 1000, p_zero=0, p_nan=0)

with pytest.raises(ValueError, match="reference set"):
sc.tl.score_genes(adata, adata.var_names[:3], gene_pool=gene_pool)
7 changes: 4 additions & 3 deletions scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def score_genes(
score_name: str = 'score',
random_state: AnyRandom = 0,
copy: bool = False,
use_raw: bool = None,
use_raw: Optional[bool] = None,
) -> Optional[AnnData]:
"""\
Score a set of genes [Satija15]_.
Expand Down Expand Up @@ -94,6 +94,7 @@ def score_genes(
"""
start = logg.info(f'computing score {score_name!r}')
adata = adata.copy() if copy else adata
use_raw = _check_use_raw(adata, use_raw)

if random_state is not None:
np.random.seed(random_state)
Expand All @@ -117,14 +118,14 @@ def score_genes(
gene_pool = list(var_names)
else:
gene_pool = [x for x in gene_pool if x in var_names]
if not gene_pool:
raise ValueError("No valid genes were passed for reference set.")

# Trying here to match the Seurat approach in scoring cells.
# Basically we need to compare genes against random genes in a matched
# interval of expression.

use_raw = _check_use_raw(adata, use_raw)
_adata = adata.raw if use_raw else adata

_adata_subset = (
_adata[:, gene_pool] if len(gene_pool) < len(_adata.var_names) else _adata
)
Expand Down

0 comments on commit 8c07642

Please sign in to comment.