diff --git a/scanpy/plotting/_anndata.py b/scanpy/plotting/_anndata.py index 76dbfe298b..09ad4171b2 100755 --- a/scanpy/plotting/_anndata.py +++ b/scanpy/plotting/_anndata.py @@ -273,15 +273,14 @@ def _scatter_obs( palette_was_none = False if palette is None: palette_was_none = True - if isinstance(palette, cabc.Sequence): + if isinstance(palette, cabc.Sequence) and not isinstance(palette, str): if not is_color_like(palette[0]): palettes = palette else: palettes = [palette] else: palettes = [palette for _ in range(len(keys))] - for i, palette in enumerate(palettes): - palettes[i] = _utils.default_palette(palette) + palettes = [_utils.default_palette(palette) for palette in palettes] if basis is not None: component_name = ( @@ -375,8 +374,7 @@ def add_centroid(centroids, name, Y, mask): centroids[name] = Y_mask[i] # loop over all categorical annotation and plot it - for i, ikey in enumerate(categoricals): - palette = palettes[i] + for ikey, palette in zip(categoricals, palettes): key = keys[ikey] _utils.add_colors_for_categorical_sample_annotation( adata, key, palette, force_update_colors=not palette_was_none diff --git a/scanpy/plotting/_utils.py b/scanpy/plotting/_utils.py index 3ad9ed82ad..403f4640cc 100644 --- a/scanpy/plotting/_utils.py +++ b/scanpy/plotting/_utils.py @@ -1,7 +1,7 @@ +from __future__ import annotations + import warnings import collections.abc as cabc -from abc import ABC -from functools import lru_cache from typing import Union, List, Sequence, Tuple, Collection, Optional, Callable, Literal import anndata @@ -320,10 +320,12 @@ def savefig_or_show( pl.close() # clear figure -def default_palette(palette: Union[Sequence[str], Cycler, None] = None) -> Cycler: +def default_palette( + palette: Union[str, Sequence[str], Cycler, None] = None +) -> str | Cycler: if palette is None: return rcParams['axes.prop_cycle'] - elif not isinstance(palette, Cycler): + elif not isinstance(palette, (str, Cycler)): return cycler(color=palette) else: return palette diff --git a/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png b/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png new file mode 100644 index 0000000000..0ddd02eb90 Binary files /dev/null and b/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png differ diff --git a/scanpy/tests/test_neighbors.py b/scanpy/tests/test_neighbors.py index b9cae23552..7ef8b48b1e 100644 --- a/scanpy/tests/test_neighbors.py +++ b/scanpy/tests/test_neighbors.py @@ -145,7 +145,7 @@ def test_distances_euclidean(mocker, neigh, method): ), ], ) -def test_distances_all(neigh, transformer, knn): +def test_distances_all(neigh: Neighbors, transformer, knn): neigh.compute_neighbors( n_neighbors, transformer=transformer, method='gauss', knn=knn ) diff --git a/scanpy/tests/test_plotting.py b/scanpy/tests/test_plotting.py index e1442ff905..70328cab33 100644 --- a/scanpy/tests/test_plotting.py +++ b/scanpy/tests/test_plotting.py @@ -1290,14 +1290,23 @@ def test_scatter_specify_layer_and_raw(): sc.pl.umap(pbmc, color="HES4", use_raw=True, layer="layer") -def test_scatter_no_basis_per_obs(image_comparer): +@pytest.mark.parametrize("color", ["n_genes", "bulk_labels"]) +def test_scatter_no_basis_per_obs(image_comparer, color): """Test scatterplot of per-obs points with no basis""" save_and_compare_images = partial(image_comparer, ROOT, tol=15) pbmc = pbmc68k_reduced() - sc.pl.scatter(pbmc, x="HES4", y="percent_mito", color="n_genes", use_raw=False) - save_and_compare_images("scatter_HES_percent_mito_n_genes") + sc.pl.scatter( + pbmc, + x="HES4", + y="percent_mito", + color=color, + use_raw=False, + # palette only applies to categorical, i.e. color=='bulk_labels' + palette='Set2', + ) + save_and_compare_images(f"scatter_HES_percent_mito_{color}") def test_scatter_no_basis_per_var(image_comparer):