From d102c3b0eeade328c25eea520a165d570d80f38a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 13 Sep 2025 13:11:02 +0200 Subject: [PATCH] make ready for new spatialdata --- src/spatialdata_plot/pl/utils.py | 3 ++- tests/conftest.py | 2 +- tests/pl/test_render_labels.py | 26 +++++++++++++------------- tests/pl/test_render_points.py | 14 +++++++------- tests/pl/test_render_shapes.py | 32 +++++++++++++++++--------------- tests/pl/test_utils.py | 3 ++- 6 files changed, 42 insertions(+), 38 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index c795bbea..6a150f7f 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -780,8 +780,9 @@ def _set_color_source_vec( color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` + # TODO check why table_name is not passed here. color_mapping = _get_categorical_color_mapping( - adata=sdata.table, + adata=sdata["table"], cluster_key=value_to_plot, color_source_vector=color_source_vector, cmap_params=cmap_params, diff --git a/tests/conftest.py b/tests/conftest.py index 1c085297..a4fdb789 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,7 +154,7 @@ def test_sdata_multiple_images_diverging_dims(): def sdata_blobs_shapes_annotated() -> SpatialData: """Get blobs sdata with continuous annotation of polygons.""" blob = blobs() - blob["table"].obs["region"] = "blobs_polygons" + blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs) blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5] return blob diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 3f50effb..69c2da42 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -33,12 +33,12 @@ def test_plot_can_render_labels(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels(element="blobs_labels").pl.show() def test_plot_can_render_multiscale_labels(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" sdata_blobs.pl.render_labels("blobs_multiscale_labels").pl.show() def test_plot_can_render_given_scale_of_multiscale_labels(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" sdata_blobs.pl.render_labels("blobs_multiscale_labels", scale="scale1").pl.show() @@ -50,7 +50,7 @@ def test_plot_can_do_rasterization(self, sdata_blobs: SpatialData): img.attrs["transform"] = sdata_blobs["blobs_labels"].transform sdata_blobs["blobs_giant_labels"] = img - sdata_blobs["table"].obs["region"] = "blobs_giant_labels" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels" sdata_blobs.pl.render_labels("blobs_giant_labels").pl.show() @@ -63,7 +63,7 @@ def test_plot_can_stop_rasterization_with_scale_full(self, sdata_blobs: SpatialD img.attrs["transform"] = sdata_blobs["blobs_labels"].transform sdata_blobs["blobs_giant_labels"] = img - sdata_blobs["table"].obs["region"] = "blobs_giant_labels" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels" sdata_blobs.pl.render_labels("blobs_giant_labels", scale="full").pl.show() @@ -110,7 +110,7 @@ def _make_tablemodel_with_categorical_labels(sdata_blobs, label): max_col = max_col.str.replace("channel_", "ch").str.replace("_sum", "") max_col = pd.Categorical(max_col, categories=set(max_col), ordered=True) adata.obs["which_max"] = max_col - adata.obs["region"] = label + adata.obs["region"] = pd.Categorical([label] * adata.n_obs) del adata.uns["spatialdata_attrs"] table = TableModel.parse( adata=adata, @@ -142,7 +142,7 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs: sdata_blobs_local = deepcopy(sdata_blobs) table = sdata_blobs_local["table"].copy() - table.obs["region"] = "blobs_multiscale_labels" + table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs) table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" table = table[:, ~table.var_names.isin(["channel_0_sum"])] sdata_blobs_local["multi_table"] = table @@ -187,7 +187,7 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline( def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData): table = sdata_blobs["table"].copy() - table.obs["region"] = "blobs_multiscale_labels" + table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs) table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" table = table[:, ~table.var_names.isin(["channel_0_sum"])] sdata_blobs["multi_table"] = table @@ -196,9 +196,9 @@ def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData): ).pl.show() def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: SpatialData): - max_col = sdata_blobs.table.to_df().idxmax(axis=1) - max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True) - sdata_blobs.table.obs["which_max"] = max_col + max_col = sdata_blobs["table"].to_df().idxmax(axis=1) + max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True) + sdata_blobs["table"].obs["which_max"] = max_col _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") @@ -210,9 +210,9 @@ def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: Spatia ).pl.show(ax=axs[1]) def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(self, sdata_blobs: SpatialData): - max_col = sdata_blobs.table.to_df().idxmax(axis=1) - max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True) - sdata_blobs.table.obs["which_max"] = max_col + max_col = sdata_blobs["table"].to_df().idxmax(axis=1) + max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True) + sdata_blobs["table"].obs["which_max"] = max_col _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index ad340f70..aa9d8ff7 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -37,7 +37,7 @@ def test_plot_can_render_points(self, sdata_blobs: SpatialData): def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialData): _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") - sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6) @@ -46,7 +46,7 @@ def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialD def test_plot_can_filter_with_groups_custom_palette(self, sdata_blobs: SpatialData): _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") - sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6) @@ -55,19 +55,19 @@ def test_plot_can_filter_with_groups_custom_palette(self, sdata_blobs: SpatialDa ) def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" sdata_blobs.pl.render_points( color="genes", groups=["gene_a", "gene_b"], palette=["lightgreen", "darkblue"] ).pl.show() def test_plot_coloring_with_cmap(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" sdata_blobs.pl.render_points(color="genes", cmap="rainbow").pl.show() def test_plot_can_stack_render_points(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" ( sdata_blobs.pl.render_points(element="blobs_points", na_color="red", size=30) @@ -86,7 +86,7 @@ def test_plot_points_coercable_categorical_color(self, sdata_blobs: SpatialData) adata.obs["instance_id"] = np.arange(adata.n_obs) adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) adata.obs["instance_id"] = list(range(adata.n_obs)) - adata.obs["region"] = "blobs_points" + adata.obs["region"] = pd.Categorical(["blobs_points"] * adata.n_obs) table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points") sdata_blobs["other_table"] = table @@ -100,7 +100,7 @@ def test_plot_points_categorical_color(self, sdata_blobs: SpatialData): adata.obs["instance_id"] = np.arange(adata.n_obs) adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) adata.obs["instance_id"] = list(range(adata.n_obs)) - adata.obs["region"] = "blobs_points" + adata.obs["region"] = pd.Categorical(["blobs_points"] * adata.n_obs) table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points") sdata_blobs["other_table"] = table diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 953eb843..67bc70f9 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -101,7 +101,7 @@ def _make_multi(): def test_plot_can_color_from_geodataframe(self, sdata_blobs: SpatialData): blob = deepcopy(sdata_blobs) - blob["table"].obs["region"] = "blobs_polygons" + blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs) blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" blob.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1] blob.pl.render_shapes( @@ -115,7 +115,7 @@ def test_plot_can_scale_shapes(self, sdata_blobs: SpatialData): def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData): _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") - sdata_blobs["table"].obs["region"] = "blobs_polygons" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1" sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2" @@ -129,7 +129,7 @@ def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData): ) def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_polygons" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1" sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2" @@ -142,13 +142,13 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData): ).pl.show() def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_polygons" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20] sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster").pl.show() def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_polygons" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20] norm = Normalize(vmin=0, vmax=5, clip=True) @@ -156,7 +156,7 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData): # subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included - blob = SpatialData.from_elements_dict( + blob = SpatialData.init_from_elements( { "blobs_circles": sdata_blobs.shapes["blobs_circles"], "blobs_multipolygons": sdata_blobs.shapes["blobs_multipolygons"], @@ -169,7 +169,7 @@ def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData cropped_blob.pl.render_shapes().pl.show() def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_circles" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) new_table = sdata_blobs["table"][:5] new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles" new_table.obs["instance_id"] = np.array(range(5)) @@ -188,7 +188,7 @@ def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blob sdata_blobs.pl.render_shapes("blobs_circles", color="annotation").pl.show() def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_circles" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) new_table = sdata_blobs["table"][:5].copy() new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles" new_table.obs["instance_id"] = np.array(range(5)) @@ -216,11 +216,12 @@ def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sd sdata_cropped.pl.render_shapes("blobs_circles", color="annotation").pl.show() def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_circles" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) new_table = sdata_blobs["table"][:10].copy() new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"] new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5)))) + new_table.obs["region"] = new_table.obs["region"].cat.add_categories(["blobs_polygons"]) new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons" new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"] new_table.obs["annotation"] = new_table.obs["annotation"].astype("category") @@ -232,11 +233,12 @@ def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: Spa ).pl.show() def test_plot_can_color_two_queried_shapes_elements_by_annotation(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_circles" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) new_table = sdata_blobs["table"][:10].copy() new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"] new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5)))) + new_table.obs["region"] = new_table.obs["region"].cat.add_categories(["blobs_polygons"]) new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons" new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"] new_table.obs["annotation"] = new_table.obs["annotation"].astype("category") @@ -364,7 +366,7 @@ def test_plot_can_color_by_category_with_cmap(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", cmap="cool").pl.show() def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = "blobs_polygons" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1] sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show() @@ -374,13 +376,13 @@ def test_plot_datashader_can_color_by_identical_value(self, sdata_blobs: Spatial We test this, because datashader internally scales the values, so when all shapes have the same value, the scaling would lead to all of them being assigned an alpha of 0, so we wouldn't see anything """ - sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 1, 1, 1, 1] sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show() def test_plot_datashader_shades_with_linear_cmap(self, sdata_blobs: SpatialData): - sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 2, 1, 20, 1] sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show() @@ -414,7 +416,7 @@ def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData): table_shapes = sdata_blobs["table"][:5].copy() table_shapes.obs.instance_id = list(range(5)) - table_shapes.obs["region"] = "blobs_circles" + table_shapes.obs["region"] = pd.Categorical(["blobs_circles"] * table_shapes.n_obs) table_shapes.obs["dummy_gene_expression"] = [i * 10 for i in range(5)] table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles" sdata_blobs["new_table"] = table_shapes @@ -493,7 +495,7 @@ def test_plot_datashader_can_transform_circles(self, sdata_blobs: SpatialData): def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData): table_shapes = sdata_blobs["table"][:3].copy() table_shapes.obs.instance_id = list(range(3)) - table_shapes.obs["region"] = "blobs_circles" + table_shapes.obs["region"] = pd.Categorical(["blobs_circles"] * table_shapes.n_obs) table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles" sdata_blobs["new_table"] = table_shapes diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index d8ecb000..75f4cf29 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -1,6 +1,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import pandas as pd import pytest import scanpy as sc from spatialdata import SpatialData @@ -47,7 +48,7 @@ def test_plot_set_outline_accepts_str_or_float_or_list_thereof(self, sdata_blobs def test_plot_colnames_that_are_valid_matplotlib_greyscale_colors_are_not_evaluated_as_colors( self, sdata_blobs: SpatialData, colname: str ): - sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" sdata_blobs.shapes["blobs_polygons"][colname] = [1, 2, 3, 5, 20] sdata_blobs.pl.render_shapes("blobs_polygons", color=colname).pl.show()