
# Task Sampler — Visual Sanity Check

Edit the CONFIG below, then run all cells.  
This will:
1. Load your **ImageMetadata** via `dataset.get_image_metadata(...)`.
2. Build a **RamRaysDataset** (if its imports resolve), otherwise warn.
3. Use **RegionTaskSampler** to sample tasks.
4. **Display the RGB images** selected for support and query for the first few tasks.

> If `RamRaysDataset` fails to import because of `nerfs.ray_sampling`, the notebook will stop and tell you how to fix your Python path. No redefinitions are in this notebook; we import everything from your files.


In [None]:
# =====================
# CONFIG 
# =====================
from pathlib import Path
# Imports
import sys, os, math
# Move cwd up one level (from project_dir/jupyter -> project_dir)
ROOT = Path(__file__).resolve().parents[1] if "__file__" in globals() else Path(os.getcwd()).parents[0]
if os.getcwd() != "/mnt/nas_drive/psklavos/crexdata/MCLNF-FDA":
    os.chdir(ROOT)
    sys.path.insert(0, str(ROOT))

print("CWD set to:", os.getcwd())


# Sampler/task params
NUM_TASKS_TO_SHOW = 4
S = 3; Q = 2
REGION_ID = 1
RAYS_SUPPORT = 1000
RAYS_QUERY = 1000

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
from data.dataset import get_image_metadata, cap_metadata

DATA_PATH = "data/drz/out/prepared"  # e.g., "/data/out/partner_germany_site_01"
SCALE = .25            # 1.0 => original resolution
MASK_DIR = Path(DATA_PATH) / 'masks' /'g22_kmeans_bm110_ss13'/ str(REGION_ID)         # or a subdir name under DATA_PATH/masks
# Load metadata (train & val)
train_md, val_md = get_image_metadata(DATA_PATH, SCALE, mask_dir=MASK_DIR)

print(f"Loaded metadata: train={len(train_md)} images, val={len(val_md)} images")
print(f"Example image path: {train_md[0].image_path if len(train_md)>0 else val_md[0].image_path}")

In [None]:
from data.ram_rays_dataset import RamRaysDataset

# Build RamRaysDataset if possible
MAX_IMAGES = 250
train_md = cap_metadata(train_md,MAX_IMAGES)
kwargs = dict(center_pixels=True, device=torch.device("cpu"))
train_ds = RamRaysDataset(metadata_items=train_md, **kwargs)

use_ds = train_ds if train_ds is not None else None
print(f"Using dataset with {len(use_ds):,} rays from {use_ds._num_images} images.")

use_ds._rays[1000]
use_ds._img_indices[1000]

In [None]:
from data.task_dataset import TaskDataset
RAYS_SUPPORT = 4000
RAYS_QUERY   = 2000
MIN_RAYS     = 100_000     # skip if region has fewer than this many rays
# IMAGE_CAP    = 0.40     # max fraction per image in each split (set None to disable)
# STEPS_PER_EPOCH = 1000  # or None for infinite stream

task_ds = TaskDataset(
    ram_ds=use_ds,
    cell_id=REGION_ID,
    S_target=RAYS_SUPPORT,
    Q_target=RAYS_QUERY,
    image_cap= .4,
    min_rays_cell=MIN_RAYS,
    assignment_checkpoint= .7,
    debug=True,
    routing_policy='dda',
    cells=(1,7,7)
)


# Or, if you want a DataLoader (recommended; matches your MultiLoader pattern):
from torch.utils.data import DataLoader
task_loader = DataLoader(
    task_ds,
    batch_size=3,              # one Task per batch entry
    num_workers=0,             # keep 0 unless your buffers are immutable/thread-safe
    pin_memory=True,
    shuffle=False,
    drop_last=False,
    collate_fn=lambda xs: xs,  # keep list[Task]
)

# Example: get one task and plot like before
batch = next(iter(task_loader))
task = batch[0]
for task in batch:
    print("Region:", task.cell_id, "S:", int(task.metrics["S"]), "Q:", int(task.metrics["Q"]))


