diff --git a/sandbox_harald.ipynb b/sandbox_harald.ipynb index 044dc7cc..4e3a2189 100644 --- a/sandbox_harald.ipynb +++ b/sandbox_harald.ipynb @@ -2,10 +2,27 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 36, "id": "125833fd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tim.treis/Documents/GitHub/spatialdata-plot/src/spatialdata_plot/pl/basic.py:11: AccessorRegistrationWarning: registration of accessor under name 'pl' for type is overriding a preexisting attribute with the same name.\n", + " \n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -13,33 +30,31 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 37, "id": "633f35c7", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tim.treis/anaconda3/envs/spatialdata/lib/python3.10/site-packages/geopandas/_compat.py:123: UserWarning: The Shapely GEOS version (3.11.1-CAPI-1.17.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.\n", - " warnings.warn(\n", - "/Users/tim.treis/Documents/GitHub/spatialdata/spatialdata/_compat.py:18: UserWarning: Geopandas was set to use PyGEOS, changing to shapely 2.0 with:\n", - "\n", - "\tgeopandas.options.use_pygeos = True\n", - "\n", - "If you intended to use PyGEOS, set the option to False.\n", - " warnings.warn(\n", - "/Users/tim.treis/anaconda3/envs/spatialdata/lib/python3.10/site-packages/anndata/experimental/pytorch/_annloader.py:18: UserWarning: Сould not load pytorch.\n", - " warnings.warn(\"Сould not load pytorch.\")\n" - ] - } - ], + "outputs": [], "source": [ "import spatialdata as sd\n", "import spatialdata_plot\n", + "import numpy as np\n", + "import pyarrow as pa\n", + "from anndata import AnnData\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", - "from typing import Union" + "from typing import Union\n", + "from numpy.random import default_rng\n", + "\n", + "RNG = default_rng()\n", + "from spatialdata._core.models import (\n", + " Image2DModel,\n", + " Labels2DModel,\n", + " Labels3DModel,\n", + " PointsModel,\n", + " PolygonsModel,\n", + " ShapesModel,\n", + " TableModel,\n", + ")\n" ] }, { @@ -52,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 38, "id": "86f42824", "metadata": {}, "outputs": [ @@ -60,9 +75,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "no parent found for : None\n", - "no parent found for : None\n", - "no parent found for : None\n", + "no parent found for : None\n", + "no parent found for : None\n", + "no parent found for : None\n", "/Users/tim.treis/anaconda3/envs/spatialdata/lib/python3.10/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n", " utils.warn_names_duplicates(\"obs\")\n" ] @@ -78,47 +93,251 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 39, + "id": "0519ddd6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/qg/qgc908995g3fc8qtss2fsbhhxyxxj4/T/ipykernel_10284/749062723.py:38: FutureWarning: X.dtype being converted to np.float32 from float64. In the next version of anndata (0.9) conversion will not be automatic. Pass dtype explicitly to avoid this warning. Pass `AnnData(X, dtype=X.dtype, ...)` to get the future behavour.\n", + " adata = AnnData(RNG.normal(size=(30, 10)), obs=pd.DataFrame(RNG.normal(size=(30, 3)), columns=[\"a\", \"b\", \"c\"]))\n", + "/Users/tim.treis/anaconda3/envs/spatialdata/lib/python3.10/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", + " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n" + ] + }, + { + "data": { + "text/plain": [ + "pyarrow.Table\n", + "x: double\n", + "y: double\n", + "points_assignment0: int64\n", + "----\n", + "x: [[0.8413944235105699,0.5615364450210023,-0.5556994053539558,1.3616613476380617,1.3414766906774243,...,-0.5651286351605515,1.2442477539941565,-1.008172775800909,0.9597715912866557,-0.46395487249913125]]\n", + "y: [[1.033588506324896,-1.3432120527072624,-0.3173451786855313,-1.0882631464969956,-1.1932496234882468,...,-0.6987319002728575,0.22930304650553632,0.6932577098741612,-0.5447879481915016,0.32888317716312865]]\n", + "points_assignment0: [[0,1,1,2,2,...,0,1,2,0,2]]" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def _get_points() -> dict[str, pa.Table]:\n", + " name = \"points\"\n", + " var_names = [np.arange(3), [\"genex\", \"geney\"]]\n", + "\n", + " out = {}\n", + " for i, v in enumerate(var_names):\n", + " name = f\"{name}_{i}\"\n", + " arr = RNG.normal(size=(1000, 2))\n", + " # randomly assign some values from v to the points\n", + " points_assignment0 = pd.Series(RNG.choice(v, size=arr.shape[0]))\n", + " annotations = pa.table(\n", + " {\"points_assignment0\": points_assignment0},\n", + " )\n", + " out[name] = PointsModel.parse(coords=arr, annotations=annotations)\n", + " return out\n", + "\n", + "\n", + "from geopandas import GeoDataFrame\n", + "from shapely.geometry import MultiPolygon, Polygon\n", + "\n", + "\n", + "\n", + "GeoDataFrame(\n", + " {\n", + " \"geometry\": [\n", + " Polygon(((0, 0), (0, 10), (10, 10), (10, 0))),\n", + " ]\n", + " }\n", + ")\n", + "\n", + "images = { \n", + " 'data1': sd.Image2DModel.parse(np.random.normal(size=(1, 100, 100)), dims=('c', 'y', 'x')),\n", + "} \n", + "\n", + "instance_key = \"instance_id\"\n", + "region_key = \"annotated_region\"\n", + "\n", + "adata = AnnData(RNG.normal(size=(30, 10)), obs=pd.DataFrame(RNG.normal(size=(30, 3)), columns=[\"a\", \"b\", \"c\"]))\n", + "adata.obs[instance_key] = [\"data1\"] * 3 + [\"data2\"] * 7 + [\"data3\"] * 20\n", + "adata.obs[region_key] = [\"data1\"] * 3 + [\"data2\"] * 7 + [\"data3\"] * 20\n", + "table = TableModel.parse(adata=adata, instance_key=instance_key, region_key=region_key)\n", + "sdata = sd.SpatialData(images=images, table=table, points=_get_points())\n", + "\n", + "sdata.points[\"points_0\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "d0dceb0b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "POLYGON ((2435.069566 6500.284703, 2476.390162 6535.178868, 2534.993939 6471.307336, 2743.287325 6256.796237, 2892.512438 6045.393994, 3110.132394 5597.718655, 3175.960311 5399.368749, 3130.896017 5374.899385, 3077.489401 5557.30352, 2850.542875 6011.196572, 2712.19876 6214.826674, 2493.024376 6434.001058, 2435.069566 6500.284703))\n", + "POLYGON ((2099.296063 6218.987584, 2143.972428 6257.045228, 2145.859029 6257.045228, 2146.338866 6256.459168, 2151.558392 6260.866921, 2210.003484 6332.509292, 2221.323093 6321.189682, 2222.09221 6320.430889, 2259.341867 6351.887251, 2357.789117 6236.588669, 2668.674769 5824.665181, 2827.226451 5490.463105, 2914.2115 5257.242323, 2873.763161 5235.279366, 2874.087247 5234.50716, 2845.513743 5219.940275, 2776.791914 5182.625157, 2600.279926 5537.095953, 2404.421965 5857.308174, 2099.296063 6218.987584))\n", + "POLYGON ((1915.9413 6061.89444, 1940.749011 6082.843905, 1940.219455 6083.477881, 2099.296063 6218.987584, 2404.421965 5857.308174, 2600.279926 5537.095953, 2776.791914 5182.625157, 2524.716288 5045.751152, 2468.153524 5171.805312, 2258.305709 5529.323812, 1917.88592 6059.383848, 1915.9413 6061.89444))\n", + "POLYGON ((1751.729672 5923.22201, 1915.9413 6061.89444, 1917.88592 6059.383848, 2258.305709 5529.323812, 2468.153524 5171.805312, 2524.716288 5045.751152, 2356.035402 4954.159479, 2281.622133 5132.944606, 2135.505877 5387.87084, 1911.668207 5717.409631, 1751.729672 5923.22201))\n", + "POLYGON ((1685.872345 5868.689447, 1743.695193 5916.437102, 1751.729672 5923.22201, 1911.668207 5717.409631, 2135.505877 5387.87084, 2281.622133 5132.944606, 2356.035402 4954.159479, 2299.438471 4923.428033, 2295.815906 4922.282422, 2190.900796 5146.937604, 2029.934809 5399.884154, 1792.318353 5710.318557, 1685.872345 5868.689447))\n", + "POLYGON ((2259.341867 6351.887251, 2435.069566 6500.284703, 2493.024376 6434.001058, 2712.19876 6214.826674, 2850.542875 6011.196572, 3077.489401 5557.30352, 3130.896017 5374.899385, 2914.2115 5257.242323, 2827.226451 5490.463105, 2668.674769 5824.665181, 2357.789117 6236.588669, 2259.341867 6351.887251))\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tim.treis/Documents/GitHub/spatialdata-plot/src/spatialdata_plot/pl/basic.py:11: AccessorRegistrationWarning: registration of accessor under name 'pl' for type is overriding a preexisting attribute with the same name.\n", + " class PlotAccessor:\n" + ] + } + ], + "source": [ + "for key, value in merfish.polygons.items():\n", + " for geometry in value.geometry:\n", + " print(geometry)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, "id": "0eecde65", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'library_id'" + "SpatialData object with:\n", + "├── Images\n", + "│ └── 'rasterized': SpatialImage[cyx] (1, 522, 575)\n", + "├── Points\n", + "│ └── 'single_molecule': pyarrow.Table shape: (3714642, 3) (2D points)\n", + "├── Polygons\n", + "│ └── 'anatomical': GeoDataFrame shape: (6, 1) (2D polygons)\n", + "├── Shapes\n", + "│ └── 'cells': AnnData with osbm.spatial (2399, 2)\n", + "└── Table\n", + " └── AnnData object with n_obs × n_vars = 2399 × 268\n", + " obs: 'cell_id'\n", + " uns: 'spatialdata_attrs': AnnData (2399, 268)" ] }, - "execution_count": 5, + "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "mibi.pp.get_region_key()" + "merfish# .pl.imshow().pl.render_polygon()\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 70, + "id": "35def7ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(0.267004, 0.004874, 0.329415, 1.0)]" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmap = plt.cm.viridis\n", + "[cmap(i) for i in np.linspace(0, 1, len(merfish.polygons.keys()))]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, "id": "2a6d74c0", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['point16', 'point23', 'point8']" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "channels = [1,2]\n", + "\n", + "image_names = []\n", + "n_channels = []\n", + "for image_name, image in mibi.images.items():\n", + " image_names.append(image_name)\n", + " n_channels.append(image.shape[0])\n", + " \n", + "channels_in_image = pd.DataFrame({\n", + " \"image_name\": image_names,\n", + " \"n_channels\": n_channels\n", + "})\n", + " \n", + "# 3) drop images that don't have enough channels for the selection\n", + "channels_in_image = channels_in_image.query(\"n_channels-1 >= \" + str(max(channels)))\n", + "valid_images = channels_in_image.image_name.values.tolist() \n", + "\n", + "valid_images\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "cc2a2d5f", + "metadata": {}, "outputs": [ { "ename": "NameError", - "evalue": "name 'mibitof' is not defined", + "evalue": "name 'imc' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m mibitof\n", - "\u001b[0;31mNameError\u001b[0m: name 'mibitof' is not defined" + "Cell \u001b[0;32mIn[44], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m imc\n", + "\u001b[0;31mNameError\u001b[0m: name 'imc' is not defined" ] } ], "source": [ - "mibitof" + "imc" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d258be1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0b78a4b", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -143,7 +362,7 @@ " uns: 'spatialdata_attrs': AnnData (2399, 268)" ] }, - "execution_count": 6, + "execution_count": 67, "metadata": {}, "output_type": "execute_result" } @@ -174,7 +393,7 @@ " uns: 'spatialdata_attrs': AnnData (5906, 31053)" ] }, - "execution_count": 7, + "execution_count": 68, "metadata": {}, "output_type": "execute_result" } @@ -220,7 +439,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mPathNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m codex \u001b[39m=\u001b[39m sd\u001b[39m.\u001b[39;49mread_zarr(\u001b[39m'\u001b[39;49m\u001b[39m/home/voehring/voehring/projects/2023-01-15_spatial_data/codex_all.zarr\u001b[39;49m\u001b[39m'\u001b[39;49m)\n", + "Cell \u001b[0;32mIn[69], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m codex \u001b[39m=\u001b[39m sd\u001b[39m.\u001b[39;49mread_zarr(\u001b[39m'\u001b[39;49m\u001b[39m/home/voehring/voehring/projects/2023-01-15_spatial_data/codex_all.zarr\u001b[39;49m\u001b[39m'\u001b[39;49m)\n", "File \u001b[0;32m~/Documents/GitHub/spatialdata/spatialdata/_io/read.py:81\u001b[0m, in \u001b[0;36mread_zarr\u001b[0;34m(store)\u001b[0m\n\u001b[1;32m 77\u001b[0m store \u001b[39m=\u001b[39m Path(store)\n\u001b[1;32m 79\u001b[0m fmt \u001b[39m=\u001b[39m SpatialDataFormatV01()\n\u001b[0;32m---> 81\u001b[0m f \u001b[39m=\u001b[39m zarr\u001b[39m.\u001b[39;49mopen(store, mode\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mr\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m 82\u001b[0m images \u001b[39m=\u001b[39m {}\n\u001b[1;32m 83\u001b[0m labels \u001b[39m=\u001b[39m {}\n", "File \u001b[0;32m~/anaconda3/envs/spatialdata/lib/python3.10/site-packages/zarr/convenience.py:122\u001b[0m, in \u001b[0;36mopen\u001b[0;34m(store, mode, zarr_version, path, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[39mreturn\u001b[39;00m open_group(_store, mode\u001b[39m=\u001b[39mmode, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 122\u001b[0m \u001b[39mraise\u001b[39;00m PathNotFoundError(path)\n", "\u001b[0;31mPathNotFoundError\u001b[0m: nothing found at path ''" diff --git a/src/spatialdata_plot/__init__.py b/src/spatialdata_plot/__init__.py index db6f0d22..738bf465 100644 --- a/src/spatialdata_plot/__init__.py +++ b/src/spatialdata_plot/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version -from . import pl, pp, tl +from . import pl, pp, tl, utils __all__ = ["pl", "pp", "tl"] diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index a1050b93..ca6c455b 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1,4 +1,7 @@ +from typing import Union + import numpy as np +import matplotlib from matplotlib import pyplot as plt from ..accessor import register_spatial_data_accessor @@ -36,8 +39,69 @@ def imshow(self, ax=None, ncols=4, width=4, height=3, **kwargs): return self._sdata - def test_plot(self): - plt.plot(np.arange(10), np.arange(10)) + def render_polygon( + self, + ax: Union[matplotlib.axes.Axes, list[matplotlib.axes.Axes]] = None, + cmap=plt.cm.viridis, + alpha_boundary: float = 1.0, + alpha_fill: float = 0.3, + split_by=True, + **kwargs, + ): + + if ax is not None: + + if not (isinstance(ax, matplotlib.axes.Axes) or all([isinstance(a, matplotlib.axes.Axes) for a in ax])): + + raise TypeError("Parameter 'ax' must be one or more objects of of type 'matplotlib.axes.Axes'.") + + if not isinstance(cmap, matplotlib.colors.Colormap): + + raise TypeError("Parameter 'cmap' must be of type 'matplotlib.colors.Colormap'.") + + if not isinstance(alpha_boundary, (int, float)): + + raise TypeError("Parameter 'alpha_boundary' must be numeric.") + + if not (0 <= alpha_boundary <= 1): + + raise ValueError("Parameter 'alpha_boundary' must be between 0 and 1.") - def scatter(self): - plt.scatter(np.random.randn(20), np.random.randn(20)) + if not (0 <= alpha_fill <= 1): + + raise ValueError("Parameter 'alpha_fill' must be between 0 and 1.") + + if not isinstance(alpha_fill, (int, float)): + + raise TypeError("Parameter 'alpha_fill' must be numeric.") + + if not isinstance(split_by, bool): + + raise TypeError("Parameter 'split_by' must be of type 'bool'.") + + # TODO(ttreis): figure out nesting of geometries + # TODO(ttreis): figure out how to handle multiple polygon colours + # TODO(ttreis): include cmap + # TODO(ttreis): include optimal tiling if split + + # Figure out how many polygons to plot there are + + + if split_by: + for key, value in self._sdata.polygons.items(): + ax = ax or plt.gca() + for geometry in value.geometry: + ( + x, + y, + ) = geometry.exterior.xy + ax.plot(x, y, alpha=alpha_boundary, **kwargs) + + ( + x, + y, + ) = geometry.exterior.xy + ax.fill(x, y, alpha=alpha_fill, **kwargs) + + + return self._sdata diff --git a/src/spatialdata_plot/pp/basic.py b/src/spatialdata_plot/pp/basic.py index ca75defb..f361133b 100644 --- a/src/spatialdata_plot/pp/basic.py +++ b/src/spatialdata_plot/pp/basic.py @@ -2,12 +2,17 @@ import spatialdata as sd from anndata import AnnData +from geopandas import GeoDataFrame from ..accessor import register_spatial_data_accessor from .colorize import _colorize from .render import _render_label from .utils import _get_listed_colormap +from spatialdata._core.models import ( + PolygonsModel, +) + @register_spatial_data_accessor("pp") class PreprocessingAccessor: @@ -23,12 +28,12 @@ def _copy( shapes: Union[None, dict] = None, table: Union[dict, AnnData] = None, ) -> sd.SpatialData: - + """ Helper function to copies the references from the original SpatialData object to the subsetted SpatialData object. """ - + return sd.SpatialData( images=self._sdata.images if images is None else images, labels=self._sdata.labels if labels is None else labels, @@ -39,13 +44,13 @@ def _copy( ) def get_region_key(self) -> str: - + "Quick access to the data's region key." - + return self._sdata.table.uns["spatialdata_attrs"]["region_key"] def get_bb(self, x: Union[slice, list, tuple], y: Union[slice, list, tuple]) -> sd.SpatialData: - + """Get bounding box around a point. Parameters @@ -60,55 +65,53 @@ def get_bb(self, x: Union[slice, list, tuple], y: Union[slice, list, tuple]) -> sd.SpatialData subsetted SpatialData object """ - + if not isinstance(x, (slice, list, tuple)): - + raise TypeError("Parameter 'x' must be one of 'slice', 'list', 'tuple'.") - + if isinstance(x, (list, tuple)): - + if len(x) != 2: - + raise ValueError("Parameter 'x' must be of length 2.") - + if x[1] <= x[0]: - + raise ValueError("The current choice of 'x' would result in an empty slice.") - + # x is clean x = slice(x[0], x[1]) - + elif isinstance(x, slice): - + if x.stop <= x.start: - + raise ValueError("The current choice of 'x' would result in an empty slice.") - - + if not isinstance(y, (slice, list, tuple)): - + raise TypeError("Parameter 'y' must be one of 'slice', 'list', 'tuple'.") - + if isinstance(y, (list, tuple)): - + if len(y) != 2: - + raise ValueError("Parameter 'y' must be of length 2.") - + if y[1] <= y[0]: - + raise ValueError("The current choice of 'y' would result in an empty slice.") - + # y is clean y = slice(y[0], y[1]) - + elif isinstance(y, slice): - + if y.stop <= y.start: - + raise ValueError("The current choice of 'x' would result in an empty slice.") - - + selection = dict(x=x, y=y) # makes use of xarray sel method # TODO: error handling if selection is out of bounds @@ -138,26 +141,26 @@ def get_images(self, keys: Union[list, str], label_func: Callable = lambda x: x) # TODO: error handling if keys are not in images if not isinstance(keys, (list, str)): - + raise TypeError("Parameter 'keys' must either be of type 'str' or 'list'.") - if isinstance(keys, list): - + if isinstance(keys, list): + if not all([isinstance(key, str) for key in keys]): - + raise TypeError("All elements in 'keys' must be of type 'str'.") if isinstance(keys, str): keys = [keys] - + assert all([isinstance(key, str) for key in keys]) - + valid_keys = list(self._sdata.images.keys()) - + for key in keys: - + if key not in valid_keys: - + raise ValueError(f"Key '{key}' is not a valid key. Valid choices are: " + ", ".join(valid_keys)) selected_images = {key: img for key, img in self._sdata.images.items() if key in keys} @@ -168,16 +171,15 @@ def get_images(self, keys: Union[list, str], label_func: Callable = lambda x: x) new_table = None # make sure that table exists if hasattr(self._sdata, "table"): - + if hasattr(self._sdata.table, "obs"): - + # create mask of used keys mask = self._sdata.table.obs[self._sdata.pp.get_region_key()] mask = list(mask.str.contains("|".join(keys))) # print(mask) - + new_table = self._sdata.table[mask, :] - return self._copy(images=selected_images, labels=selected_labels, table=new_table) @@ -200,6 +202,107 @@ def get_channels(self, keys: Union[list, slice]) -> sd.SpatialData: return self._copy(images=channels_images) + def get_polygons(self, keys: Union[str, list[str]]) -> sd.SpatialData: + """Get polygons from a list of keys. + + Parameters + ---------- + keys : list + ist of keys to select + + Returns + ------- + sd.SpatialData + subsetted SpatialData object + """ + + if len(keys) == 0: + + raise ValueError("No keys specified") + + if not (isinstance(keys, str) or isinstance(keys, list)): + + raise TypeError("Parameter 'keys' must either be a string or a list of strings.") + + if isinstance(keys, list): + + if not all(isinstance(x, str) for x in keys): + + raise TypeError("Not all elements in 'keys' are strings.") + + # 1) collect and verify all polygons + + polygons_to_retain = [] + keys = [keys] if isinstance(keys, str) else keys + for key in keys: + + key = key.split("/") + + if len(key) > 2: + + raise ValueError(f"Key '{key}' is specified in an invalid format.") + + if len(key) == 2: + + key_major, key_minor = key + + elif len(key) == 1: + + key_major, key_minor = (key[0], None) + + # TODO(ttreis) error handling if key_minor cannot be converted to int + + if key_major not in list(self._sdata.polygons.keys()): + + raise ValueError(f"Polygon with key '{key_major}' does not exist.") + + if key_minor is not None: + + if int(key_minor) not in self._sdata.polygons[key_major].index: + + raise ValueError(f"Polygon with key '{key_major}/{key_minor}' does not exist.") + + key_minor = "all" if key_minor is None else key_minor + + polygons_to_retain.append([key_major, key_minor]) + + # 2) Collect polygons + polygons = {} + + # 2a) Initialize polygon dict + for key_major, key_minor in polygons_to_retain: + + if key_minor != "all": + + polygons[key_major] = [] + + # 2b) Collect polygons + for key_major, key_minor in polygons_to_retain: + + if key_minor == "all": + + polygons[key_major] = self._sdata.polygons[key_major] + + else: + + polygons[key_major].append( + self._sdata.polygons[key_major].loc[ + int(key_minor), + ]["geometry"] + ) + + # 2c) If not a full GeoDataFrame, convert subset to GeoDataFrame + for key_major in polygons.keys(): + + if isinstance(polygons[key_major], list): + + polygons[key_major] = GeoDataFrame(({"geometry": polygons[key_major]})) + polygons[key_major] = PolygonsModel.parse(polygons[key_major], name=key_major) + + sdata = self._copy(polygons=polygons) + + return sdata + def colorize( self, colors: List[str] = ["C0", "C1", "C2", "C3"], diff --git a/src/spatialdata_plot/tl/__init__.py b/src/spatialdata_plot/tl/__init__.py index 95a32cd2..c73a42e5 100644 --- a/src/spatialdata_plot/tl/__init__.py +++ b/src/spatialdata_plot/tl/__init__.py @@ -1 +1,5 @@ -from .basic import basic_tool +from .basic import PreprocessingAccessor + +__all__ = [ + "PreprocessingAccessor", +] diff --git a/src/spatialdata_plot/tl/basic.py b/src/spatialdata_plot/tl/basic.py index d215ade4..66a8b6ea 100644 --- a/src/spatialdata_plot/tl/basic.py +++ b/src/spatialdata_plot/tl/basic.py @@ -1,17 +1,57 @@ +from typing import Callable, List, Union + from anndata import AnnData +import spatialdata as sd + + +from ..accessor import register_spatial_data_accessor + + +@register_spatial_data_accessor("tl") +class ToolAccessor: + def __init__(self, sdata): + self._sdata = sdata + + def _copy( + self, + images: Union[None, dict] = None, + labels: Union[None, dict] = None, + points: Union[None, dict] = None, + polygons: Union[None, dict] = None, + shapes: Union[None, dict] = None, + table: Union[dict, AnnData] = None, + ) -> sd.SpatialData: + + """ + Helper function to copies the references from the original SpatialData + object to the subsetted SpatialData object. + """ + return sd.SpatialData( + images=self._sdata.images if images is None else images, + labels=self._sdata.labels if labels is None else labels, + points=self._sdata.points if points is None else points, + polygons=self._sdata.polygons if polygons is None else polygons, + shapes=self._sdata.shapes if shapes is None else shapes, + table=self._sdata.table if table is None else table, + ) -def basic_tool(adata: AnnData) -> int: - """Run a tool on the AnnData object. + def filter_polygon(self, polygon_key: str) -> sd.SpatialData: + """Subsets the SpatialData object to the polygon with the given key. - Parameters - ---------- - adata - The AnnData object to preprocess. + Parameters + ---------- + polygon_key + Key of the polygon to subset to. - Returns - ------- - Some integer value. - """ - print("Implement a tool to run on the AnnData object.") - return 0 + Returns + ------- + sd.SpatialData + """ + + if polygon_key not in self._sdata.polygons: + raise ValueError(f"Polygon with key {polygon_key} not found in SpatialData object.") + + polygons = {polygon_key: self._sdata.polygons[polygon_key]} + + return self._copy(polygons=polygons) \ No newline at end of file diff --git a/src/spatialdata_plot/utils.py b/src/spatialdata_plot/utils.py new file mode 100644 index 00000000..cde0aa4e --- /dev/null +++ b/src/spatialdata_plot/utils.py @@ -0,0 +1,30 @@ +import spatialdata as sd + + +def _confirm_membership(sd: sd.SpatialData, key: str, slot: str) -> None: + """ + Helper function to check if a key is part of the respective slot. + """ + + if not isinstance(key, str): + + raise TypeError("Parameter 'key' must be of type 'str'.") + + if not isinstance(slot, str): + + raise TypeError("Parameter 'slot' must be of type 'str'.") + + valid_slots = ["images", "labels", "points", "polygons", "shapes", "table"] + if slot not in valid_slots: + + raise ValueError(f"Parameter 'slot' must be one of {valid_slots}.") + + if len(getattr(sd, slot)) == 0: + + raise ValueError(f"Slot '{slot}' is empty.") + + if key not in getattr(sd, slot).keys(): + + raise ValueError(f"Key '{key}' not found in slot '{slot}'.") + + return True \ No newline at end of file