Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
from spatialdata.transformations import get_transformation, set_transformation
from spatialdata.transformations import set_transformation
from spatialdata.transformations.transformations import Identity
from xarray import DataTree

Expand All @@ -44,7 +44,6 @@
_get_colors_for_categorical_obs,
_get_extent_and_range_for_datashader_canvas,
_get_linear_colormap,
_get_transformation_matrix_for_datashader,
_hex_no_alpha,
_is_coercable_to_float,
_map_color_seg,
Expand Down Expand Up @@ -186,10 +185,9 @@ def _render_shapes(
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
tm = _get_transformation_matrix_for_datashader(element_trans)
tm = trans.get_matrix()
transformed_element = sdata_filt.shapes[element].transform(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
)
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(
Expand Down
38 changes: 1 addition & 37 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@
from spatialdata._core.query.relational_query import _locate_value
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement

# from spatialdata.transformations.transformations import Scale
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
from spatialdata.transformations import Sequence as SDSequence
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from xarray import DataArray, DataTree

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -2381,39 +2378,6 @@ def _prepare_transformation(
return trans, trans_data


def _get_datashader_trans_matrix_of_single_element(
trans: Identity | Scale | Affine | MapAxis | Translation,
) -> npt.NDArray[Any]:
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))

if isinstance(trans, Identity):
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if isinstance(trans, (Scale | Affine)):
# idea: "flip the y-axis", apply transformation, flip back
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
return flip_and_transform
if isinstance(trans, MapAxis):
# no flipping needed
return tm
# for a Translation, we need the transposed transformation matrix
tm_T = tm.T
assert isinstance(tm_T, np.ndarray)
return tm_T


def _get_transformation_matrix_for_datashader(
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
) -> npt.NDArray[Any]:
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
if isinstance(trans, SDSequence):
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
for x in trans.transformations:
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
return tm
return _get_datashader_trans_matrix_of_single_element(trans)


def _datashader_map_aggregate_to_color(
agg: DataArray,
cmap: str | list[str] | ListedColormap,
Expand Down