In [None]:
NUM_TASKS_TO_LOAD = 10
NUM_TASKS_TO_SHOW = 1

# Helper to show images given list of image indices (img_ids stored by the dataset)
def show_image(path, title=None, max_size=256):
    im = Image.open(path).convert("RGB")
    w, h = im.size
    scale = min(1.0, max_size / max(w, h))
    if scale < 1.0:
        im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
    plt.imshow(im); plt.axis("off")
    if title: plt.title(title)

# Build a map from image_id -> metadata (to get image file path)
# The dataset assigns unified image indices across train+val in get_image_metadata
id2md = {md.image_index: md for md in (train_md + val_md)}

# ---- SAMPLE TASKS (from TaskDataset) ----
tasks = []
it = iter(task_ds)  # task_ds is your TaskDataset instance for this region
for _ in range(NUM_TASKS_TO_LOAD):
    try:
        tasks.append(next(it))
    except StopIteration:
        break

print(f"Collected {len(tasks)} tasks; visualizing first {min(NUM_TASKS_TO_SHOW, len(tasks))}.")

# ---- VISUALIZE ----
for t_idx, task in enumerate(tasks[:NUM_TASKS_TO_SHOW], 1):
    # pull unique image ids from tensors (works with 'img_indices' or 'img_ids')
    sup_tensor = task.support.get("img_indices", task.support.get("img_ids"))
    qry_tensor = task.query.get("img_indices", task.query.get("img_ids"))

    sup_ids = torch.unique(sup_tensor).tolist() if sup_tensor is not None else []
    qry_ids = torch.unique(qry_tensor).tolist() if qry_tensor is not None else []

    # Grid layout: columns is the max of both sets; rows=2 only if we have queries
    cols = max(1, len(sup_ids), len(qry_ids))
    rows = 2 if len(qry_ids) > 0 else 1

    plt.figure(figsize=(3.5 * cols, 3.5 * rows))

    # --- Support row (row 1) ---
    for i, img_id in enumerate(sup_ids[:cols]):
        ax_idx = i + 1  # 1-based index
        plt.subplot(rows, cols, ax_idx)
        md = id2md[img_id]
        show_image(md.image_path, title=f"Support img_id={img_id}")

    # --- Query row (row 2) ---
    if rows == 2:
        for j, img_id in enumerate(qry_ids[:cols]):
            ax_idx = cols + j + 1
            plt.subplot(rows, cols, ax_idx)
            md = id2md[img_id]  # FIX: set inside this loop
            show_image(md.image_path, title=f"Query img_id={img_id}")

    S = int(task.metrics.get("S", len(task.support.get("rays", []))))
    Q = int(task.metrics.get("Q", len(task.query.get("rays", []))))
    plt.suptitle(f"Task {t_idx}: S={S} Q={Q}", y=0.99)
    plt.tight_layout()
    plt.show()
    if task.warnings:
        print("Warnings:")
        for warning in task.warnings:
            print(f"Task {t_idx}: {warning}")
            


In [None]:
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe

# --- colors (support/query only) ---
COLOR_SUP  = np.array([  0, 255, 255], np.uint8)   # cyan
COLOR_QRY  = np.array([255, 255, 0], np.uint8)   # pink
CELL_EDGE  = (0.70, 0.00, 0.70)                    # magenta-like, in 0..1 floats

ALPHA_SUP  = 0.75
ALPHA_QRY  = 0.65

DILATE_RADIUS       = 1
MAX_RAYS_PER_IMAGE  = 200_000
NUM_TASKS_TO_SHOW   = 10 

# ----------------- helpers -----------------

def _dilate(mask, ys, xs, R):
    if R <= 0:
        mask[ys, xs] = True; return
    H, W = mask.shape
    for dy in range(-R, R+1):
        yy = ys + dy
        v_y = (yy >= 0) & (yy < H)
        if not np.any(v_y): continue
        yy = yy[v_y]; xs_v = xs[v_y]
        for dx in range(-R, R+1):
            xx = xs_v + dx
            v_x = (xx >= 0) & (xx < W)
            if np.any(v_x):
                mask[yy[v_x], xx[v_x]] = True

