Skip to content

Commit

Permalink
dim -> axis
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Dec 12, 2023
1 parent 3764a7f commit 0aef147
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
32 changes: 16 additions & 16 deletions scanpy/get/_aggregated.py
Expand Up @@ -152,7 +152,7 @@ def aggregate(
by: str | list[str],
func: AggType | Iterable[AggType],
*,
dim: Literal["obs", "var"] | None = None,
axis: Literal[0, 1] | None = None,
dof: int = 1,
layer: str | None = None,
obsm: str | None = None,
Expand All @@ -179,7 +179,7 @@ def aggregate(
Key of the column to be grouped-by.
func
How to aggregate.
dim
axis
Axis on which to find group by column.
dof
Degrees of freedom for variance. Defaults to 1.
Expand Down Expand Up @@ -221,36 +221,36 @@ def aggregate(
Note that this filters out any combination of groups that wasn't present in the original data.
"""
if dim not in ["obs", "var", None]:
raise ValueError(f"dim must be one of 'obs' or 'var', was '{dim}'")
if axis not in [0, 1, None]:
raise ValueError(f"axis must be one of 0 or 1, was '{axis}'")
# TODO replace with get helper
data = adata.X
if sum(p is not None for p in [varm, obsm, layer]) > 1:
raise TypeError("Please only provide one (or none) of varm, obsm, or layer")

if dim is None:
if axis is None:
if varm:
dim = "var"
axis = 1
else:
dim = "obs"
axis = 0

if varm is not None:
if dim != "var":
raise ValueError("varm can only be used when dim is 'var'")
if axis != 1:
raise ValueError("varm can only be used when axis is 1")
data = adata.varm[varm]
elif obsm is not None:
if dim != "obs":
raise ValueError("obsm can only be used when dim is 'obs'")
if axis != 0:
raise ValueError("obsm can only be used when axis is 0")
data = adata.obsm[obsm]
elif layer is not None:
data = adata.layers[layer]
if dim == "var":
if axis == 1:
data = data.T
elif dim == "var":
elif axis == 1:
# i.e., all of `varm`, `obsm`, `layers` are None so we use `X` which must be transposed
data = data.T

dim_df = getattr(adata, dim)
dim_df = getattr(adata, ["obs", "var"][axis])
categorical, new_label_df = _combine_categories(dim_df, by)
# Actual computation
layers = aggregate(
Expand All @@ -262,10 +262,10 @@ def aggregate(
result = AnnData(
layers=layers,
obs=new_label_df,
var=getattr(adata, "var" if dim == "obs" else "obs"),
var=getattr(adata, "var" if axis == 0 else "obs"),
)

if dim == "var":
if axis == 1:
return result.T
else:
return result
Expand Down
4 changes: 2 additions & 2 deletions scanpy/tests/test_aggregated.py
Expand Up @@ -120,7 +120,7 @@ def test_aggregate_axis(array_type, metric):
].copy()
adata.X = array_type(adata.X)
expected = sc.get.aggregate(adata, ["louvain"], metric)
actual = sc.get.aggregate(adata.T, ["louvain"], metric, dim="var").T
actual = sc.get.aggregate(adata.T, ["louvain"], metric, axis=1).T

assert_equal(expected, actual)

Expand Down Expand Up @@ -164,7 +164,7 @@ def test_aggregate_incorrect_dim():
adata = pbmc3k_processed().raw.to_adata()

with pytest.raises(ValueError, match="was 'foo'"):
sc.get.aggregate(adata, ["louvain"], "sum", dim="foo")
sc.get.aggregate(adata, ["louvain"], "sum", axis="foo")


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0aef147

Please sign in to comment.