diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 21ed8226..23d06ce5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -19,7 +19,7 @@ from scanpy._settings import settings as sc_settings from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata.models import PointsModel, ShapesModel, get_table_keys -from spatialdata.transformations import get_transformation, set_transformation +from spatialdata.transformations import set_transformation from spatialdata.transformations.transformations import Identity from xarray import DataTree @@ -44,7 +44,6 @@ _get_colors_for_categorical_obs, _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, - _get_transformation_matrix_for_datashader, _hex_no_alpha, _is_coercable_to_float, _map_color_seg, @@ -186,10 +185,9 @@ def _render_shapes( sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points - element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system) - tm = _get_transformation_matrix_for_datashader(element_trans) + tm = trans.get_matrix() transformed_element = sdata_filt.shapes[element].transform( - lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] + lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2] ) transformed_element = ShapesModel.parse( gpd.GeoDataFrame( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 574ca56b..1368110d 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -65,11 +65,8 @@ from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement - -# from spatialdata.transformations.transformations import Scale -from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation -from spatialdata.transformations import Sequence as SDSequence from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree from spatialdata_plot._logging import logger @@ -2381,39 +2378,6 @@ def _prepare_transformation( return trans, trans_data -def _get_datashader_trans_matrix_of_single_element( - trans: Identity | Scale | Affine | MapAxis | Translation, -) -> npt.NDArray[Any]: - flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) - tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y")) - - if isinstance(trans, Identity): - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - if isinstance(trans, (Scale | Affine)): - # idea: "flip the y-axis", apply transformation, flip back - flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix - return flip_and_transform - if isinstance(trans, MapAxis): - # no flipping needed - return tm - # for a Translation, we need the transposed transformation matrix - tm_T = tm.T - assert isinstance(tm_T, np.ndarray) - return tm_T - - -def _get_transformation_matrix_for_datashader( - trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, -) -> npt.NDArray[Any]: - """Get the affine matrix needed to transform shapes for rendering with datashader.""" - if isinstance(trans, SDSequence): - tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - for x in trans.transformations: - tm = tm @ _get_datashader_trans_matrix_of_single_element(x) - return tm - return _get_datashader_trans_matrix_of_single_element(trans) - - def _datashader_map_aggregate_to_color( agg: DataArray, cmap: str | list[str] | ListedColormap,