From c4940e6039a293ffdeb1604a107119259716e806 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 23 Sep 2025 23:03:20 +0200 Subject: [PATCH 1/5] moved changes over --- src/spatialdata_plot/pl/render.py | 2 +- src/spatialdata_plot/pl/utils.py | 100 +----------------------------- 2 files changed, 2 insertions(+), 100 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 21ed8226..abf0b704 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -189,7 +189,7 @@ def _render_shapes( element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system) tm = _get_transformation_matrix_for_datashader(element_trans) transformed_element = sdata_filt.shapes[element].transform( - lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] + lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2] ) transformed_element = ShapesModel.parse( gpd.GeoDataFrame( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 574ca56b..2a79ee66 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -66,10 +66,8 @@ from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement -# from spatialdata.transformations.transformations import Scale -from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation -from spatialdata.transformations import Sequence as SDSequence from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree from spatialdata_plot._logging import logger @@ -2381,102 +2379,6 @@ def _prepare_transformation( return trans, trans_data -def _get_datashader_trans_matrix_of_single_element( - trans: Identity | Scale | Affine | MapAxis | Translation, -) -> npt.NDArray[Any]: - flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) - tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y")) - - if isinstance(trans, Identity): - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - if isinstance(trans, (Scale | Affine)): - # idea: "flip the y-axis", apply transformation, flip back - flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix - return flip_and_transform - if isinstance(trans, MapAxis): - # no flipping needed - return tm - # for a Translation, we need the transposed transformation matrix - tm_T = tm.T - assert isinstance(tm_T, np.ndarray) - return tm_T - - -def _get_transformation_matrix_for_datashader( - trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, -) -> npt.NDArray[Any]: - """Get the affine matrix needed to transform shapes for rendering with datashader.""" - if isinstance(trans, SDSequence): - tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - for x in trans.transformations: - tm = tm @ _get_datashader_trans_matrix_of_single_element(x) - return tm - return _get_datashader_trans_matrix_of_single_element(trans) - - -def _datashader_map_aggregate_to_color( - agg: DataArray, - cmap: str | list[str] | ListedColormap, - color_key: None | list[str] = None, - min_alpha: float = 40, - span: None | list[float] = None, - clip: bool = True, -) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: - """ds.tf.shade() part, ensuring correct clipping behavior. - - If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. - This ensures the correct clipping behavior, because else datashader would always automatically clip. - """ - if not clip and isinstance(cmap, Colormap) and span is not None: - # in case we use datashader together with a Normalize object where clip=False - # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 - agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) - img_in = ds.tf.shade( - agg_in, - cmap=cmap, - span=(span[0], span[1]), - how="linear", - color_key=color_key, - min_alpha=min_alpha, - ) - - agg_under = agg.where(agg < span[0]) - img_under = ds.tf.shade( - agg_under, - cmap=[to_hex(cmap.get_under())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) - - agg_over = agg.where(agg > span[1]) - img_over = ds.tf.shade( - agg_over, - cmap=[to_hex(cmap.get_over())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) - - # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 - stack = img_under.to_numpy().base - if stack is None: - stack = img_in.to_numpy().base - else: - stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] - img_over = img_over.to_numpy().base - if img_over is not None: - stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] - return stack - - return ds.tf.shade( - agg, - cmap=cmap, - color_key=color_key, - min_alpha=min_alpha, - span=span, - how="linear", - ) - - def _hex_no_alpha(hex: str) -> str: """ Return a hex color string without an alpha component. From ef16e29b0326fab494dce7ea13d31a95fadaba09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 21:04:43 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata_plot/pl/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 2a79ee66..b0227003 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -65,7 +65,6 @@ from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement - from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree From c58c465cc17665f1852f2f6ef3f80b9d85e64d2a Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 23 Sep 2025 23:13:31 +0200 Subject: [PATCH 3/5] added second missing function --- src/spatialdata_plot/pl/render.py | 3 +- src/spatialdata_plot/pl/utils.py | 62 +++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index abf0b704..f9ac5df7 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -186,8 +186,7 @@ def _render_shapes( sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points - element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system) - tm = _get_transformation_matrix_for_datashader(element_trans) + tm = trans.get_matrix() transformed_element = sdata_filt.shapes[element].transform( lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2] ) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index b0227003..5c97055c 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2378,6 +2378,68 @@ def _prepare_transformation( return trans, trans_data +def _datashader_map_aggregate_to_color( + agg: DataArray, + cmap: str | list[str] | ListedColormap, + color_key: None | list[str] = None, + min_alpha: float = 40, + span: None | list[float] = None, + clip: bool = True, +) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: + """ds.tf.shade() part, ensuring correct clipping behavior. + + If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. + This ensures the correct clipping behavior, because else datashader would always automatically clip. + """ + if not clip and isinstance(cmap, Colormap) and span is not None: + # in case we use datashader together with a Normalize object where clip=False + # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 + agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) + img_in = ds.tf.shade( + agg_in, + cmap=cmap, + span=(span[0], span[1]), + how="linear", + color_key=color_key, + min_alpha=min_alpha, + ) + + agg_under = agg.where(agg < span[0]) + img_under = ds.tf.shade( + agg_under, + cmap=[to_hex(cmap.get_under())[:7]], + min_alpha=min_alpha, + color_key=color_key, + ) + + agg_over = agg.where(agg > span[1]) + img_over = ds.tf.shade( + agg_over, + cmap=[to_hex(cmap.get_over())[:7]], + min_alpha=min_alpha, + color_key=color_key, + ) + + # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 + stack = img_under.to_numpy().base + if stack is None: + stack = img_in.to_numpy().base + else: + stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] + img_over = img_over.to_numpy().base + if img_over is not None: + stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] + return stack + + return ds.tf.shade( + agg, + cmap=cmap, + color_key=color_key, + min_alpha=min_alpha, + span=span, + how="linear", + ) + def _hex_no_alpha(hex: str) -> str: """ Return a hex color string without an alpha component. From e053cb9ed47d5b7b6d9fd8cec2f5023418527316 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 21:14:27 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata_plot/pl/render.py | 3 +-- src/spatialdata_plot/pl/utils.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f9ac5df7..23d06ce5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -19,7 +19,7 @@ from scanpy._settings import settings as sc_settings from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata.models import PointsModel, ShapesModel, get_table_keys -from spatialdata.transformations import get_transformation, set_transformation +from spatialdata.transformations import set_transformation from spatialdata.transformations.transformations import Identity from xarray import DataTree @@ -44,7 +44,6 @@ _get_colors_for_categorical_obs, _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, - _get_transformation_matrix_for_datashader, _hex_no_alpha, _is_coercable_to_float, _map_color_seg, diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 5c97055c..1368110d 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2440,6 +2440,7 @@ def _datashader_map_aggregate_to_color( how="linear", ) + def _hex_no_alpha(hex: str) -> str: """ Return a hex color string without an alpha component. From 65f35e1fc176161120e7e3fe56b7d1122a6a3ef4 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 23 Sep 2025 23:14:55 +0200 Subject: [PATCH 5/5] removed unused func --- src/spatialdata_plot/pl/render.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f9ac5df7..5bc81e4f 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -44,7 +44,6 @@ _get_colors_for_categorical_obs, _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, - _get_transformation_matrix_for_datashader, _hex_no_alpha, _is_coercable_to_float, _map_color_seg,