def project_dirs_to_pixels(md, d_world):
    """Project world directions onto image pixels (handles +/-Z forward)."""
    c2w = md.c2w.to(d_world.device).float()
    R = c2w[:3,:3]; Rt = R.t()
    d_cam = (Rt @ d_world.T).T

    fx, fy, cx, cy = [float(x) for x in md.intrinsics]
    W, H = int(md.W), int(md.H)
    if max(fx, fy, cx, cy) <= 2.5:
        fx *= W; fy *= H; cx *= W; cy *= H

    def proj(sign):
        z = d_cam[:,2]
        if sign < 0:
            valid = z < -1e-6; denom = -z
        else:
            valid = z >  1e-6; denom =  z
        if not valid.any(): return np.empty(0), np.empty(0)
        dc = d_cam[valid]
        u = fx * (dc[:,0]/denom[valid]) + cx
        v = fy * (dc[:,1]/denom[valid]) + cy
        inb = (u >= 0) & (u < W) & (v >= 0) & (v < H)
        return u[inb].cpu().numpy(), v[inb].cpu().numpy()

    u1,v1 = proj(-1.0); u2,v2 = proj(+1.0)
    return (u1,v1) if u1.size >= u2.size else (u2,v2)

def build_mask_from_indices(md, idx_tensor, use_ds):
    """Indices -> project to pixels -> boolean mask (dedup + optional dilate)."""
    W, H = int(md.W), int(md.H)
    mask = np.zeros((H, W), dtype=bool)
    if idx_tensor is None or idx_tensor.numel() == 0:
        return mask, 0
    sel = idx_tensor
    if sel.numel() > MAX_RAYS_PER_IMAGE:
        perm = torch.randperm(sel.numel(), device=sel.device)[:MAX_RAYS_PER_IMAGE]
        sel = sel[perm]
    d_world = use_ds._rays[sel, 3:6].float()
    u, v = project_dirs_to_pixels(md, d_world)
    if u.size == 0: return mask, 0
    xs = np.rint(np.clip(u, 0, W-1)).astype(np.int32)
    ys = np.rint(np.clip(v, 0, H-1)).astype(np.int32)
    if xs.size > 0:
        uniq = np.unique(np.stack([ys, xs], 1), axis=0)
        ys, xs = uniq[:,0], uniq[:,1]
        _dilate(mask, ys, xs, DILATE_RADIUS)
    return mask, xs.size

def overlay_masks(md, sup_mask, qry_mask):
    """Blend two masks into the image (no overlap color; we will assert/clip overlaps)."""
    base = Image.open(md.image_path).convert("RGB")
    H, W = sup_mask.shape
    if base.size != (W, H):
        base = base.resize((W, H), Image.LANCZOS)
    out = np.array(base, dtype=np.uint8)

    # enforce pixel-level disjointness (visual); if you prefer assertion-only, drop this line
    qry_mask = np.logical_and(qry_mask, ~sup_mask)

    def blend(mask2d, color, alpha):
        if not np.any(mask2d): return
        pix = out[mask2d]
        out[mask2d] = ((1 - alpha) * pix + alpha * color).astype(np.uint8)

    blend(sup_mask, COLOR_SUP, ALPHA_SUP)
    blend(qry_mask, COLOR_QRY, ALPHA_QRY)
    return out

def _get_w2c(md):
    """Prefer w2c if present; otherwise invert c2w."""
    if hasattr(md, "w2c"):
        return md.w2c.cpu().numpy().astype(np.float64)
    c2w = md.c2w.cpu().numpy().astype(np.float64)
    R, t = c2w[:3,:3], c2w[:3,3]
    w2c = np.eye(4, dtype=np.float64)
    w2c[:3,:3] = R.T
    w2c[:3, 3] = -R.T @ t
    return w2c

