Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

obs. vs. obs for all grouped plots using BasePlot #2769

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2877f27
matrixplot functionality
yugeji Nov 28, 2023
dc676d4
added convert_tidy
yugeji Nov 28, 2023
a334e9b
Distinguish MultiIndex
lisadratva Nov 28, 2023
0ce9b8e
functionality for dotplot
yugeji Nov 28, 2023
23a00bf
remove copy-pasta
yugeji Nov 28, 2023
0c7b83e
functionality for stacked_violin
yugeji Nov 28, 2023
78a3973
bug fixes to categories
yugeji Nov 29, 2023
ea6d8ee
Adapt figure width to conditions
lisadratva Nov 29, 2023
9085641
switch to changing obs_tidy instead of categories
yugeji Nov 29, 2023
10da093
revert some changes
yugeji Nov 29, 2023
170e535
bug fix
yugeji Nov 29, 2023
a64f09e
bug fix
yugeji Nov 29, 2023
adebf2b
Add groupby_cols to common_plot_args
lisadratva Nov 30, 2023
fc03fdc
Show groupby_cols in docstrings
lisadratva Nov 30, 2023
408bbeb
Groupby_cols category counter for fig formatting
lisadratva Nov 30, 2023
9d6def7
Add detail to docstrings
lisadratva Dec 1, 2023
f578977
Correct max if nans present
lisadratva Dec 1, 2023
32b0687
Fix row formatting for stacked_violin
lisadratva Dec 1, 2023
70ebe65
Fix stacked_violin swap_axes issue
lisadratva Dec 1, 2023
824d507
Fix when groupby_cols missing
lisadratva Dec 1, 2023
4d6800a
ValueError for overlapping arguments
lisadratva Dec 1, 2023
8003859
Merge branch 'main' into master
flying-sheep Mar 18, 2024
91590f0
Fixup merge
flying-sheep Mar 18, 2024
3de4f91
Smaller diff
flying-sheep Mar 18, 2024
5d57811
Missed some
flying-sheep Mar 18, 2024
ecf6255
Merge branch 'main' into master
flying-sheep Mar 21, 2024
e6766df
Cleanup
flying-sheep Mar 21, 2024
1c4740e
undo
flying-sheep Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,7 @@ def _prepare_dataframe(
# and does not need to be given
groupby = groupby.copy() # copy to not modify user passed parameter
groupby.remove(groupby_index)
keys = list(groupby) + list(np.unique(var_names))
keys = [*groupby, *np.unique(var_names)]
obs_tidy = get.obs_df(
adata, keys=keys, layer=layer, use_raw=use_raw, gene_symbols=gene_symbols
)
Expand Down
71 changes: 66 additions & 5 deletions scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from warnings import warn

import numpy as np
import pandas as pd
from matplotlib import gridspec
from matplotlib import pyplot as plt

Expand Down Expand Up @@ -97,6 +98,7 @@
var_names: _VarNames | Mapping[str, _VarNames],
groupby: str | Sequence[str],
*,
groupby_cols: str | Sequence[str] = (),
use_raw: bool | None = None,
log: bool = False,
num_categories: int = 7,
Expand All @@ -120,7 +122,10 @@
self.var_group_positions = var_group_positions
self.var_group_rotation = var_group_rotation
self.width, self.height = figsize if figsize is not None else (None, None)

self.groupby = [groupby] if isinstance(groupby, str) else groupby
self.groupby_cols = (
[groupby_cols] if isinstance(groupby_cols, str) else groupby_cols
)
self.has_var_groups = (
True
if var_group_positions is not None and len(var_group_positions) > 0
Expand All @@ -132,13 +137,33 @@
self.categories, self.obs_tidy = _prepare_dataframe(
adata,
self.var_names,
groupby,
self.groupby,
use_raw=use_raw,
log=log,
num_categories=num_categories,
layer=layer,
gene_symbols=gene_symbols,
)
# reset obs_tidy if using groupby_cols
if len(self.groupby_cols) > 0:
if overlap := (set(self.groupby) & set(self.groupby_cols)):
raise ValueError(

Check warning on line 150 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L149-L150

Added lines #L149 - L150 were not covered by tests
f"`groupby` and `groupby_cols` have overlapping elements: {overlap}."
)
# TODO : Check if we rather need the product of categories ?
self.categories_cols = adata.obs.loc[:, self.groupby_cols].nunique().sum()
_, self.obs_tidy = _prepare_dataframe(

Check warning on line 155 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L154-L155

Added lines #L154 - L155 were not covered by tests
adata,
self.var_names,
[*self.groupby, *self.groupby_cols],
use_raw,
log,
num_categories,
layer=layer,
gene_symbols=gene_symbols,
)
else:
self.categories_cols = 0
if len(self.categories) > self.MAX_NUM_CATEGORIES:
warn(
f"Over {self.MAX_NUM_CATEGORIES} categories found. "
Expand All @@ -159,7 +184,6 @@
return

self.adata = adata
self.groupby = [groupby] if isinstance(groupby, str) else groupby
self.log = log
self.kwds = kwds

Expand Down Expand Up @@ -372,6 +396,11 @@
_sort = True if sort is not None else False
_ascending = True if sort == "ascending" else False
counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending)
# could remove the previous line and only use this but this is slower
if len(self.groupby_cols) > 0:
counts_df = self.adata.obs[self.groupby].value_counts(

Check warning on line 401 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L401

Added line #L401 was not covered by tests
sort=_sort, ascending=_ascending
)

if _sort:
self.categories_order = counts_df.index
Expand Down Expand Up @@ -586,7 +615,7 @@
self._plot_colorbar(color_legend_ax, normalize)
return_ax_dict["color_legend_ax"] = color_legend_ax

def _mainplot(self, ax):
def _mainplot(self, ax: Axes):
y_labels = self.categories
x_labels = self.var_names

Expand Down Expand Up @@ -655,7 +684,8 @@
if self.height is None:
mainplot_height = len(self.categories) * category_height
mainplot_width = (
len(self.var_names) * category_width + self.group_extra_size
len(self.var_names) * category_width * (1 + self.categories_cols)
+ self.group_extra_size
)
if self.are_axes_swapped:
mainplot_height, mainplot_width = mainplot_width, mainplot_height
Expand Down Expand Up @@ -857,6 +887,37 @@
self.make_figure()
plt.savefig(filename, bbox_inches=bbox_inches, **kwargs)

def _convert_tidy_to_stacked(self, values_df: pd.DataFrame) -> pd.DataFrame:
"""\
Utility function used to convert obs_tidy into the correct format when using a groupby_col.
"""
label = values_df.index.name
stacked_df = values_df.reset_index()
stacked_df.index = pd.MultiIndex.from_tuples(

Check warning on line 896 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L894-L896

Added lines #L894 - L896 were not covered by tests
stacked_df[label].str.split("_").tolist(),
names=self.groupby + self.groupby_cols,
)
stacked_df = stacked_df.drop(label, axis=1).unstack(level=self.groupby_cols)

Check warning on line 900 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L900

Added line #L900 was not covered by tests

# recreate the original formatting of values_df
values_df = stacked_df.reset_index(drop=True)
if isinstance(stacked_df.index, pd.MultiIndex):
values_df.index = (

Check warning on line 905 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L903-L905

Added lines #L903 - L905 were not covered by tests
stacked_df.index.to_series()
.apply(lambda x: "_".join(map(str, x)))
.values
)
else:
values_df.index = (

Check warning on line 911 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L911

Added line #L911 was not covered by tests
stacked_df.index.to_series()
.apply(lambda x: "".join(map(str, x)))
.values
)
values_df.columns = (

Check warning on line 916 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L916

Added line #L916 was not covered by tests
stacked_df.columns.to_series().apply(lambda x: "_".join(map(str, x))).values
)
return values_df

Check warning on line 919 in scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_baseplot_class.py#L919

Added line #L919 was not covered by tests

def _reorder_categories_after_dendrogram(self, dendrogram) -> None:
"""\
Function used by plotting functions that need to reorder the the groupby
Expand Down
2 changes: 2 additions & 0 deletions scanpy/plotting/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@
then the `var_group_labels` and `var_group_positions` are set.
groupby
The key of the observation grouping to consider.
groupby_cols
The key of the observation grouping to consider for grouping columns.
use_raw
Use `raw` attribute of `adata` if present.
log
Expand Down
17 changes: 14 additions & 3 deletions scanpy/plotting/_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
var_names: _VarNames | Mapping[str, _VarNames],
groupby: str | Sequence[str],
*,
groupby_cols: str | Sequence[str] = (),
use_raw: bool | None = None,
log: bool = False,
num_categories: int = 7,
Expand Down Expand Up @@ -169,6 +170,7 @@
adata,
var_names,
groupby,
groupby_cols=groupby_cols,
use_raw=use_raw,
log=log,
num_categories=num_categories,
Expand Down Expand Up @@ -204,6 +206,8 @@
obs_bool.groupby(level=0, observed=True).sum()
/ obs_bool.groupby(level=0, observed=True).count()
)
if len(groupby_cols) > 0:
dot_size_df = self._convert_tidy_to_stacked(dot_size_df)

Check warning on line 210 in scanpy/plotting/_dotplot.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_dotplot.py#L210

Added line #L210 was not covered by tests

if dot_color_df is None:
# 2. compute mean expression value value
Expand All @@ -227,6 +231,8 @@
pass
else:
logg.warning("Unknown type for standard_scale, ignored")
if len(groupby_cols) > 0:
dot_color_df = self._convert_tidy_to_stacked(dot_color_df)

Check warning on line 235 in scanpy/plotting/_dotplot.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_dotplot.py#L235

Added line #L235 was not covered by tests
else:
# check that both matrices have the same shape
if dot_color_df.shape != dot_size_df.shape:
Expand Down Expand Up @@ -568,7 +574,7 @@
self._plot_colorbar(color_legend_ax, normalize)
return_ax_dict["color_legend_ax"] = color_legend_ax

def _mainplot(self, ax):
def _mainplot(self, ax: Axes):
# work on a copy of the dataframes. This is to avoid changes
# on the original data frames after repetitive calls to the
# DotPlot object, for example once with swap_axes and other without
Expand Down Expand Up @@ -737,7 +743,7 @@
mean_flat = dot_color.values.flatten()
cmap = plt.get_cmap(cmap)
if dot_max is None:
dot_max = np.ceil(max(frac) * 10) / 10
dot_max = np.ceil(np.nanmax(frac) * 10) / 10
else:
if dot_max < 0 or dot_max > 1:
raise ValueError("`dot_max` value has to be between 0 and 1")
Expand All @@ -758,6 +764,8 @@
# rescale size to match smallest_dot and largest_dot
size = size * (largest_dot - smallest_dot) + smallest_dot
normalize = check_colornorm(vmin, vmax, vcenter, norm)
# circumvent unexpected behavior with nan in matplotlib
normalize(mean_flat[~np.isnan(mean_flat)])

if color_on == "square":
if edge_color is None:
Expand Down Expand Up @@ -871,6 +879,7 @@
var_names: _VarNames | Mapping[str, _VarNames],
groupby: str | Sequence[str],
*,
groupby_cols: str | Sequence[str] = (),
use_raw: bool | None = None,
log: bool = False,
num_categories: int = 7,
Expand Down Expand Up @@ -907,6 +916,7 @@
Makes a *dot plot* of the expression values of `var_names`.

For each var_name and each `groupby` category a dot is plotted.
Columns can optionally be grouped by specifying `groupby_cols`.
Each dot represents two values: mean expression within each category
(visualized by color) and fraction of cells expressing the `var_name` in the
category (visualized by the size of the dot). If `groupby` is not given,
Expand Down Expand Up @@ -1013,7 +1023,8 @@
dp = DotPlot(
adata,
var_names,
groupby,
groupby=groupby,
groupby_cols=groupby_cols,
use_raw=use_raw,
log=log,
num_categories=num_categories,
Expand Down
10 changes: 9 additions & 1 deletion scanpy/plotting/_matrixplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
var_names: _VarNames | Mapping[str, _VarNames],
groupby: str | Sequence[str],
*,
groupby_cols: str | Sequence[str] = (),
use_raw: bool | None = None,
log: bool = False,
num_categories: int = 7,
Expand All @@ -147,6 +148,7 @@
adata,
var_names,
groupby,
groupby_cols=groupby_cols,
use_raw=use_raw,
log=log,
num_categories=num_categories,
Expand Down Expand Up @@ -189,6 +191,9 @@
else:
logg.warning("Unknown type for standard_scale, ignored")

if len(groupby_cols) > 0:
values_df = self._convert_tidy_to_stacked(values_df)

Check warning on line 195 in scanpy/plotting/_matrixplot.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_matrixplot.py#L195

Added line #L195 was not covered by tests

self.values_df = values_df

self.cmap = self.DEFAULT_COLORMAP
Expand Down Expand Up @@ -252,7 +257,7 @@

return self

def _mainplot(self, ax):
def _mainplot(self, ax: Axes):
# work on a copy of the dataframes. This is to avoid changes
# on the original data frames after repetitive calls to the
# MatrixPlot object, for example once with swap_axes and other without
Expand Down Expand Up @@ -339,6 +344,7 @@
var_names: _VarNames | Mapping[str, _VarNames],
groupby: str | Sequence[str],
*,
groupby_cols: str | Sequence[str] = (),
use_raw: bool | None = None,
log: bool = False,
num_categories: int = 7,
Expand Down Expand Up @@ -367,6 +373,7 @@
) -> MatrixPlot | dict[str, Axes] | None:
"""\
Creates a heatmap of the mean expression values per group of each var_names.
Columns can optionally be grouped by specifying `groupby_cols`.

This function provides a convenient interface to the :class:`~scanpy.pl.MatrixPlot`
class. If you need more flexibility, you should use :class:`~scanpy.pl.MatrixPlot`
Expand Down Expand Up @@ -432,6 +439,7 @@
adata,
var_names,
groupby=groupby,
groupby_cols=groupby_cols,
use_raw=use_raw,
log=log,
num_categories=num_categories,
Expand Down