diff --git a/scanpy/plotting/_anndata.py b/scanpy/plotting/_anndata.py index a904a0778..9c31fc735 100755 --- a/scanpy/plotting/_anndata.py +++ b/scanpy/plotting/_anndata.py @@ -271,15 +271,14 @@ def _scatter_obs( palette_was_none = False if palette is None: palette_was_none = True - if isinstance(palette, cabc.Sequence): + if isinstance(palette, cabc.Sequence) and not isinstance(palette, str): if not is_color_like(palette[0]): palettes = palette else: palettes = [palette] else: palettes = [palette for _ in range(len(keys))] - for i, palette in enumerate(palettes): - palettes[i] = _utils.default_palette(palette) + palettes = [_utils.default_palette(palette) for palette in palettes] if basis is not None: component_name = ( @@ -372,8 +371,7 @@ def add_centroid(centroids, name, Y, mask): centroids[name] = Y_mask[i] # loop over all categorical annotation and plot it - for i, ikey in enumerate(categoricals): - palette = palettes[i] + for ikey, palette in zip(categoricals, palettes): key = keys[ikey] _utils.add_colors_for_categorical_sample_annotation( adata, key, palette, force_update_colors=not palette_was_none diff --git a/scanpy/plotting/_utils.py b/scanpy/plotting/_utils.py index 0a6fa70d8..76639a81f 100644 --- a/scanpy/plotting/_utils.py +++ b/scanpy/plotting/_utils.py @@ -1,7 +1,7 @@ +from __future__ import annotations + import warnings import collections.abc as cabc -from abc import ABC -from functools import lru_cache from typing import Union, List, Sequence, Tuple, Collection, Optional, Callable, Literal import anndata @@ -315,10 +315,12 @@ def savefig_or_show( pl.close() # clear figure -def default_palette(palette: Union[Sequence[str], Cycler, None] = None) -> Cycler: +def default_palette( + palette: Union[str, Sequence[str], Cycler, None] = None +) -> str | Cycler: if palette is None: return rcParams['axes.prop_cycle'] - elif not isinstance(palette, Cycler): + elif not isinstance(palette, (str, Cycler)): return cycler(color=palette) else: return palette diff --git a/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png b/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png new file mode 100644 index 000000000..0ddd02eb9 Binary files /dev/null and b/scanpy/tests/_images/scatter_HES_percent_mito_bulk_labels/expected.png differ diff --git a/scanpy/tests/test_plotting.py b/scanpy/tests/test_plotting.py index a51b37860..cd1919b8c 100644 --- a/scanpy/tests/test_plotting.py +++ b/scanpy/tests/test_plotting.py @@ -1233,12 +1233,21 @@ def test_scatter_specify_layer_and_raw(): sc.pl.umap(pbmc, color="HES4", use_raw=True, layer="layer") -def test_scatter_no_basis_per_obs(image_comparer): +@pytest.mark.parametrize("color", ["n_genes", "bulk_labels"]) +def test_scatter_no_basis_per_obs(image_comparer, color): """Test scatterplot of per-obs points with no basis""" save_and_compare_images = image_comparer(ROOT, FIGS, tol=15) pbmc = pbmc68k_reduced() - sc.pl.scatter(pbmc, x="HES4", y="percent_mito", color="n_genes", use_raw=False) - save_and_compare_images("scatter_HES_percent_mito_n_genes") + sc.pl.scatter( + pbmc, + x="HES4", + y="percent_mito", + color=color, + use_raw=False, + # palette only applies to categorical, i.e. color=='bulk_labels' + palette='Set2', + ) + save_and_compare_images(f"scatter_HES_percent_mito_{color}") def test_scatter_no_basis_per_var(image_comparer):