def project_point(md, Pw, eps=1e-9):
    w2c = _get_w2c(md)
    Pc = (w2c[:3,:3] @ Pw) + w2c[:3,3]
    fx, fy, cx, cy = map(float, md.intrinsics)
    W, H = int(md.W), int(md.H)
    if max(fx, fy, cx, cy) <= 2.5:
        fx *= W; fy *= H; cx *= W; cy *= H
    z = float(Pc[2])
    # Always compute u,v with a safe denom; mark visibility separately.
    denom = z if abs(z) > eps else (eps if z >= 0 else -eps)
    u = fx * (Pc[0] / denom) + cx
    v = fy * (Pc[1] / denom) + cy
    visible = (z > eps) and (0 <= u < W) and (0 <= v < H)
    return (u, v, visible)



def _clip_2d_to_rect(p0, p1, W, H):
    # Cohen–Sutherland on [0,W]x[0,H]; returns (ok, q0, q1)
    LEFT, RIGHT, BOTTOM, TOP = 1, 2, 4, 8
    def code(x,y):
        c=0
        if x<0: c|=LEFT
        elif x>W: c|=RIGHT
        if y<0: c|=TOP
        elif y>H: c|=BOTTOM
        return c
    x0,y0 = p0; x1,y1 = p1
    c0, c1 = code(x0,y0), code(x1,y1)
    while True:
        if not (c0|c1):  # both inside
            return True, (x0,y0), (x1,y1)
        if c0 & c1:      # both outside same half-space
            return False, None, None
        c_out = c0 or c1
        if c_out & TOP:
            x = x0+(x1-x0)*(0-y0)/(y1-y0); y = 0
        elif c_out & BOTTOM:
            x = x0+(x1-x0)*(H-y0)/(y1-y0); y = H
        elif c_out & RIGHT:
            y = y0+(y1-y0)*(W-x0)/(x1-x0); x = W
        else: # LEFT
            y = y0+(y1-y0)*(0-x0)/(x1-x0); x = 0
        if c_out == c0:
            x0,y0 = x,y; c0 = code(x0,y0)
        else:
            x1,y1 = x,y; c1 = code(x1,y1)

def draw_cell_edges(ax, md, cell_bounds, color=CELL_EDGE, lw=3.0):
    lo = cell_bounds[0].detach().cpu().numpy()
    hi = cell_bounds[1].detach().cpu().numpy()
    C = np.array([
        [lo[0], lo[1], lo[2]], [hi[0], lo[1], lo[2]],
        [hi[0], hi[1], lo[2]], [lo[0], hi[1], lo[2]],
        [lo[0], lo[1], hi[2]], [hi[0], lo[1], hi[2]],
        [hi[0], hi[1], hi[2]], [lo[0], hi[1], hi[2]],
    ], dtype=np.float64)
    E = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]
    W, H = int(md.W), int(md.H)
    for a, b in E:
        ua, va, oka = project_point(md, C[a])
        ub, vb, okb = project_point(md, C[b])
        # require points to be in front of the camera (your project_point already checks Pc[2] > 0)
        if not oka and not okb:
            continue
        ok, (x0,y0), (x1,y1) = _clip_2d_to_rect((ua,va), (ub,vb), W-1, H-1)
        if not ok:
            continue
        ax.plot([x0, x1], [y0, y1], '-', color=color, linewidth=lw, alpha=1.0, zorder=3,
                path_effects=[pe.Stroke(linewidth=lw+1.5, foreground='k'), pe.Normal()])

# ----------------- visualize + assert -----------------

