# PyTorch Object Segmentation + ESD Line Extraction Demo

This notebook demonstrates the integration of **PyTorch-based object segmentation** with the
**LineExtraction ESD (Edge Segment Detector)** framework.

## Pipeline Overview

```
┌─────────────────────┐    ┌──────────────────────┐    ┌─────────────────────┐    ┌────────────────┐
│  Input Image        │───>│ PyTorch Segmentation │───>│  Extract Contours   │───>│  ESD + Lines   │
│  (BSDS500/MDB)      │    │  (SAM / YOLO)        │    │  (cv2.findContours) │    │  (le_edge/lsd) │
└─────────────────────┘    └──────────────────────┘    └─────────────────────┘    └────────────────┘
```

**Interactive Features:**
- **SAM (Segment Anything):** Click on any object to segment it
- **YOLO:** Automatic instance segmentation of 80 COCO classes
- **Model switching:** Toggle between SAM and YOLO in real-time
- **ESD integration:** Convert object contours to line segments

**Prerequisites:**
```bash
# Build Python bindings
bazel build //libs/...

# Install PyTorch dependencies
uv pip install torch torchvision segment-anything ultralytics
```

## 1. Setup and Imports

Configure the Python environment to use LineExtraction bindings from Bazel output directories.
Import PyTorch, segmentation models, and LE modules.

In [None]:
import sys
import pathlib
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Enable interactive matplotlib backend for click events
%matplotlib widget

# --- Locate workspace root and add Bazel output dirs to sys.path ---
workspace = pathlib.Path.cwd()
while not (workspace / "MODULE.bazel").exists():
    if workspace == workspace.parent:
        raise RuntimeError("Cannot find LineExtraction workspace root (MODULE.bazel)")
    workspace = workspace.parent

# Add each binding's Bazel output directory
for lib in ["imgproc", "edge", "geometry", "eval", "lsd"]:
    p = workspace / f"bazel-bin/libs/{lib}/python"
    if p.exists():
        sys.path.insert(0, str(p))
    else:
        print(f"⚠ Not found: {p}  — run: bazel build //libs/{lib}/...")

# Add lsfm package for TestImages
sys.path.insert(0, str(workspace / "python"))

print(f"Workspace: {workspace}")

# Import LineExtraction modules
import le_imgproc
import le_edge
import le_geometry
import le_lsd

# Import test images helper
from lsfm.data import TestImages

