Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/spatialdata_plot/_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# from https://github.com/scverse/spatialdata/blob/main/src/spatialdata/_logging.py

import logging
import re
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from _pytest.logging import LogCaptureFixture


def _setup_logger() -> "logging.Logger":
Expand All @@ -21,3 +28,44 @@ def _setup_logger() -> "logging.Logger":


logger = _setup_logger()


@contextmanager
def logger_warns(
caplog: "LogCaptureFixture",
logger: logging.Logger,
match: str | None = None,
level: int = logging.WARNING,
) -> Iterator[None]:
"""
Context manager similar to pytest.warns, but for logging.Logger.

Usage:
with logger_warns(caplog, logger, match="Found 1 NaN"):
call_code_that_logs()
"""
# Store initial record count to only check new records
initial_record_count = len(caplog.records)

# Add caplog's handler directly to the logger to capture logs even if propagate=False
handler = caplog.handler
logger.addHandler(handler)
original_level = logger.level
logger.setLevel(level)

# Use caplog.at_level to ensure proper capture setup
with caplog.at_level(level, logger=logger.name):
try:
yield
finally:
logger.removeHandler(handler)
logger.setLevel(original_level)

# Only check records that were added during this context
records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level]

if match is not None:
pattern = re.compile(match)
if not any(pattern.search(r.getMessage()) for r in records):
msgs = [r.getMessage() for r in records]
raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}")
95 changes: 52 additions & 43 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from collections import abc
from copy import copy

Expand Down Expand Up @@ -49,7 +48,6 @@
_get_extent_and_range_for_datashader_canvas,
_get_linear_colormap,
_hex_no_alpha,
_is_coercable_to_float,
_map_color_seg,
_maybe_set_colors,
_mpl_ax_contains_elements,
Expand Down Expand Up @@ -94,20 +92,7 @@ def _render_shapes(
)
sdata_filt[table_name] = table = joined_table

if (
col_for_color is not None
and table_name is not None
and col_for_color in sdata_filt[table_name].obs.columns
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
warnings.warn(
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical plotting. "
f"Consider converting before plotting.",
UserWarning,
stacklevel=2,
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
shapes = sdata_filt[element]

# get color vector (categorical or continuous)
color_source_vector, color_vector, _ = _set_color_source_vec(
Expand All @@ -121,6 +106,7 @@ def _render_shapes(
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
coordinate_system=coordinate_system,
)

values_are_categorical = color_source_vector is not None
Expand All @@ -144,12 +130,25 @@ def _render_shapes(

# continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
if color_source_vector is None and not values_are_categorical:
color_vector = np.asarray(color_vector, dtype=float)
if np.isnan(color_vector).any():
nan_count = int(np.isnan(color_vector).sum())
msg = f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
warnings.warn(msg, UserWarning, stacklevel=2)
logger.warning(msg)
_series = color_vector if isinstance(color_vector, pd.Series) else pd.Series(color_vector)

try:
color_vector = np.asarray(_series, dtype=float)
except (TypeError, ValueError):
nan_count = int(_series.isna().sum())
if nan_count:
logger.warning(
f"Found {nan_count} NaN values in color data. "
"These observations will be colored with the 'na_color'."
)
color_vector = _series.to_numpy()
else:
if np.isnan(color_vector).any():
nan_count = int(np.isnan(color_vector).sum())
logger.warning(
f"Found {nan_count} NaN values in color data. "
"These observations will be colored with the 'na_color'."
)

# Using dict.fromkeys here since set returns in arbitrary order
# remove the color of NaN values, else it might be assigned to a category
Expand Down Expand Up @@ -476,10 +475,33 @@ def _render_shapes(
if not values_are_categorical:
vmin = render_params.cmap_params.norm.vmin
vmax = render_params.cmap_params.norm.vmax
if vmin is None:
vmin = float(np.nanmin(color_vector))
if vmax is None:
vmax = float(np.nanmax(color_vector))
if vmin is None or vmax is None:
# Extract numeric values only (filter out strings and other non-numeric types)
if isinstance(color_vector, np.ndarray):
if np.issubdtype(color_vector.dtype, np.number):
# Already numeric, can use directly
numeric_values = color_vector
else:
# Mixed types - extract only numeric values using pandas
numeric_values = pd.to_numeric(color_vector, errors="coerce")
numeric_values = numeric_values[np.isfinite(numeric_values)]
if len(numeric_values) > 0:
if vmin is None:
vmin = float(np.nanmin(numeric_values))
if vmax is None:
vmax = float(np.nanmax(numeric_values))
else:
# No numeric values found, use defaults
if vmin is None:
vmin = 0.0
if vmax is None:
vmax = 1.0
else:
# Not a numpy array, use defaults
if vmin is None:
vmin = 0.0
if vmax is None:
vmax = 1.0
_cax.set_clim(vmin=vmin, vmax=vmax)

if (
Expand Down Expand Up @@ -541,31 +563,16 @@ def _render_points(
coords = ["x", "y"]

if table_name is not None and col_for_color not in points.columns:
warnings.warn(
logger.warning(
f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. "
f"To improve performance, it is advisable to store point annotations directly in the .parquet file.",
UserWarning,
stacklevel=2,
f"To improve performance, it is advisable to store point annotations directly in the .parquet file."
)

if col_for_color is None or (
table_name is not None
and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names)
):
points = points[coords].compute()
if (
col_for_color
and col_for_color in sdata_filt[table_name].obs.columns
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
warnings.warn(
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical "
f"plotting. Consider converting before plotting.",
UserWarning,
stacklevel=2,
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
else:
coords += [col_for_color]
points = points[coords].compute()
Expand Down Expand Up @@ -683,6 +690,7 @@ def _render_points(
alpha=render_params.alpha,
table_name=table_name,
render_type="points",
coordinate_system=coordinate_system,
)

if added_color_from_table and col_for_color is not None:
Expand Down Expand Up @@ -1219,6 +1227,7 @@ def _render_labels(
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
coordinate_system=coordinate_system,
)

# rasterize could have removed labels from label
Expand Down
Loading