diff --git a/scanpy/get/_aggregated.py b/scanpy/get/_aggregated.py index 78888272b5..058fd9921f 100644 --- a/scanpy/get/_aggregated.py +++ b/scanpy/get/_aggregated.py @@ -152,7 +152,7 @@ def aggregate( by: str | list[str], func: AggType | Iterable[AggType], *, - dim: Literal["obs", "var"] | None = None, + axis: Literal[0, 1] | None = None, dof: int = 1, layer: str | None = None, obsm: str | None = None, @@ -179,7 +179,7 @@ def aggregate( Key of the column to be grouped-by. func How to aggregate. - dim + axis Axis on which to find group by column. dof Degrees of freedom for variance. Defaults to 1. @@ -221,36 +221,36 @@ def aggregate( Note that this filters out any combination of groups that wasn't present in the original data. """ - if dim not in ["obs", "var", None]: - raise ValueError(f"dim must be one of 'obs' or 'var', was '{dim}'") + if axis not in [0, 1, None]: + raise ValueError(f"axis must be one of 0 or 1, was '{axis}'") # TODO replace with get helper data = adata.X if sum(p is not None for p in [varm, obsm, layer]) > 1: raise TypeError("Please only provide one (or none) of varm, obsm, or layer") - if dim is None: + if axis is None: if varm: - dim = "var" + axis = 1 else: - dim = "obs" + axis = 0 if varm is not None: - if dim != "var": - raise ValueError("varm can only be used when dim is 'var'") + if axis != 1: + raise ValueError("varm can only be used when axis is 1") data = adata.varm[varm] elif obsm is not None: - if dim != "obs": - raise ValueError("obsm can only be used when dim is 'obs'") + if axis != 0: + raise ValueError("obsm can only be used when axis is 0") data = adata.obsm[obsm] elif layer is not None: data = adata.layers[layer] - if dim == "var": + if axis == 1: data = data.T - elif dim == "var": + elif axis == 1: # i.e., all of `varm`, `obsm`, `layers` are None so we use `X` which must be transposed data = data.T - dim_df = getattr(adata, dim) + dim_df = getattr(adata, ["obs", "var"][axis]) categorical, new_label_df = _combine_categories(dim_df, by) # Actual computation layers = aggregate( @@ -262,10 +262,10 @@ def aggregate( result = AnnData( layers=layers, obs=new_label_df, - var=getattr(adata, "var" if dim == "obs" else "obs"), + var=getattr(adata, "var" if axis == 0 else "obs"), ) - if dim == "var": + if axis == 1: return result.T else: return result diff --git a/scanpy/tests/test_aggregated.py b/scanpy/tests/test_aggregated.py index e6e8f0c5e3..9d035237ce 100644 --- a/scanpy/tests/test_aggregated.py +++ b/scanpy/tests/test_aggregated.py @@ -120,7 +120,7 @@ def test_aggregate_axis(array_type, metric): ].copy() adata.X = array_type(adata.X) expected = sc.get.aggregate(adata, ["louvain"], metric) - actual = sc.get.aggregate(adata.T, ["louvain"], metric, dim="var").T + actual = sc.get.aggregate(adata.T, ["louvain"], metric, axis=1).T assert_equal(expected, actual) @@ -164,7 +164,7 @@ def test_aggregate_incorrect_dim(): adata = pbmc3k_processed().raw.to_adata() with pytest.raises(ValueError, match="was 'foo'"): - sc.get.aggregate(adata, ["louvain"], "sum", dim="foo") + sc.get.aggregate(adata, ["louvain"], "sum", axis="foo") @pytest.mark.parametrize(