for t_idx, task in enumerate(tasks[:NUM_TASKS_TO_SHOW], 1):
    # --- (A) hard assertions for ray-level disjointness ---
    sup_idx = task.support.get("idx")
    qry_idx = task.query.get("idx")
    assert sup_idx is not None and qry_idx is not None, "Task must carry global 'idx' for both splits."
    # no duplicates inside a split
    assert sup_idx.numel() == torch.unique(sup_idx).numel(), "Duplicates inside support!"
    assert qry_idx.numel() == torch.unique(qry_idx).numel(), "Duplicates inside query!"
    A = torch.unique(task.support["idx"])
    B = torch.unique(task.query["idx"])
    AB = torch.unique(torch.cat([A, B]))
    assert AB.numel() == A.numel() + B.numel(), "S/Q intersect at ray level!"
    
    
    # --- group selected global ray indices per image ---
    sup_img = task.support["img_indices"]
    qry_img = task.query["img_indices"]
    sup_map, qry_map = {}, {}
    for img_id in torch.unique(sup_img).tolist():
        sup_map[img_id] = sup_idx[(sup_img == img_id)]
    for img_id in torch.unique(qry_img).tolist():
        qry_map[img_id] = qry_idx[(qry_img == img_id)]

    img_ids = sorted(set(sup_map.keys()) | set(qry_map.keys()))
    cols = min(4, max(1, len(img_ids)))
    rows = int(np.ceil(len(img_ids) / cols))
    plt.figure(figsize=(4.6*cols, 4.6*rows))

    for i, img_id in enumerate(img_ids, 1):
        ax = plt.subplot(rows, cols, i)
        md = id2md[img_id]

        # build masks
        sup_mask, _ = build_mask_from_indices(md, sup_map.get(img_id), use_ds)
        qry_mask, _ = build_mask_from_indices(md, qry_map.get(img_id), use_ds)


        # overlay
        rgb = overlay_masks(md, sup_mask, qry_mask)
        ax.imshow(rgb); ax.axis("off")
        ax.set_xlim(0, int(md.W)); ax.set_ylim(int(md.H), 0)  # lock view

        # draw the micro-cell box (task.bounds must be the cell AABB [2,3])
        # draw_cell_edges(ax, md, task.bounds)

        ax.set_title(f"Task {t_idx} | img_id={img_id}", fontsize=9)

    plt.suptitle(
        f"Task {t_idx}: region={task.cell_id}  cell={task.cell_id}  "
        f"S={int(task.metrics['S'])}  Q={int(task.metrics['Q'])}  "
        f"cell_rays={int(task.metrics['total_cell'])}",
        y=0.99, fontsize=10
    )
    plt.tight_layout(); plt.show()


In [None]:
# =========================
# Task visualizer (full)
# =========================
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as pe

# ----------------- config -----------------
NUM_TASKS_TO_SHOW   = 10

# colors
COLOR_SUP  = np.array([  0, 255, 255], np.uint8)  # cyan
COLOR_QRY  = np.array([255, 255,   0], np.uint8)  # yellow
CELL_FACE  = (0.70, 0.00, 0.70)                   # magenta-ish
CELL_EDGE  = (0.35, 0.00, 0.35)

ALPHA_SUP  = 0.80
ALPHA_QRY  = 0.70
CELL_ALPHA = 0.22

# ray subsampling for speed (per-image)
MAX_RAYS_PER_IMAGE = 200_000

# geometry threshold (what “counts” as inside the cell)
MIN_OVERLAP_FRAC = 0.03   # >= 3% of cell diagonal
MIN_OVERLAP_ABS  = 0.0    # or absolute distance in scene units

# ----------------- low-level helpers -----------------
def _get_w2c(md):
    """Prefer w2c if present; otherwise invert c2w."""
    if hasattr(md, "w2c"):
        return md.w2c.detach().cpu().numpy().astype(np.float64)
    c2w = md.c2w.detach().cpu().numpy().astype(np.float64)
    R, t = c2w[:3,:3], c2w[:3,3]
    w2c = np.eye(4, dtype=np.float64)
    w2c[:3,:3] = R.T
    w2c[:3, 3] = -R.T @ t
    return w2c

def project_point(md, Pw, eps=1e-9):
    """Project a single world point to pixel (u,v), return (u,v,visible)."""
    w2c = _get_w2c(md)
    Pc = (w2c[:3,:3] @ Pw) + w2c[:3,3]
    fx, fy, cx, cy = map(float, md.intrinsics)
    W, H = int(md.W), int(md.H)
    if max(fx, fy, cx, cy) <= 2.5:  # normalized intrinsics → pixels
        fx *= W; fy *= H; cx *= W; cy *= H
    z = float(Pc[2])
    denom = z if abs(z) > eps else (eps if z >= 0 else -eps)
    u = fx * (Pc[0] / denom) + cx
    v = fy * (Pc[1] / denom) + cy
    visible = (z > eps) and (0 <= u < W) and (0 <= v < H)
    return (u, v, visible)

