diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 6ada5397..76e57d82 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -570,6 +570,23 @@ def _render_points( coords += [col_for_color] points = points[coords].compute() + added_color_from_table = False + if col_for_color is not None and col_for_color not in points.columns: + color_values = get_values( + value_key=col_for_color, + sdata=sdata_filt, + element_name=element, + table_name=table_name, + table_layer=table_layer, + ) + points = points.merge( + color_values[[col_for_color]], + how="left", + left_index=True, + right_index=True, + ) + added_color_from_table = True + if groups is not None and col_for_color is not None: if col_for_color in points.columns: points_color_values = points[col_for_color] @@ -588,6 +605,14 @@ def _render_points( if len(points) <= 0: raise ValueError(f"None of the groups {groups} could be found in the column '{col_for_color}'.") + n_points = len(points) + points_pd_with_color = points + points_for_model = ( + points_pd_with_color.drop(columns=[col_for_color], errors="ignore") + if added_color_from_table and col_for_color is not None + else points_pd_with_color + ) + # we construct an anndata to hack the plotting functions if table_name is None: adata = AnnData( @@ -617,7 +642,7 @@ def _render_points( # Convert back to dask dataframe to modify sdata transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system] - points = dask.dataframe.from_pandas(points, npartitions=1) + points = dask.dataframe.from_pandas(points_for_model, npartitions=1) sdata_filt.points[element] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"}) # restore transformation in coordinate system of interest set_transformation( @@ -658,6 +683,16 @@ def _render_points( render_type="points", ) + if added_color_from_table and col_for_color is not None: + points_with_color_dd = dask.dataframe.from_pandas(points_pd_with_color, npartitions=1) + sdata_filt.points[element] = PointsModel.parse(points_with_color_dd, coordinates={"x": "x", "y": "y"}) + set_transformation( + element=sdata_filt.points[element], + transformation=transformation_in_cs, + to_coordinate_system=coordinate_system, + ) + points = points_with_color_dd + # color_source_vector is None when the values aren't categorical if color_source_vector is None and render_params.transfunc is not None: color_vector = render_params.transfunc(color_vector) @@ -669,7 +704,7 @@ def _render_points( method = render_params.method if method is None: - method = "datashader" if len(points) > 10000 else "matplotlib" + method = "datashader" if n_points > 10000 else "matplotlib" if method != "matplotlib": # we only notify the user when we switched away from matplotlib diff --git a/tests/_images/Points_datashader_colors_from_table_obs.png b/tests/_images/Points_datashader_colors_from_table_obs.png new file mode 100644 index 00000000..35fc36fd Binary files /dev/null and b/tests/_images/Points_datashader_colors_from_table_obs.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 71a6c4e3..5e3fc38a 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -178,6 +178,32 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData): method="datashader", ).pl.show() + def test_plot_datashader_colors_from_table_obs(self, sdata_blobs: SpatialData): + n_obs = len(sdata_blobs["blobs_points"]) + obs = pd.DataFrame( + { + "instance_id": np.arange(n_obs), + "region": pd.Categorical(["blobs_points"] * n_obs), + "foo": pd.Categorical(np.where(np.arange(n_obs) % 2 == 0, "a", "b")), + } + ) + + table = TableModel.parse( + adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs), + region="blobs_points", + region_key="region", + instance_key="instance_id", + ) + sdata_blobs["datashader_table"] = table + + sdata_blobs.pl.render_points( + "blobs_points", + color="foo", + table_name="datashader_table", + method="datashader", + size=5, + ).pl.show() + def test_plot_datashader_can_use_sum_as_reduction(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_points( element="blobs_points", @@ -487,3 +513,31 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): table_name="other_table", ).pl.show() ) + + +def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData): + # Fast regression for https://github.com/scverse/spatialdata-plot/issues/479. + n_obs = len(sdata_blobs["blobs_points"]) + obs = pd.DataFrame( + { + "instance_id": np.arange(n_obs), + "region": pd.Categorical(["blobs_points"] * n_obs), + "foo": pd.Categorical(np.where(np.arange(n_obs) % 2 == 0, "a", "b")), + } + ) + + table = TableModel.parse( + adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs), + region="blobs_points", + region_key="region", + instance_key="instance_id", + ) + sdata_blobs["datashader_table"] = table + + sdata_blobs.pl.render_points( + "blobs_points", + color="foo", + table_name="datashader_table", + method="datashader", + size=5, + ).pl.show()