Skip to content

Commit

Permalink
Backport PR #2571: Fix scatter palette name
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Oct 9, 2023
1 parent 46969b4 commit 308df00
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
8 changes: 3 additions & 5 deletions scanpy/plotting/_anndata.py
Expand Up @@ -271,15 +271,14 @@ def _scatter_obs(
palette_was_none = False
if palette is None:
palette_was_none = True
if isinstance(palette, cabc.Sequence):
if isinstance(palette, cabc.Sequence) and not isinstance(palette, str):
if not is_color_like(palette[0]):
palettes = palette
else:
palettes = [palette]
else:
palettes = [palette for _ in range(len(keys))]
for i, palette in enumerate(palettes):
palettes[i] = _utils.default_palette(palette)
palettes = [_utils.default_palette(palette) for palette in palettes]

if basis is not None:
component_name = (
Expand Down Expand Up @@ -372,8 +371,7 @@ def add_centroid(centroids, name, Y, mask):
centroids[name] = Y_mask[i]

# loop over all categorical annotation and plot it
for i, ikey in enumerate(categoricals):
palette = palettes[i]
for ikey, palette in zip(categoricals, palettes):
key = keys[ikey]
_utils.add_colors_for_categorical_sample_annotation(
adata, key, palette, force_update_colors=not palette_was_none
Expand Down
10 changes: 6 additions & 4 deletions scanpy/plotting/_utils.py
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
import collections.abc as cabc
from abc import ABC
from functools import lru_cache
from typing import Union, List, Sequence, Tuple, Collection, Optional, Callable, Literal
import anndata

Expand Down Expand Up @@ -315,10 +315,12 @@ def savefig_or_show(
pl.close() # clear figure


def default_palette(palette: Union[Sequence[str], Cycler, None] = None) -> Cycler:
def default_palette(
palette: Union[str, Sequence[str], Cycler, None] = None
) -> str | Cycler:
if palette is None:
return rcParams['axes.prop_cycle']
elif not isinstance(palette, Cycler):
elif not isinstance(palette, (str, Cycler)):
return cycler(color=palette)
else:
return palette
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 12 additions & 3 deletions scanpy/tests/test_plotting.py
Expand Up @@ -1233,12 +1233,21 @@ def test_scatter_specify_layer_and_raw():
sc.pl.umap(pbmc, color="HES4", use_raw=True, layer="layer")


def test_scatter_no_basis_per_obs(image_comparer):
@pytest.mark.parametrize("color", ["n_genes", "bulk_labels"])
def test_scatter_no_basis_per_obs(image_comparer, color):
"""Test scatterplot of per-obs points with no basis"""
save_and_compare_images = image_comparer(ROOT, FIGS, tol=15)
pbmc = pbmc68k_reduced()
sc.pl.scatter(pbmc, x="HES4", y="percent_mito", color="n_genes", use_raw=False)
save_and_compare_images("scatter_HES_percent_mito_n_genes")
sc.pl.scatter(
pbmc,
x="HES4",
y="percent_mito",
color=color,
use_raw=False,
# palette only applies to categorical, i.e. color=='bulk_labels'
palette='Set2',
)
save_and_compare_images(f"scatter_HES_percent_mito_{color}")


def test_scatter_no_basis_per_var(image_comparer):
Expand Down

0 comments on commit 308df00

Please sign in to comment.