# Check for PyTorch
try:
    import torch
    import torchvision
    print(f"PyTorch: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
    raise ImportError("PyTorch not installed. Run: uv pip install torch torchvision")

# Check for segmentation models
SAM_AVAILABLE = False
YOLO_AVAILABLE = False

try:
    from segment_anything import sam_model_registry, SamPredictor
    SAM_AVAILABLE = True
    print("✓ Segment Anything Model (SAM) available")
except ImportError:
    print("⚠ SAM not installed. Run: uv pip install segment-anything")

try:
    from ultralytics import YOLO
    YOLO_AVAILABLE = True
    print("✓ Ultralytics YOLO available")
except ImportError:
    print("⚠ YOLO not installed. Run: uv pip install ultralytics")

# Interactive widgets
import ipywidgets as widgets
from IPython.display import display, clear_output

print("\nAll modules loaded successfully.")

## 2. Model Abstraction Layer

Define a common interface for segmentation models, enabling seamless switching between SAM and YOLO.
Each model must implement:
- `load_model()` — Download and initialize the model
- `predict_mask(image, point)` — Generate binary mask from input
- `get_contours(mask)` — Extract polygon contours from mask

In [None]:
@dataclass
class SegmentationResult:
    """Result from a segmentation model."""
    masks: List[np.ndarray]  # List of binary masks (H, W), one per detected object
    contours: List[np.ndarray]  # List of contour arrays, each shape (N, 2) as (x, y)
    labels: List[str]  # Class labels for each mask
    scores: List[float]  # Confidence scores


class SegmentationModel(ABC):
    """Abstract base class for segmentation models."""
    
    def __init__(self, device: str = "cpu"):
        self.device = device
        self.model = None
        self._loaded = False
    
    @abstractmethod
    def load_model(self) -> None:
        """Load the model weights. Called lazily on first prediction."""
        pass
    
    @abstractmethod
    def predict(
        self, 
        image: np.ndarray, 
        point: Optional[Tuple[int, int]] = None
    ) -> SegmentationResult:
        """
        Predict segmentation masks for the image.
        
        Args:
            image: RGB image as numpy array (H, W, 3), uint8
            point: Optional (x, y) click coordinate for interactive segmentation (SAM)
        
        Returns:
            SegmentationResult with masks, contours, labels, and scores
        """
        pass
    
    def ensure_loaded(self) -> None:
        """Ensure model is loaded before prediction."""
        if not self._loaded:
            print(f"Loading {self.__class__.__name__}...")
            self.load_model()
            self._loaded = True
            print(f"✓ {self.__class__.__name__} ready")
    
    @staticmethod
    def extract_contours(mask: np.ndarray, min_area: int = 100) -> List[np.ndarray]:
        """
        Extract contours from a binary mask using marching squares.
        
        Args:
            mask: Binary mask (H, W), values 0 or 255
            min_area: Minimum contour area to keep
        
        Returns:
            List of contour arrays, each shape (N, 2) as (x, y) coordinates
        """
        # Use skimage for contour extraction (more reliable than cv2 for this use case)
        from skimage import measure
        
        # Ensure binary mask
        binary = (mask > 127).astype(np.uint8)
        
        # Find contours
        contours = measure.find_contours(binary, level=0.5)
        
        result = []
        for contour in contours:
            # skimage returns (row, col) = (y, x), convert to (x, y)
            contour_xy = contour[:, ::-1].astype(np.float32)
            
            # Filter by area
            if len(contour_xy) >= 3:
                # Approximate area using shoelace formula
                x = contour_xy[:, 0]
                y = contour_xy[:, 1]
                area = 0.5 * abs(np.sum(x[:-1] * y[1:] - x[1:] * y[:-1]))
                if area >= min_area:
                    result.append(contour_xy)
        
        return result


print("SegmentationModel base class defined.")

## 3. SAM (Segment Anything Model) Wrapper

Meta's SAM enables point-based interactive segmentation. Click anywhere on an object
and SAM generates a precise mask for that object.

**Model variants:**
- `vit_h` — Huge (2.4GB) — Highest quality
- `vit_l` — Large (1.2GB) — Good balance
- `vit_b` — Base (375MB) — Fastest

In [None]:
class SamSegmenter(SegmentationModel):
    """
    Segment Anything Model (SAM) wrapper for interactive point-based segmentation.
    
    Usage:
        sam = SamSegmenter(model_type="vit_b")
        result = sam.predict(image, point=(x, y))
    """
    
    # Model checkpoint URLs
    CHECKPOINTS = {
        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
    }
    
    def __init__(self, model_type: str = "vit_b", device: str = DEVICE):
        super().__init__(device)
        self.model_type = model_type
        self.predictor = None
        self._current_image = None
    
    def load_model(self) -> None:
        """Download and load SAM checkpoint."""
        if not SAM_AVAILABLE:
            raise ImportError("segment-anything not installed")
        
        import urllib.request
        
        # Download checkpoint if not cached
        cache_dir = workspace / ".cache" / "sam"
        cache_dir.mkdir(parents=True, exist_ok=True)
        
        checkpoint_url = self.CHECKPOINTS[self.model_type]
        checkpoint_name = checkpoint_url.split("/")[-1]
        checkpoint_path = cache_dir / checkpoint_name
        
        if not checkpoint_path.exists():
            print(f"Downloading SAM {self.model_type} checkpoint (~375MB for vit_b)...")
            urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
            print(f"✓ Saved to {checkpoint_path}")
        
        # Load model
        sam = sam_model_registry[self.model_type](checkpoint=str(checkpoint_path))
        sam.to(device=self.device)
        self.predictor = SamPredictor(sam)
    
    def predict(
        self, 
        image: np.ndarray, 
        point: Optional[Tuple[int, int]] = None
    ) -> SegmentationResult:
        """
        Predict segmentation mask for clicked point.
        
        Args:
            image: RGB image (H, W, 3), uint8
            point: (x, y) click coordinate — REQUIRED for SAM
        
        Returns:
            SegmentationResult with single mask for clicked object
        """
        self.ensure_loaded()
        
        if point is None:
            # Without a point, return empty result
            return SegmentationResult(masks=[], contours=[], labels=[], scores=[])
        
        # Set image (cached for efficiency)
        if self._current_image is not image:
            self.predictor.set_image(image)
            self._current_image = image
        
        # Predict with point prompt
        input_point = np.array([[point[0], point[1]]])
        input_label = np.array([1])  # 1 = foreground
        
        masks, scores, _ = self.predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        
        # Use highest scoring mask
        best_idx = np.argmax(scores)
        best_mask = (masks[best_idx] * 255).astype(np.uint8)
        best_score = float(scores[best_idx])
        
        # Extract contours
        contours = self.extract_contours(best_mask)
        
        return SegmentationResult(
            masks=[best_mask],
            contours=contours,
            labels=["object"],
            scores=[best_score],
        )


if SAM_AVAILABLE:
    print("SamSegmenter class defined.")
else:
    print("⚠ SamSegmenter unavailable (segment-anything not installed)")

## 4. YOLO v8 Segmentation Wrapper

Ultralytics YOLOv8-seg provides fast automatic instance segmentation for 80 COCO classes.
No user interaction required — automatically detects and segments all objects.

**Model variants:**
- `yolov8n-seg` — Nano (6.7MB) — Fastest
- `yolov8s-seg` — Small (22.4MB) — Good balance
- `yolov8m-seg` — Medium (50.5MB) — Higher accuracy

In [None]:
class YoloSegmenter(SegmentationModel):
    """
    YOLOv8 instance segmentation wrapper for automatic object detection.
    
    Usage:
        yolo = YoloSegmenter(model_name="yolov8n-seg")
        result = yolo.predict(image)  # Returns all detected objects
    """
    
    def __init__(self, model_name: str = "yolov8n-seg", device: str = DEVICE):
        super().__init__(device)
        self.model_name = model_name
        self.conf_threshold = 0.25
    
    def load_model(self) -> None:
        """Load YOLOv8 segmentation model (auto-downloads from Ultralytics)."""
        if not YOLO_AVAILABLE:
            raise ImportError("ultralytics not installed")
        
        # YOLO auto-downloads to ~/.cache/ultralytics/
        self.model = YOLO(self.model_name)
        self.model.to(self.device)
    
    def predict(
        self, 
        image: np.ndarray, 
        point: Optional[Tuple[int, int]] = None
    ) -> SegmentationResult:
        """
        Predict segmentation masks for all detected objects.
        
        Args:
            image: RGB image (H, W, 3), uint8
            point: Optional (x, y) — if provided, only return object containing that point
        
        Returns:
            SegmentationResult with masks for all detected objects (or filtered by point)
        """
        self.ensure_loaded()
        
        # Run inference
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = self.model(image, conf=self.conf_threshold, verbose=False)
        
        masks = []
        contours = []
        labels = []
        scores = []
        
        if results and len(results) > 0:
            result = results[0]
            
            if result.masks is not None:
                h, w = image.shape[:2]
                
                for i, mask_data in enumerate(result.masks.data):
                    # Get mask as numpy array
                    mask = mask_data.cpu().numpy()
                    
                    # Resize to original image size if needed
                    if mask.shape != (h, w):
                        from PIL import Image as PILImage
                        mask_pil = PILImage.fromarray((mask * 255).astype(np.uint8))
                        mask_pil = mask_pil.resize((w, h), PILImage.NEAREST)
                        mask = np.array(mask_pil)
                    else:
                        mask = (mask * 255).astype(np.uint8)
                    
                    # Get class info
                    cls_id = int(result.boxes.cls[i])
                    conf = float(result.boxes.conf[i])
                    label = result.names[cls_id]
                    
                    # If point specified, check if this mask contains the point
                    if point is not None:
                        px, py = point
                        if 0 <= py < h and 0 <= px < w:
                            if mask[py, px] < 128:
                                continue  # Skip if point not in this mask
                    
                    # Extract contours
                    mask_contours = self.extract_contours(mask)
                    
                    masks.append(mask)
                    contours.extend(mask_contours)
                    labels.append(label)
                    scores.append(conf)
        
        return SegmentationResult(
            masks=masks,
            contours=contours,
            labels=labels,
            scores=scores,
        )


if YOLO_AVAILABLE:
    print("YoloSegmenter class defined.")
else:
    print("⚠ YoloSegmenter unavailable (ultralytics not installed)")

## 5. Contour-to-Lines Pipeline (ESD Integration)

This section converts segmentation contours to line segments using the LineExtraction framework.

**Pipeline:**
1. **Contour Simplification** — Douglas-Peucker algorithm to reduce point count
2. **ESD Edge Segments** — Use `le_edge.EsdDrawing` on contour image
3. **Line Fitting** — Extract line segments with `le_lsd` or direct fitting

The key insight: neural segmentation provides accurate object boundaries, while ESD/LSD
provides robust line extraction with sub-pixel precision.

In [None]:
@dataclass
class LineExtractionResult:
    """Result from line extraction pipeline."""
    contours: List[np.ndarray]  # Original contours (x, y) points
    simplified_contours: List[np.ndarray]  # Douglas-Peucker simplified
    edge_segments: List[np.ndarray]  # ESD edge segment points
    line_segments: List[Tuple[Tuple[float, float], Tuple[float, float]]]  # Line endpoints


class ContourToLinesESD:
    """
    Pipeline to convert segmentation contours to line segments using ESD.
    
    Two approaches are supported:
    1. Direct contour simplification (Douglas-Peucker)
    2. ESD-based edge segment detection on contour image
    """
    
    def __init__(
        self,
        epsilon_ratio: float = 0.01,  # Douglas-Peucker epsilon as ratio of perimeter
        min_line_length: float = 10.0,  # Minimum line segment length in pixels
        esd_min_pixels: int = 5,  # Minimum pixels for ESD segment
    ):
        self.epsilon_ratio = epsilon_ratio
        self.min_line_length = min_line_length
        self.esd_min_pixels = esd_min_pixels
    
    def simplify_contour(self, contour: np.ndarray) -> np.ndarray:
        """
        Simplify contour using Douglas-Peucker algorithm.
        
        Args:
            contour: Contour points (N, 2) as (x, y)
        
        Returns:
            Simplified contour points
        """
        if len(contour) < 3:
            return contour
        
        # Calculate perimeter
        diffs = np.diff(contour, axis=0, append=contour[:1])
        perimeter = np.sum(np.linalg.norm(diffs, axis=1))
        
        # Epsilon as ratio of perimeter
        epsilon = self.epsilon_ratio * perimeter
        
        # Simple recursive Douglas-Peucker
        return self._douglas_peucker(contour, epsilon)
    
    def _douglas_peucker(self, points: np.ndarray, epsilon: float) -> np.ndarray:
        """Recursive Douglas-Peucker simplification."""
        if len(points) <= 2:
            return points
        
        # Find point with maximum distance from line between first and last
        start, end = points[0], points[-1]
        line_vec = end - start
        line_len = np.linalg.norm(line_vec)
        
        if line_len < 1e-10:
            return np.array([start, end])
        
        line_unit = line_vec / line_len
        
        # Distance from each point to line
        dists = np.abs(np.cross(points - start, line_unit))
        max_idx = np.argmax(dists)
        max_dist = dists[max_idx]
        
        if max_dist > epsilon:
            # Recursively simplify
            left = self._douglas_peucker(points[:max_idx + 1], epsilon)
            right = self._douglas_peucker(points[max_idx:], epsilon)
            return np.vstack([left[:-1], right])
        else:
            return np.array([start, end])
    
    def contour_to_edge_image(
        self, 
        contours: List[np.ndarray], 
        shape: Tuple[int, int]
    ) -> np.ndarray:
        """
        Render contours as a binary edge image.
        
        Args:
            contours: List of contour arrays (N, 2) as (x, y)
            shape: Image shape (H, W)
        
        Returns:
            Binary edge image (H, W), uint8
        """
        edge_img = np.zeros(shape, dtype=np.uint8)
        
        for contour in contours:
            pts = contour.astype(np.int32)
            for i in range(len(pts)):
                p1 = pts[i]
                p2 = pts[(i + 1) % len(pts)]
                # Draw line using Bresenham
                self._draw_line(edge_img, p1, p2)
        
        return edge_img
    
    def _draw_line(self, img: np.ndarray, p1: np.ndarray, p2: np.ndarray) -> None:
        """Draw line on image using Bresenham's algorithm."""
        x0, y0 = p1
        x1, y1 = p2
        h, w = img.shape
        
        dx = abs(x1 - x0)
        dy = -abs(y1 - y0)
        sx = 1 if x0 < x1 else -1
        sy = 1 if y0 < y1 else -1
        err = dx + dy
        
        while True:
            if 0 <= x0 < w and 0 <= y0 < h:
                img[y0, x0] = 255
            
            if x0 == x1 and y0 == y1:
                break
            
            e2 = 2 * err
            if e2 >= dy:
                err += dy
                x0 += sx
            if e2 <= dx:
                err += dx
                y0 += sy
    
    def extract_lines_from_contours(
        self, 
        contours: List[np.ndarray],
        image_shape: Tuple[int, int]
    ) -> LineExtractionResult:
        """
        Extract line segments from contours using ESD.
        
        Args:
            contours: List of contour arrays from segmentation
            image_shape: Original image shape (H, W)
        
        Returns:
            LineExtractionResult with all intermediate and final results
        """
        if not contours:
            return LineExtractionResult(
                contours=[],
                simplified_contours=[],
                edge_segments=[],
                line_segments=[],
            )
        
        # 1. Simplify contours using Douglas-Peucker
        simplified = [self.simplify_contour(c) for c in contours]
        
        # 2. Create edge image from contours
        edge_img = self.contour_to_edge_image(contours, image_shape)
        
        # 3. Use EdgeSource on the edge image to get gradients
        edge_source = le_imgproc.SobelGradient()
        edge_source.process(edge_img)
        
        # 4. Run ESD on the edge image
        try:
            esd = le_edge.EsdDrawing(min_pixels=self.esd_min_pixels)
            
            # For ESD we need: direction_map, magnitude, seeds
            # Create a simple NMS to get seeds
            nms = le_edge.NmsPatternFull()
            nms.process(edge_source)
            
            # Get NMS results
            nms_img = nms.img_nms()
            
            # Detect edge segments
            esd.detect(nms, edge_source)
            
            # Get segment points
            points = esd.points()
            segments = esd.segments()
            
            edge_segment_list = []
            for seg in segments:
                # Extract points for this segment
                seg_points = []
                for idx in range(seg.begin, seg.end):
                    if idx < len(points):
                        pt = points[idx]
                        # Convert linear index to (x, y)
                        y = pt // image_shape[1]
                        x = pt % image_shape[1]
                        seg_points.append([x, y])
                if seg_points:
                    edge_segment_list.append(np.array(seg_points))
            
        except Exception as e:
            print(f"ESD processing failed: {e}")
            edge_segment_list = []
        
        # 5. Extract line segments from simplified contours
        line_segments = []
        for simp in simplified:
            for i in range(len(simp) - 1):
                p1 = tuple(simp[i])
                p2 = tuple(simp[i + 1])
                
                # Filter by length
                length = np.linalg.norm(np.array(p2) - np.array(p1))
                if length >= self.min_line_length:
                    line_segments.append((p1, p2))
            
            # Close the contour if needed
            if len(simp) >= 3:
                p1 = tuple(simp[-1])
                p2 = tuple(simp[0])
                length = np.linalg.norm(np.array(p2) - np.array(p1))
                if length >= self.min_line_length:
                    line_segments.append((p1, p2))
        
        return LineExtractionResult(
            contours=contours,
            simplified_contours=simplified,
            edge_segments=edge_segment_list,
            line_segments=line_segments,
        )


# Instantiate pipeline
contour_pipeline = ContourToLinesESD()
print("ContourToLinesESD pipeline defined.")

## 6. Visualization Utilities

Helper functions for displaying results at each stage of the pipeline.

In [None]:
def draw_contours_on_image(
    image: np.ndarray,
    contours: List[np.ndarray],
    color: Tuple[int, int, int] = (0, 255, 0),
    thickness: int = 2,
) -> np.ndarray:
    """Draw contours on image copy."""
    result = image.copy()
    
    for contour in contours:
        pts = contour.astype(np.int32)
        for i in range(len(pts)):
            p1 = tuple(pts[i])
            p2 = tuple(pts[(i + 1) % len(pts)])
            # Simple line drawing
            draw_line_on_image(result, p1, p2, color, thickness)
    
    return result


def draw_line_on_image(
    img: np.ndarray,
    p1: Tuple[int, int],
    p2: Tuple[int, int],
    color: Tuple[int, int, int],
    thickness: int = 1,
) -> None:
    """Draw a line on image using simple algorithm."""
    x0, y0 = int(p1[0]), int(p1[1])
    x1, y1 = int(p2[0]), int(p2[1])
    h, w = img.shape[:2]
    
    # Bresenham with thickness
    dx = abs(x1 - x0)
    dy = -abs(y1 - y0)
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    err = dx + dy
    
    while True:
        for tx in range(-thickness // 2, thickness // 2 + 1):
            for ty in range(-thickness // 2, thickness // 2 + 1):
                px, py = x0 + tx, y0 + ty
                if 0 <= px < w and 0 <= py < h:
                    img[py, px] = color
        
        if x0 == x1 and y0 == y1:
            break
        
        e2 = 2 * err
        if e2 >= dy:
            err += dy
            x0 += sx
        if e2 <= dx:
            err += dx
            y0 += sy


def draw_lines_on_image(
    image: np.ndarray,
    lines: List[Tuple[Tuple[float, float], Tuple[float, float]]],
    color: Tuple[int, int, int] = (255, 0, 0),
    thickness: int = 2,
) -> np.ndarray:
    """Draw line segments on image copy."""
    result = image.copy()
    
    for p1, p2 in lines:
        draw_line_on_image(result, (int(p1[0]), int(p1[1])), (int(p2[0]), int(p2[1])), color, thickness)
    
    return result


def overlay_mask(
    image: np.ndarray,
    mask: np.ndarray,
    color: Tuple[int, int, int] = (0, 120, 255),
    alpha: float = 0.4,
) -> np.ndarray:
    """Overlay a binary mask on image with transparency."""
    result = image.copy()
    
    # Create colored overlay
    overlay = np.zeros_like(result)
    mask_bool = mask > 127
    overlay[mask_bool] = color
    
    # Blend
    result[mask_bool] = (
        (1 - alpha) * result[mask_bool] + alpha * overlay[mask_bool]
    ).astype(np.uint8)
    
    return result


def visualize_pipeline_result(
    image: np.ndarray,
    seg_result: SegmentationResult,
    line_result: LineExtractionResult,
    figsize: Tuple[int, int] = (16, 4),
) -> None:
    """
    Visualize the complete pipeline: Original → Segmentation → Contours → Lines
    """
    fig, axes = plt.subplots(1, 4, figsize=figsize)
    
    # 1. Original image
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    
    # 2. Segmentation mask overlay
    if seg_result.masks:
        masked = image.copy()
        colors = [(255, 100, 100), (100, 255, 100), (100, 100, 255), (255, 255, 100)]
        for i, mask in enumerate(seg_result.masks):
            color = colors[i % len(colors)]
            masked = overlay_mask(masked, mask, color, alpha=0.4)
        axes[1].imshow(masked)
        title = "Segmentation"
        if seg_result.labels:
            title += f"\n{', '.join(seg_result.labels[:3])}"
    else:
        axes[1].imshow(image)
        title = "No Segmentation"
    axes[1].set_title(title)
    axes[1].axis("off")
    
    # 3. Contours
    if line_result.contours:
        contour_img = draw_contours_on_image(image, line_result.contours, (0, 255, 0), 2)
        axes[2].imshow(contour_img)
        axes[2].set_title(f"Contours ({len(line_result.contours)})")
    else:
        axes[2].imshow(image)
        axes[2].set_title("No Contours")
    axes[2].axis("off")
    
    # 4. Line segments
    if line_result.line_segments:
        line_img = draw_lines_on_image(image, line_result.line_segments, (255, 0, 0), 2)
        axes[3].imshow(line_img)
        axes[3].set_title(f"Lines ({len(line_result.line_segments)})")
    else:
        axes[3].imshow(image)
        axes[3].set_title("No Lines")
    axes[3].axis("off")
    
    plt.tight_layout()
    plt.show()


print("Visualization utilities defined.")

## 7. Interactive Demo Application

The interactive application provides:
- **Image selection** from BSDS500/MDB datasets
- **Model switching** between SAM and YOLO
- **Click-to-segment** for SAM (click on any object)
- **Auto-detect** for YOLO (detects all objects)
- **Real-time line extraction** on segmented objects

In [None]:
class InteractiveSegmentationDemo:
    """
    Interactive demo combining PyTorch segmentation with ESD line extraction.
    
    Features:
    - Switch between SAM and YOLO models
    - Click on image to segment (SAM mode)
    - Automatic segmentation (YOLO mode)
    - Real-time line extraction from contours
    """
    
    def __init__(self):
        self.test_images = TestImages()
        self.current_image: Optional[np.ndarray] = None
        self.current_path: Optional[pathlib.Path] = None
        
        # Models (lazy-loaded)
        self._sam: Optional[SamSegmenter] = None
        self._yolo: Optional[YoloSegmenter] = None
        self.current_model = "yolo"  # Default to YOLO (no click required)
        
        # Results
        self.seg_result: Optional[SegmentationResult] = None
        self.line_result: Optional[LineExtractionResult] = None
        
        # Pipeline
        self.pipeline = ContourToLinesESD()
        
        # UI elements
        self.fig = None
        self.ax = None
        self.output = widgets.Output()
        
    @property
    def sam(self) -> SamSegmenter:
        if self._sam is None:
            if not SAM_AVAILABLE:
                raise ImportError("SAM not available")
            self._sam = SamSegmenter()
        return self._sam
    
    @property
    def yolo(self) -> YoloSegmenter:
        if self._yolo is None:
            if not YOLO_AVAILABLE:
                raise ImportError("YOLO not available")
            self._yolo = YoloSegmenter()
        return self._yolo
    
    def load_image(self, path: Union[str, pathlib.Path]) -> np.ndarray:
        """Load image from path."""
        path = pathlib.Path(path)
        img = Image.open(path).convert("RGB")
        self.current_image = np.array(img)
        self.current_path = path
        
        # Reset results
        self.seg_result = None
        self.line_result = None
        
        return self.current_image
    
    def get_available_images(self, dataset: str = "bsds500", limit: int = 20) -> List[pathlib.Path]:
        """Get list of available images from dataset."""
        if dataset == "bsds500":
            return list(self.test_images.bsds500())[:limit]
        elif dataset == "mdb":
            scenes = list(self.test_images.stereo_scenes("Q"))[:limit]
            return [self.test_images.stereo_pair(s, "Q")[0] for s in scenes]
        elif dataset == "noise":
            return list(self.test_images.noise_images())
        else:
            return []
    
    def segment(self, point: Optional[Tuple[int, int]] = None) -> SegmentationResult:
        """Run segmentation on current image."""
        if self.current_image is None:
            raise ValueError("No image loaded")
        
        if self.current_model == "sam":
            if point is None:
                return SegmentationResult(masks=[], contours=[], labels=[], scores=[])
            self.seg_result = self.sam.predict(self.current_image, point)
        else:  # yolo
            self.seg_result = self.yolo.predict(self.current_image, point)
        
        return self.seg_result
    
    def extract_lines(self) -> LineExtractionResult:
        """Extract lines from segmentation result."""
        if self.seg_result is None or not self.seg_result.contours:
            self.line_result = LineExtractionResult(
                contours=[], simplified_contours=[], edge_segments=[], line_segments=[]
            )
        else:
            h, w = self.current_image.shape[:2]
            self.line_result = self.pipeline.extract_lines_from_contours(
                self.seg_result.contours, (h, w)
            )
        
        return self.line_result
    
    def run_pipeline(self, point: Optional[Tuple[int, int]] = None) -> None:
        """Run full pipeline and display results."""
        self.segment(point)
        self.extract_lines()
        
        with self.output:
            clear_output(wait=True)
            if self.current_image is not None and self.seg_result is not None:
                visualize_pipeline_result(
                    self.current_image,
                    self.seg_result,
                    self.line_result,
                )
    
    def on_click(self, event):
        """Handle mouse click on image."""
        if event.inaxes != self.ax:
            return
        
        x, y = int(event.xdata), int(event.ydata)
        print(f"Click at ({x}, {y})")
        
        self.run_pipeline(point=(x, y))
    
    def create_ui(self) -> widgets.Widget:
        """Create the interactive UI."""
        # Dataset selector
        dataset_dropdown = widgets.Dropdown(
            options=["bsds500", "mdb", "noise"],
            value="bsds500",
            description="Dataset:",
        )
        
        # Image selector (populated dynamically)
        image_dropdown = widgets.Dropdown(
            options=[],
            description="Image:",
        )
        
        # Model selector
        model_toggle = widgets.ToggleButtons(
            options=["YOLO (Auto)", "SAM (Click)"],
            value="YOLO (Auto)",
            description="Model:",
        )
        
        # Buttons
        load_btn = widgets.Button(description="Load Image", button_style="primary")
        detect_btn = widgets.Button(description="Detect Objects", button_style="success")
        
        # Status
        status = widgets.HTML(value="<i>Select an image to begin</i>")
        
        def update_images(change):
            paths = self.get_available_images(change["new"])
            image_dropdown.options = [(p.name, p) for p in paths]
        
        def on_load(btn):
            path = image_dropdown.value
            if path:
                self.load_image(path)
                status.value = f"<b>Loaded:</b> {path.name} ({self.current_image.shape[1]}x{self.current_image.shape[0]})"
                
                # Show image
                with self.output:
                    clear_output(wait=True)
                    plt.figure(figsize=(8, 6))
                    plt.imshow(self.current_image)
                    plt.title(f"{path.name} - Click to segment (SAM) or press Detect (YOLO)")
                    plt.axis("off")
                    plt.show()
        
        def on_detect(btn):
            if self.current_image is None:
                status.value = "<span style='color:red'>Load an image first!</span>"
                return
            
            status.value = "<i>Processing...</i>"
            self.run_pipeline()
            status.value = f"<b>Done:</b> {len(self.seg_result.masks)} objects, {len(self.line_result.line_segments)} lines"
        
        def on_model_change(change):
            if "Auto" in change["new"]:
                self.current_model = "yolo"
            else:
                self.current_model = "sam"
        
        # Connect callbacks
        dataset_dropdown.observe(update_images, names="value")
        load_btn.on_click(on_load)
        detect_btn.on_click(on_detect)
        model_toggle.observe(on_model_change, names="value")
        
        # Initialize image list
        update_images({"new": "bsds500"})
        
        # Layout
        controls = widgets.VBox([
            widgets.HBox([dataset_dropdown, image_dropdown, load_btn]),
            widgets.HBox([model_toggle, detect_btn]),
            status,
        ])
        
        return widgets.VBox([controls, self.output])


# Create demo instance
demo = InteractiveSegmentationDemo()
print("InteractiveSegmentationDemo created.")

### Run the Interactive Demo

Execute the cell below to launch the interactive UI. 

**How to use:**
1. Select a dataset (BSDS500 recommended for diverse images)
2. Choose an image from the dropdown
3. Click **Load Image** to display it
4. Choose model: **YOLO** for automatic detection, **SAM** for click-to-segment
5. Click **Detect Objects** to run the pipeline
6. View results: Original → Segmentation → Contours → Lines

In [None]:
# Launch the interactive demo
ui = demo.create_ui()
display(ui)

## 8. Quick Demo (Non-Interactive)

For testing without the interactive UI, run this cell to process a single image.

In [None]:
# Quick demo: Load a test image and run the full pipeline

# Get first available image
test_images = TestImages()
available = list(test_images.bsds500())

if available:
    # Load image
    img_path = available[0]
    print(f"Loading: {img_path.name}")
    
    image = np.array(Image.open(img_path).convert("RGB"))
    print(f"Image shape: {image.shape}")
    
    # Run YOLO segmentation (if available)
    if YOLO_AVAILABLE:
        print("\nRunning YOLO segmentation...")
        yolo = YoloSegmenter()
        seg_result = yolo.predict(image)
        print(f"Detected {len(seg_result.masks)} objects: {seg_result.labels}")
        
        # Extract lines
        print("\nExtracting lines...")
        pipeline = ContourToLinesESD()
        line_result = pipeline.extract_lines_from_contours(
            seg_result.contours, 
            image.shape[:2]
        )
        print(f"Extracted {len(line_result.line_segments)} line segments")
        
        # Visualize
        visualize_pipeline_result(image, seg_result, line_result)
    elif SAM_AVAILABLE:
        print("\nYOLO not available, using SAM with center point...")
        h, w = image.shape[:2]
        center = (w // 2, h // 2)
        
        sam = SamSegmenter()
        seg_result = sam.predict(image, point=center)
        print(f"Segmented object at center: score={seg_result.scores[0]:.2f}")
        
        # Extract lines
        pipeline = ContourToLinesESD()
        line_result = pipeline.extract_lines_from_contours(
            seg_result.contours,
            image.shape[:2]
        )
        print(f"Extracted {len(line_result.line_segments)} line segments")
        
        visualize_pipeline_result(image, seg_result, line_result)
    else:
        print("⚠ Neither YOLO nor SAM available. Install: uv pip install ultralytics segment-anything torch")
else:
    print("⚠ No test images found. Ensure BSDS500 dataset is available.")

## Summary

This notebook demonstrated the integration of PyTorch-based object segmentation with the
LineExtraction ESD framework:

### Pipeline Components

| Stage | Component | Description |
|-------|-----------|-------------|
| **Segmentation** | SAM / YOLO | Neural network generates pixel-accurate object masks |
| **Contour Extraction** | skimage | Marching squares extracts polygon contours from masks |
| **Simplification** | Douglas-Peucker | Reduces contour points while preserving shape |
| **Line Extraction** | ESD + LSD | LineExtraction framework converts contours to line segments |

### Key Classes

- `SegmentationModel` — Abstract base for segmentation models
- `SamSegmenter` — Segment Anything Model wrapper (click-to-segment)
- `YoloSegmenter` — YOLOv8 instance segmentation (automatic)
- `ContourToLinesESD` — Pipeline connecting segmentation to LineExtraction

### Extensions

1. **Add more models:** Implement `SegmentationModel` for other architectures (DeepLabV3, Mask R-CNN)
2. **Improve line fitting:** Use `le_lsd` detectors on edge images for sub-pixel accuracy
3. **Batch processing:** Process multiple images and aggregate statistics
4. **Export results:** Save line segments as SVG or other vector formats

### Dependencies

```bash
# Core
bazel build //libs/...  # Build LE Python bindings

# PyTorch ecosystem
uv pip install torch torchvision segment-anything ultralytics
```