In [None]:
import json
from pathlib import Path

from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon


def plot_segmentations_on_downscaled_image(
    image_path,
    geojson_path,
    orig_width,
    orig_height,
    alpha=0.3,
    edge_color="yellow",
    line_width=0.5,
):
    """
    image_path: path to the *downscaled* histology image (PNG/JPEG/TIFF)
    geojson_path: path to cell_segmentations.geojson generated on the original image
    orig_width, orig_height: dimensions of the original image (the one used to create the segmentations)
    """

    image_path = Path(image_path)
    geojson_path = Path(geojson_path)

    # 1. Load the downscaled image
    img = Image.open(image_path).convert("RGB")
    new_width, new_height = img.size

    print(f"Original size: ({orig_width}, {orig_height})")
    print(f"Downscaled size: ({new_width}, {new_height})")

    # 2. Compute scaling factors
    scale_x = new_width / orig_width
    scale_y = new_height / orig_height
    print(f"Scale factors: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}")

    # 3. Load GeoJSON
    with open(geojson_path, "r") as f:
        gj = json.load(f)

    features = gj.get("features", [])
    print(f"Loaded {len(features)} features from GeoJSON")

    # 4. Plot
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(img)
    ax.set_axis_off()

    # 5. Loop over each feature (cell polygon)
    for feat in features:
        geom = feat.get("geometry", {})
        if geom.get("type") != "Polygon":
            # skip non-polygons, e.g. MultiPolygon, Point
            continue

        # GeoJSON polygons are [[[x1, y1], [x2, y2], ...]]
        for ring in geom.get("coordinates", []):
            # Scale coordinates
            scaled_coords = [
                (x * scale_x, y * scale_y)
                for x, y in ring
            ]

            poly = Polygon(
                scaled_coords,
                closed=True,
                fill=False,
                edgecolor=edge_color,
                linewidth=line_width,
                alpha=alpha,
            )
            ax.add_patch(poly)

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    # EDIT THESE
    image_path = "/path/to/downscaled_histology.png"
    geojson_path = "../data/visium_adult_mouse_brain/cell_segmentations.geojson"

    # You must set these to the *original* full-res image size
    orig_width = 40000  # replace with real width
    orig_height = 32000  # replace with real height

    plot_segmentations_on_downscaled_image(
        image_path=image_path,
        geojson_path=geojson_path,
        orig_width=orig_width,
        orig_height=orig_height,
        alpha=0.5,
        edge_color="cyan",
        line_width=0.4,
    )
