Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,9 @@ def _set_color_source_vec(

color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`

# TODO check why table_name is not passed here.
color_mapping = _get_categorical_color_mapping(
adata=sdata.table,
adata=sdata["table"],
cluster_key=value_to_plot,
color_source_vector=color_source_vector,
cmap_params=cmap_params,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_sdata_multiple_images_diverging_dims():
def sdata_blobs_shapes_annotated() -> SpatialData:
"""Get blobs sdata with continuous annotation of polygons."""
blob = blobs()
blob["table"].obs["region"] = "blobs_polygons"
blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs)
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
return blob
Expand Down
26 changes: 13 additions & 13 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def test_plot_can_render_labels(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels(element="blobs_labels").pl.show()

def test_plot_can_render_multiscale_labels(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
sdata_blobs.pl.render_labels("blobs_multiscale_labels").pl.show()

def test_plot_can_render_given_scale_of_multiscale_labels(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_multiscale_labels"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
sdata_blobs.pl.render_labels("blobs_multiscale_labels", scale="scale1").pl.show()

Expand All @@ -50,7 +50,7 @@ def test_plot_can_do_rasterization(self, sdata_blobs: SpatialData):
img.attrs["transform"] = sdata_blobs["blobs_labels"].transform
sdata_blobs["blobs_giant_labels"] = img

sdata_blobs["table"].obs["region"] = "blobs_giant_labels"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels"

sdata_blobs.pl.render_labels("blobs_giant_labels").pl.show()
Expand All @@ -63,7 +63,7 @@ def test_plot_can_stop_rasterization_with_scale_full(self, sdata_blobs: SpatialD
img.attrs["transform"] = sdata_blobs["blobs_labels"].transform
sdata_blobs["blobs_giant_labels"] = img

sdata_blobs["table"].obs["region"] = "blobs_giant_labels"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_giant_labels"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_giant_labels"

sdata_blobs.pl.render_labels("blobs_giant_labels", scale="full").pl.show()
Expand Down Expand Up @@ -110,7 +110,7 @@ def _make_tablemodel_with_categorical_labels(sdata_blobs, label):
max_col = max_col.str.replace("channel_", "ch").str.replace("_sum", "")
max_col = pd.Categorical(max_col, categories=set(max_col), ordered=True)
adata.obs["which_max"] = max_col
adata.obs["region"] = label
adata.obs["region"] = pd.Categorical([label] * adata.n_obs)
del adata.uns["spatialdata_attrs"]
table = TableModel.parse(
adata=adata,
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs:
sdata_blobs_local = deepcopy(sdata_blobs)

table = sdata_blobs_local["table"].copy()
table.obs["region"] = "blobs_multiscale_labels"
table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs)
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
sdata_blobs_local["multi_table"] = table
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline(

def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
table = sdata_blobs["table"].copy()
table.obs["region"] = "blobs_multiscale_labels"
table.obs["region"] = pd.Categorical(["blobs_multiscale_labels"] * table.n_obs)
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
sdata_blobs["multi_table"] = table
Expand All @@ -196,9 +196,9 @@ def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
).pl.show()

def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: SpatialData):
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
sdata_blobs.table.obs["which_max"] = max_col
max_col = sdata_blobs["table"].to_df().idxmax(axis=1)
max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True)
sdata_blobs["table"].obs["which_max"] = max_col

_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

Expand All @@ -210,9 +210,9 @@ def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: Spatia
).pl.show(ax=axs[1])

def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(self, sdata_blobs: SpatialData):
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
sdata_blobs.table.obs["which_max"] = max_col
max_col = sdata_blobs["table"].to_df().idxmax(axis=1)
max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True)
sdata_blobs["table"].obs["which_max"] = max_col

_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

Expand Down
14 changes: 7 additions & 7 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_plot_can_render_points(self, sdata_blobs: SpatialData):
def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"

sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6)
Expand All @@ -46,7 +46,7 @@ def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialD
def test_plot_can_filter_with_groups_custom_palette(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"

sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6)
Expand All @@ -55,19 +55,19 @@ def test_plot_can_filter_with_groups_custom_palette(self, sdata_blobs: SpatialDa
)

def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"
sdata_blobs.pl.render_points(
color="genes", groups=["gene_a", "gene_b"], palette=["lightgreen", "darkblue"]
).pl.show()

def test_plot_coloring_with_cmap(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"
sdata_blobs.pl.render_points(color="genes", cmap="rainbow").pl.show()

def test_plot_can_stack_render_points(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"
(
sdata_blobs.pl.render_points(element="blobs_points", na_color="red", size=30)
Expand All @@ -86,7 +86,7 @@ def test_plot_points_coercable_categorical_color(self, sdata_blobs: SpatialData)
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_points"
adata.obs["region"] = pd.Categorical(["blobs_points"] * adata.n_obs)
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points")
sdata_blobs["other_table"] = table

Expand All @@ -100,7 +100,7 @@ def test_plot_points_categorical_color(self, sdata_blobs: SpatialData):
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_points"
adata.obs["region"] = pd.Categorical(["blobs_points"] * adata.n_obs)
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points")
sdata_blobs["other_table"] = table

Expand Down
32 changes: 17 additions & 15 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _make_multi():

def test_plot_can_color_from_geodataframe(self, sdata_blobs: SpatialData):
blob = deepcopy(sdata_blobs)
blob["table"].obs["region"] = "blobs_polygons"
blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs)
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
blob.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
blob.pl.render_shapes(
Expand All @@ -115,7 +115,7 @@ def test_plot_can_scale_shapes(self, sdata_blobs: SpatialData):
def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
Expand All @@ -129,7 +129,7 @@ def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
)

def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
Expand All @@ -142,21 +142,21 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
).pl.show()

def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster").pl.show()

def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
norm = Normalize(vmin=0, vmax=5, clip=True)
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show()

def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData):
# subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included
blob = SpatialData.from_elements_dict(
blob = SpatialData.init_from_elements(
{
"blobs_circles": sdata_blobs.shapes["blobs_circles"],
"blobs_multipolygons": sdata_blobs.shapes["blobs_multipolygons"],
Expand All @@ -169,7 +169,7 @@ def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData
cropped_blob.pl.render_shapes().pl.show()

def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
new_table = sdata_blobs["table"][:5]
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
new_table.obs["instance_id"] = np.array(range(5))
Expand All @@ -188,7 +188,7 @@ def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blob
sdata_blobs.pl.render_shapes("blobs_circles", color="annotation").pl.show()

def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
new_table = sdata_blobs["table"][:5].copy()
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
new_table.obs["instance_id"] = np.array(range(5))
Expand Down Expand Up @@ -216,11 +216,12 @@ def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sd
sdata_cropped.pl.render_shapes("blobs_circles", color="annotation").pl.show()

def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
new_table = sdata_blobs["table"][:10].copy()
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

new_table.obs["region"] = new_table.obs["region"].cat.add_categories(["blobs_polygons"])
new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons"
new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")
Expand All @@ -232,11 +233,12 @@ def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: Spa
).pl.show()

def test_plot_can_color_two_queried_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
new_table = sdata_blobs["table"][:10].copy()
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

new_table.obs["region"] = new_table.obs["region"].cat.add_categories(["blobs_polygons"])
new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons"
new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")
Expand Down Expand Up @@ -364,7 +366,7 @@ def test_plot_can_color_by_category_with_cmap(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", cmap="cool").pl.show()

def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()
Expand All @@ -374,13 +376,13 @@ def test_plot_datashader_can_color_by_identical_value(self, sdata_blobs: Spatial
We test this, because datashader internally scales the values, so when all shapes have the same value,
the scaling would lead to all of them being assigned an alpha of 0, so we wouldn't see anything
"""
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 1, 1, 1, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()

def test_plot_datashader_shades_with_linear_cmap(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 2, 1, 20, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()
Expand Down Expand Up @@ -414,7 +416,7 @@ def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs:
def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData):
table_shapes = sdata_blobs["table"][:5].copy()
table_shapes.obs.instance_id = list(range(5))
table_shapes.obs["region"] = "blobs_circles"
table_shapes.obs["region"] = pd.Categorical(["blobs_circles"] * table_shapes.n_obs)
table_shapes.obs["dummy_gene_expression"] = [i * 10 for i in range(5)]
table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["new_table"] = table_shapes
Expand Down Expand Up @@ -493,7 +495,7 @@ def test_plot_datashader_can_transform_circles(self, sdata_blobs: SpatialData):
def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData):
table_shapes = sdata_blobs["table"][:3].copy()
table_shapes.obs.instance_id = list(range(3))
table_shapes.obs["region"] = "blobs_circles"
table_shapes.obs["region"] = pd.Categorical(["blobs_circles"] * table_shapes.n_obs)
table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["new_table"] = table_shapes

Expand Down
Loading
Loading