# ----------------- cell projection (filled faces) -----------------
_CELL_EDGES = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]
_CELL_FACES = [
    (0,1,2,3),  # z=lo
    (4,5,6,7),  # z=hi
    (0,1,5,4),  # y=lo
    (2,3,7,6),  # y=hi
    (1,2,6,5),  # x=hi
    (0,3,7,4),  # x=lo
]

def _cell_corners_numpy(cell_bounds):
    lo = cell_bounds[0].detach().cpu().numpy()
    hi = cell_bounds[1].detach().cpu().numpy()
    return np.array([
        [lo[0], lo[1], lo[2]], [hi[0], lo[1], lo[2]],
        [hi[0], hi[1], lo[2]], [lo[0], hi[1], lo[2]],
        [lo[0], lo[1], hi[2]], [hi[0], lo[1], hi[2]],
        [hi[0], hi[1], hi[2]], [lo[0], hi[1], hi[2]],
    ], dtype=np.float64)

def _clip_2d_to_rect(poly, W, H):
    """Sutherland–Hodgman clip of polygon 'poly' against [0,W-1]x[0,H-1]."""
    def clip(polygon, edge):
        out = []
        if not polygon: return out
        if edge == 'left':
            inside = lambda P: P[0] >= 0
            I = lambda A,B: (0, A[1] + (B[1]-A[1])*(0-A[0])/(B[0]-A[0] + 1e-20))
        elif edge == 'right':
            xw = W-1
            inside = lambda P: P[0] <= xw
            I = lambda A,B: (xw, A[1] + (B[1]-A[1])*(xw-A[0])/(B[0]-A[0] + 1e-20))
        elif edge == 'top':
            inside = lambda P: P[1] >= 0
            I = lambda A,B: (A[0] + (B[0]-A[0])*(0-A[1])/(B[1]-A[1] + 1e-20), 0)
        else:  # bottom
            yh = H-1
            inside = lambda P: P[1] <= yh
            I = lambda A,B: (A[0] + (B[0]-A[0])*(yh-A[1])/(B[1]-A[1] + 1e-20), yh)
        S = polygon[-1]
        for E in polygon:
            if inside(E):
                if inside(S): out.append(E)
                else: out.append(I(S,E)); out.append(E)
            else:
                if inside(S): out.append(I(S,E))
            S = E
        return out
    for e in ('left','right','top','bottom'):
        poly = clip(poly, e)
        if not poly: break
    return poly

def _project_cell_faces(md, cell_bounds):
    C = _cell_corners_numpy(cell_bounds)
    W, H = int(md.W), int(md.H)
    # project all corners (even if behind)
    UV = []
    for Pw in C:
        u, v, _ = project_point(md, Pw)
        UV.append((u, v))
    UV = np.array(UV, float)
    # build & clip 6 faces
    polys = []
    for f in _CELL_FACES:
        poly = [(UV[i,0], UV[i,1]) for i in f]
        poly = _clip_2d_to_rect(poly, W, H)
        if len(poly) >= 3:
            polys.append(poly)
    return polys

def draw_cell_filled(ax, md, cell_bounds, face_alpha=CELL_ALPHA, edge_alpha=0.9,
                     face_color=CELL_FACE, edge_color=CELL_EDGE, lw=1.5):
    polys = _project_cell_faces(md, cell_bounds)
    for poly in polys:
        ax.add_patch(mpatches.Polygon(poly, closed=True,
                                      facecolor=face_color, edgecolor=edge_color,
                                      linewidth=lw, alpha=face_alpha, zorder=3))
    # draw edges stronger
    C = _cell_corners_numpy(cell_bounds)
    for (i,j) in _CELL_EDGES:
        u0,v0,_ = project_point(md, C[i]); u1,v1,_ = project_point(md, C[j])
        seg = _clip_2d_to_rect([(u0,v0),(u1,v1)], int(md.W), int(md.H))
        if len(seg) >= 2:
            xs, ys = zip(*seg[:2])
            ax.plot(xs, ys, '-', color=edge_color, linewidth=lw, alpha=edge_alpha, zorder=4,
                    path_effects=[pe.Stroke(linewidth=lw+1.2, foreground='k'), pe.Normal()])

# ----------------- in-cell midpoint computation -----------------
def _ray_cell_segment_np(rays_np, cell_bounds_np):
    """Compute [t_entry, t_exit] of ray vs AABB; returns (hit, t0, t1)."""
    o = rays_np[:, :3]; d = rays_np[:, 3:6]
    near = rays_np[:, 6]; far = rays_np[:, 7]
    lo, hi = cell_bounds_np[0], cell_bounds_np[1]
    invd = 1.0 / np.where(np.abs(d) < 1e-12, np.sign(d)*1e-12, d)
    t0 = (lo - o) * invd
    t1 = (hi - o) * invd
    tmin = np.minimum(t0, t1).max(axis=1)
    tmax = np.maximum(t0, t1).min(axis=1)
    t_entry = np.maximum.reduce([tmin, near, np.zeros_like(near)])
    t_exit  = np.minimum(tmax, far)
    hit = t_exit > t_entry
    return hit, t_entry, t_exit

def incell_midpoints_world(rays_sel, cell_bounds):
    """
    returns:
      P (M,3) world midpoints,
      keep_mask over K,
      overlap_lengths (only for kept rays)
    """
    if rays_sel is None or rays_sel.numel() == 0:
        return np.empty((0,3)), np.zeros((0,), bool), np.empty((0,))
    r = rays_sel.detach().cpu().numpy()
    cb = cell_bounds.detach().cpu().numpy()
    hit, t0, t1 = _ray_cell_segment_np(r, cb)
    cell_diag = np.linalg.norm(cb[1] - cb[0])
    min_len = max(MIN_OVERLAP_ABS, MIN_OVERLAP_FRAC * cell_diag)
    L = (t1 - t0)
    keep = hit & (L >= min_len)
    if not keep.any():
        return np.empty((0,3)), keep, np.empty((0,))
    tm = 0.5 * (t0[keep] + t1[keep])
    P = r[keep, :3] + r[keep, 3:6] * tm[:, None]
    return P, keep, L[keep]

def project_points_uv(md, P):
    """Project array of world points (M,3) to integer pixels (N,2)."""
    if P.size == 0: return np.empty((0,2), int)
    out = []
    for Pw in P:
        u, v, vis = project_point(md, Pw)
        if vis: out.append((int(round(u)), int(round(v))))
    if not out: return np.empty((0,2), int)
    return np.array(out, int)

# ----------------- plotting + assertions -----------------
def assert_ray_disjoint_and_incell(sup_idx, qry_idx, cell_bounds):
    """Hard checks: S/Q disjointness + all rays overlap cell by threshold."""
    # disjointness
    su = torch.unique(sup_idx); qu = torch.unique(qry_idx)
    AB = torch.unique(torch.cat([su, qu]))
    assert AB.numel() == su.numel() + qu.numel(), "S/Q intersect at ray level!"

    # geometry
    def _check(idx, name):
        r = use_ds._rays[idx]
        _, keep, _ = incell_midpoints_world(r, cell_bounds)
        if not np.all(keep):
            bad = int((~keep).sum())
            raise AssertionError(f"{bad} {name} rays do not overlap cell by "
                                 f"min_len=max({MIN_OVERLAP_ABS}, {MIN_OVERLAP_FRAC}*diag).")
    _check(sup_idx, "support"); _check(qry_idx, "query")

def build_img_to_indices_map(img_tensor, idx_tensor):
    m = {}
    for img_id in torch.unique(img_tensor).tolist():
        sel = idx_tensor[(img_tensor == img_id)]
        # optional speed: downsample a lot
        if sel.numel() > MAX_RAYS_PER_IMAGE:
            perm = torch.randperm(sel.numel(), device=sel.device)[:MAX_RAYS_PER_IMAGE]
            sel = sel[perm]
        m[img_id] = sel
    return m

def print_task_summary(task, sup_map, qry_map, Ls, Lq):
    sup_ids = sorted(sup_map.keys()); qry_ids = sorted(qry_map.keys())
    print(f"Task cell={task.cell_id} | S={int(task.metrics['S'])} Q={int(task.metrics['Q'])} | "
          f"S_imgs={len(sup_ids)} Q_imgs={len(qry_ids)} | disjoint_ok={bool(task.metrics.get('image_disjoint_ok',1.0))}")
    # per-image counts
    def counts(m): return {k:int(m[k].numel()) for k in sorted(m)}
    print("  support imgs & counts:", counts(sup_map))
    print("  query   imgs & counts:", counts(qry_map))
    # overlap stats
    fmt = lambda a: f"min={a.min():.4g} mean={a.mean():.4g} max={a.max():.4g}" if a.size else "n/a"
    print("  overlap S:", fmt(Ls), "| overlap Q:", fmt(Lq))
    if task.warnings:
        for w in task.warnings:
            print("  warning:", w)

# ----------------- MAIN LOOP -----------------
for t_idx, task in enumerate(tasks[:NUM_TASKS_TO_SHOW], 1):
    sup_idx = task.support["idx"]; qry_idx = task.query["idx"]
    assert sup_idx.numel() == torch.unique(sup_idx).numel(), "Duplicates in support!"
    assert qry_idx.numel() == torch.unique(qry_idx).numel(), "Duplicates in query!"
    # assert_ray_disjoint_and_incell(sup_idx, qry_idx, task.bounds)

    sup_img = task.support["img_indices"]; qry_img = task.query["img_indices"]
    sup_map = build_img_to_indices_map(sup_img, sup_idx)
    qry_map = build_img_to_indices_map(qry_img, qry_idx)

    # precompute overlap lengths (for printing)
    rS = use_ds._rays[sup_idx]; _, _, Ls = incell_midpoints_world(rS, task.bounds)
    rQ = use_ds._rays[qry_idx]; _, _, Lq = incell_midpoints_world(rQ, task.bounds)
    print_task_summary(task, sup_map, qry_map, Ls, Lq)

    # figure layout
    img_ids = sorted(set(sup_map.keys()) | set(qry_map.keys()))
    cols = min(4, max(1, len(img_ids)))
    rows = int(np.ceil(len(img_ids) / cols))
    plt.figure(figsize=(4.8*cols, 4.8*rows))

    for i, img_id in enumerate(img_ids, 1):
        ax = plt.subplot(rows, cols, i)
        md = id2md[img_id]
        base = Image.open(md.image_path).convert("RGB")
        ax.imshow(base); ax.axis("off")
        ax.set_xlim(0, int(md.W)); ax.set_ylim(int(md.H), 0)

        # SUPPORT points (in-cell midpoints)
        if img_id in sup_map:
            P, _, _ = incell_midpoints_world(use_ds._rays[sup_map[img_id]], task.bounds)
            uv = project_points_uv(md, P)
            if uv.size:
                ax.scatter(uv[:,0], uv[:,1], s=6, marker='o', linewidths=0,
                           c=[(0,1,1)], alpha=ALPHA_SUP, zorder=5)

        # QUERY points (in-cell midpoints)
        if img_id in qry_map:
            P, _, _ = incell_midpoints_world(use_ds._rays[qry_map[img_id]], task.bounds)
            uv = project_points_uv(md, P)
            if uv.size:
                ax.scatter(uv[:,0], uv[:,1], s=6, marker='o', linewidths=0,
                           c=[(1,1,0)], alpha=ALPHA_QRY, zorder=5)

        # draw filled voxel (under points)
        draw_cell_filled(ax, md, task.bounds, face_alpha=CELL_ALPHA)

        ax.set_title(f"Task {t_idx} | img_id={img_id}", fontsize=9)

    plt.suptitle(
        f"Task {t_idx}: region={task.cid}  cell={task.cell_id}  "
        f"S={int(task.metrics['S'])}  Q={int(task.metrics['Q'])}  "
        f"cell_rays={int(task.metrics['total_cell'])}",
        y=0.99, fontsize=10
    )
    plt.tight_layout()
    plt.show()
