## IMPORTANT

Sometimes the sample axes need to be read in the opposite way to normal, e.g. right to left instead of left to right.

If you find the stitched image looks off, then search for the parameters `invertx` and `inverty` in this notebook and change them according to the comments found next to those variables.

In [None]:
# pythondata_folder = r"/mnt/data/semeds/20250825_sample 10 (eds)/Project 1/h5data"
#pythondata_folder = r"/mnt/data/semeds/20250826 # 10 after (eds)/Project 1/h5pythondata"
pythondata_folder = r"/mnt/data/semeds/20241210 Jonah 9 (eds)/Project 1/h5pythondata"
#pythondata_folder = r"/mnt/data/semeds/20241218 Jonah 10 (eds)/Project 1/h5pythondata"

# pythondata_folder = r"/media/lenr/Data/semeds/Florian Ti (eds)/h5data"

In [None]:
# --- user toggles ---
invertx = 1   # 0 = normal, 1 = mirror left↔right
inverty = 1   # 0 = normal, 1 = mirror top↔bottom

In [None]:
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple
import os
import re

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import ipywidgets as w
from IPython.display import Javascript, display

In [None]:
folder = Path(pythondata_folder)

In [None]:
META_RE = re.compile(r"^\s*(?P<key>.+?)\s*:\s*(?P<val>.+?)\s*$")

# 1) Add a field to the dataclass
@dataclass
class SiteRecord:
    basename: str
    png_path: str
    meta_path: str
    npz_path: Optional[str] = None        
    width_px: Optional[int] = None
    height_px: Optional[int] = None
    px_x_um: Optional[float] = None
    px_y_um: Optional[float] = None
    tile_w_um: Optional[float] = None
    tile_h_um: Optional[float] = None
    stage_x_um: Optional[float] = None
    stage_y_um: Optional[float] = None
    stage_z_um: Optional[float] = None

def parse_metadata_txt(path: Path) -> Dict[str, str]:
    out: Dict[str, str] = {}
    for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = META_RE.match(line)
        if m:
            out[m.group("key").strip()] = m.group("val").strip()
    return out

def get_float(d: Dict[str, str], key: str) -> Optional[float]:
    v = d.get(key)
    if v is None: return None
    try:
        return float(v)
    except Exception:
        try:
            return float(v.replace(",", "."))
        except Exception:
            return None

def robust_image_size(p: Path) -> Tuple[int, int]:
    with Image.open(p) as im:
        return im.size

def pair_png_and_meta(folder: Path):
    pairs = []
    for meta in sorted(folder.glob("*_metadata.txt")):
        stem = meta.name.replace("_metadata.txt", "")
        png = folder / f"{stem}_sem.png"
        if not png.exists():
            cands = sorted(folder.glob(stem + "*.png"))
            png = cands[0] if cands else None
        if png and png.exists():
            pairs.append((png, meta, stem))
    return pairs

def build_site_records(folder: Path) -> List[SiteRecord]:
    recs = []
    for png, meta, stem in pair_png_and_meta(folder):
        md = parse_metadata_txt(meta)

        w = int(get_float(md, "/1/Electron Image/Header/X Cells") or 0)
        h = int(get_float(md, "/1/Electron Image/Header/Y Cells") or 0)
        if w == 0 or h == 0:
            w, h = robust_image_size(png)

        px_x = get_float(md, "/1/Electron Image/Header/X Step")
        px_y = get_float(md, "/1/Electron Image/Header/Y Step")

        sem_npz = png.with_suffix(".npz")  # e.g. "..._sem.png" -> "..._sem.npz"
        recs.append(SiteRecord(
            basename=stem,
            png_path=str(png),
            meta_path=str(meta),
            npz_path=str(sem_npz) if sem_npz.exists() else str(sem_npz),  # keep path even if missing
            width_px=w, height_px=h,
            px_x_um=px_x, px_y_um=px_y,
            tile_w_um=(w*px_x if (w and px_x) else None),
            tile_h_um=(h*px_y if (h and px_y) else None),
            stage_x_um=1000*get_float(md, "/1/EDS/Header/Stage Position/X"),
            stage_y_um=1000*get_float(md, "/1/EDS/Header/Stage Position/Y"),
            stage_z_um=1000*get_float(md, "/1/EDS/Header/Stage Position/Z"),
        ))
    return recs

# simple clustering → grid row/col from stage coords
def _cluster_1d(vals: np.ndarray, tol: float):
    order = np.argsort(vals)
    labels = np.empty_like(order, dtype=int)
    centers = [vals[order[0]]] if len(order) else []
    lab = 0
    if len(order):
        labels[order[0]] = 0
        for i in order[1:]:
            v = vals[i]
            if abs(v - centers[-1]) > tol:
                lab += 1; centers.append(v)
            labels[i] = lab
    return labels

def infer_grid_indices(df_pos: pd.DataFrame, tol_frac: float = 0.55) -> Tuple[np.ndarray,np.ndarray]:
    w_med = float(np.nanmedian(df_pos["tile_w_um"])) if df_pos["tile_w_um"].notna().any() else 0.0
    h_med = float(np.nanmedian(df_pos["tile_h_um"])) if df_pos["tile_h_um"].notna().any() else 0.0
    tol_x = w_med*tol_frac if w_med else max(1.0, np.ptp(df_pos["X_rel_um"].to_numpy(float))/20.0)
    tol_y = h_med*tol_frac if h_med else max(1.0, np.ptp(df_pos["Y_rel_um"].to_numpy(float))/20.0)
    cols = _cluster_1d(df_pos["X_rel_um"].to_numpy(float), tol_x)
    rows = _cluster_1d(df_pos["Y_rel_um"].to_numpy(float), tol_y)
    return rows, cols

In [None]:
sites = build_site_records(folder)
site_df = pd.DataFrame([asdict(s) for s in sites])
site_df

In [None]:
# --- Add per-image low/high + dtype, then plot dtype counts once ---

def _load_gray_and_dtype(path: str | Path):
    """
    Load an image from disk, preserving bit depth when possible.
    Returns (array_2d, dtype_string).
    - If the image is multi-channel, convert to grayscale (L) for stats.
    - Keeps 16-bit PNG/TIFF as uint16 when present.
    """
    p = Path(path)
    with Image.open(p) as im:
        mode = im.mode  # e.g., 'L', 'I;16', 'RGB', 'RGBA', 'I'
        if mode in ("I;16",):
            arr = np.array(im, dtype=np.uint16)
        elif mode == "I":
            arr = np.array(im, dtype=np.int32)
        else:
            # For any 8-bit (L/RGB/RGBA/...) convert to 8-bit grayscale
            if mode != "L":
                im = im.convert("L")
            arr = np.array(im, dtype=np.uint8)
    if arr.ndim == 3:
        arr = arr[..., 0]
    return arr, str(arr.dtype)

# Compute per-image stats
mins, maxs, dtypes = [], [], []
for p in site_df["png_path"]:
    try:
        a, dt = _load_gray_and_dtype(p)
        a = a[np.isfinite(a)]
        mins.append(float(a.min()) if a.size else np.nan)
        maxs.append(float(a.max()) if a.size else np.nan)
        dtypes.append(dt)
    except Exception:
        mins.append(np.nan); maxs.append(np.nan); dtypes.append("(unreadable)")

# Attach to df
site_df["img_min"] = mins
site_df["img_max"] = maxs
site_df["img_dtype"] = dtypes

site_df.head()

In [None]:
print(f"Pairs found: {len(site_df)}")
print("Sizes (w×h):")
print(site_df[["width_px", "height_px"]].value_counts().sort_values(ascending=False).to_string())

# Ensure we have absolute stage positions in µm
need_cols = {"stage_x_um","stage_y_um"}
if not need_cols.issubset(site_df.columns) or site_df[list(need_cols)].isna().any().any():
    raise ValueError("stage_x_um / stage_y_um required (µm).")

# --- 1) start from ABSOLUTE stage coords (µm) ---
x_abs = site_df["stage_x_um"].astype(float).to_numpy()
y_abs = site_df["stage_y_um"].astype(float).to_numpy()

# --- 2) optionally invert within the dataset bounds ---
if invertx:
    xmin, xmax = float(np.nanmin(x_abs)), float(np.nanmax(x_abs))
    x_abs = (xmin + xmax) - x_abs

if inverty:
    ymin, ymax = float(np.nanmin(y_abs)), float(np.nanmax(y_abs))
    y_abs = (ymin + ymax) - y_abs

# (optional) keep the inverted absolutes for reference/debug
site_df["stage_x_um_inv"] = x_abs
site_df["stage_y_um_inv"] = y_abs

# --- 3) compute RELATIVE coords from the (possibly inverted) absolutes ---
x0, y0 = float(np.nanmin(x_abs)), float(np.nanmin(y_abs))
site_df["X_rel_um"] = x_abs - x0
site_df["Y_rel_um"] = y_abs - y0

site_df

In [None]:
# === Global min/max + recommended clipping (from NPZ raw 'sem_data') ===

# knobs
CLIP_PERCENT    = 1.0      # e.g., 1% → use p1 / p99
SAMPLE_PER_TILE = 5000     # how many pixels to sample per image (0 = use all)
RNG_SEED        = 0

def _load_npz_sem(path: str | Path):
    """Load raw SEM plane from *_sem.npz (key 'sem_data'). Returns (arr2d, dtype_str)."""
    p = Path(path)
    with np.load(p) as d:
        a = d["sem_data"]           # keep native dtype (often uint16)
    if a.ndim == 3:                  # just in case
        a = a[..., 0]
    return a, str(a.dtype)

# 0) ensure we have an npz_path column
if "npz_path" not in site_df.columns:
    raise ValueError("site_df is missing 'npz_path'. Add it when building records.")

# 1) per-file NPZ min/max (cached in df so we can reuse)
if not {"npz_min","npz_max","npz_dtype"}.issubset(site_df.columns):
    npz_mins, npz_maxs, npz_dtypes = [], [], []
    for p in site_df["npz_path"]:
        try:
            a, dt = _load_npz_sem(p)
            a = a[np.isfinite(a)]
            npz_mins.append(float(a.min()) if a.size else np.nan)
            npz_maxs.append(float(a.max()) if a.size else np.nan)
            npz_dtypes.append(dt)
        except Exception:
            npz_mins.append(np.nan); npz_maxs.append(np.nan); npz_dtypes.append("(missing)")
    site_df["npz_min"]   = npz_mins
    site_df["npz_max"]   = npz_maxs
    site_df["npz_dtype"] = npz_dtypes

# quick global min/max from NPZs
GLOBAL_MIN = float(np.nanmin(site_df["npz_min"].to_numpy()))
GLOBAL_MAX = float(np.nanmax(site_df["npz_max"].to_numpy()))
print(f"Global raw min/max from NPZ: {GLOBAL_MIN:.3g}, {GLOBAL_MAX:.3g}")

# 2) robust global lo/hi by sampling pixels from NPZs
rng = np.random.default_rng(RNG_SEED)
samples = []
for p in site_df["npz_path"]:
    try:
        a, _ = _load_npz_sem(p)
        flat = a.ravel().astype(np.float64, copy=False)
        if SAMPLE_PER_TILE and flat.size > SAMPLE_PER_TILE:
            idx = rng.choice(flat.size, SAMPLE_PER_TILE, replace=False)
            flat = flat[idx]
        samples.append(flat)
    except Exception:
        continue

if samples:
    all_samples = np.concatenate(samples)
    if CLIP_PERCENT and CLIP_PERCENT > 0:
        GLOBAL_LO, GLOBAL_HI = np.nanpercentile(all_samples, [CLIP_PERCENT, 100.0 - CLIP_PERCENT])
    else:
        GLOBAL_LO, GLOBAL_HI = float(np.nanmin(all_samples)), float(np.nanmax(all_samples))
else:
    # fallback to raw extremes if sampling failed
    GLOBAL_LO, GLOBAL_HI = GLOBAL_MIN, GLOBAL_MAX

print(f"Recommended NPZ clipping @ {CLIP_PERCENT}%: lo={GLOBAL_LO:.3g}, hi={GLOBAL_HI:.3g}")

# 3) stash constants in the dataframe for downstream plotting/normalization
site_df["global_min"]   = GLOBAL_MIN
site_df["global_max"]   = GLOBAL_MAX
site_df["global_lo"]    = float(GLOBAL_LO)
site_df["global_hi"]    = float(GLOBAL_HI)
site_df["clip_percent"] = CLIP_PERCENT

site_df.head(3)

In [None]:
# === Min/Max of X_rel_um, Y_rel_um + quick plot ===

df_mm = site_df.copy()

# If relative coords missing, derive from stage coords
if ("X_rel_um" not in df_mm.columns) or df_mm["X_rel_um"].isna().any():
    if {"stage_x_um","stage_y_um"}.issubset(df_mm.columns) and df_mm[["stage_x_um","stage_y_um"]].notna().all(axis=None):
        x0, y0 = float(df_mm["stage_x_um"].min()), float(df_mm["stage_y_um"].min())
        df_mm["X_rel_um"] = df_mm["stage_x_um"] - x0
        df_mm["Y_rel_um"] = df_mm["stage_y_um"] - y0
    else:
        raise ValueError("No X_rel_um/Y_rel_um or stage_x_um/stage_y_um available.")

# Drop NAs and compute mins/maxes
x = df_mm["X_rel_um"].astype(float).dropna()
y = df_mm["Y_rel_um"].astype(float).dropna()
xmin, xmax = float(x.min()), float(x.max())
ymin, ymax = float(y.min()), float(y.max())

print(f"X_rel_um: min={xmin:.6f} µm, max={xmax:.6f} µm, span={xmax-xmin:.6f} µm")
print(f"Y_rel_um: min={ymin:.6f} µm, max={ymax:.6f} µm, span={ymax-ymin:.6f} µm")

# Four corners for visualization
corners = {
    "NW": (xmin, ymin),
    "NE": (xmax, ymin),
    "SW": (xmin, ymax),
    "SE": (xmax, ymax),
}

fig = go.Figure()

# Invisible trace to auto-set bounds
fig.add_trace(go.Scatter(
    x=[xmin, xmax], y=[ymin, ymax],
    mode="markers", opacity=0, showlegend=False, hoverinfo="skip"
))

# Rectangle showing the span box
fig.add_shape(
    type="rect", x0=xmin, y0=ymin, x1=xmax, y1=ymax,
    line=dict(width=3), fillcolor=None
)

# Corner markers with labels
fig.add_trace(go.Scatter(
    x=[v[0] for v in corners.values()],
    y=[v[1] for v in corners.values()],
    mode="markers+text",
    text=list(corners.keys()),
    textposition="top center",
    name="Corners"
))

fig.update_layout(
    title="Extent of relative coordinates",
    xaxis_title="X_rel_um (µm)",
    yaxis_title="Y_rel_um (µm, top-down)",
    height=500, margin=dict(l=60, r=20, t=50, b=60),
    plot_bgcolor="white"
)
fig.update_xaxes(showline=True, mirror=True, zeroline=False, ticks="outside")
fig.update_yaxes(showline=True, mirror=True, zeroline=False, ticks="outside",
                 scaleanchor="x", autorange="reversed")  # top-left origin

fig.show()

In [None]:
# === Plot all positions (X_rel_um, Y_rel_um) with bounding rectangle ===

D = site_df.copy()

# Ensure relative coords exist
if ("X_rel_um" not in D.columns) or D["X_rel_um"].isna().any():
    if {"stage_x_um","stage_y_um"}.issubset(D.columns) and D[["stage_x_um","stage_y_um"]].notna().all(axis=None):
        x0, y0 = float(D["stage_x_um"].min()), float(D["stage_y_um"].min())
        D["X_rel_um"] = D["stage_x_um"] - x0
        D["Y_rel_um"] = D["stage_y_um"] - y0
    else:
        raise ValueError("Need X_rel_um/Y_rel_um or stage_x_um/stage_y_um.")

# keep rows with finite coords
D = D[np.isfinite(D["X_rel_um"]) & np.isfinite(D["Y_rel_um"])].copy()

# bounds
xmin, xmax = float(D["X_rel_um"].min()), float(D["X_rel_um"].max())
ymin, ymax = float(D["Y_rel_um"].min()), float(D["Y_rel_um"].max())

# hover labels
name_col = "Label" if "Label" in D.columns else ("basename" if "basename" in D.columns else None)
hover_text = (D[name_col].astype(str) if name_col else D.index.astype(str))

fig = go.Figure()

# bounding rectangle
fig.add_shape(type="rect", x0=xmin, y0=ymin, x1=xmax, y1=ymax,
              line=dict(width=2), fillcolor=None)

# all points
fig.add_trace(go.Scattergl(
    x=D["X_rel_um"],
    y=D["Y_rel_um"],
    mode="markers",
    marker=dict(size=5, opacity=0.7),
    text=hover_text,
    hovertemplate="idx=%{customdata}<br>X=%{x:.3f} µm<br>Y=%{y:.3f} µm<br>%{text}",
    customdata=D.index,
    name="positions"
))

fig.update_layout(
    title="All positions in relative-coordinate space",
    xaxis_title="X_rel_um (µm)",
    yaxis_title="Y_rel_um (µm, top-down)",
    height=650, margin=dict(l=60, r=20, t=50, b=60),
    plot_bgcolor="white",
)

# 1:1 aspect and top-left origin
fig.update_xaxes(showline=True, mirror=True, zeroline=False, ticks="outside")
fig.update_yaxes(showline=True, mirror=True, zeroline=False, ticks="outside",
                 scaleanchor="x", autorange="reversed")

fig.show()

In [None]:
# === Plot rectangles for every position ===

# 1) Build/refresh the dataframe if you haven't already in this session
# records = build_site_records(Path(r"/media/lenr/Data/semeds/20250401 #12 (eds)/Project 1/pythondata"))
# site_df = pd.DataFrame([asdict(r) for r in records])

D = site_df.copy()

# 2) Ensure tile sizes in µm exist (px * µm/px as fallback)
if ("TileWidth_um" not in D.columns) or D["TileWidth_um"].isna().any():
    if {"width_px","px_x_um"}.issubset(D.columns):
        D["TileWidth_um"] = D.get("TileWidth_um")
        m = D["TileWidth_um"].isna()
        D.loc[m, "TileWidth_um"] = D.loc[m, "width_px"] * D.loc[m, "px_x_um"]
if ("TileHeight_um" not in D.columns) or D["TileHeight_um"].isna().any():
    if {"height_px","px_y_um"}.issubset(D.columns):
        D["TileHeight_um"] = D.get("TileHeight_um")
        m = D["TileHeight_um"].isna()
        D.loc[m, "TileHeight_um"] = D.loc[m, "height_px"] * D.loc[m, "px_y_um"]

# 3) Ensure relative stage coords in µm exist
if ("X_rel_um" not in D.columns) or D["X_rel_um"].isna().any() or ("Y_rel_um" not in D.columns):
    x0, y0 = float(D["stage_x_um"].min()), float(D["stage_y_um"].min())
    D["X_rel_um"] = D["stage_x_um"] - x0
    D["Y_rel_um"] = D["stage_y_um"] - y0

# 4) Keep only rows with complete geometry
need = ["X_rel_um","Y_rel_um","TileWidth_um","TileHeight_um"]
D = D.dropna(subset=need).copy()

# 5) Vectorized rectangle geometry
x0 = D["X_rel_um"].to_numpy(float)
y0 = D["Y_rel_um"].to_numpy(float)
x1 = (D["X_rel_um"] + D["TileWidth_um"]).to_numpy(float)
y1 = (D["Y_rel_um"] + D["TileHeight_um"]).to_numpy(float)

# 6) Build plotly figure
fig = go.Figure()

# invisible trace → sets bounds quickly
fig.add_trace(go.Scatter(
    x=[x0.min(), x1.max()],
    y=[y0.min(), y1.max()],
    mode="markers", opacity=0, showlegend=False, hoverinfo="skip"
))

# all rectangles as layout shapes (fast for a few thousand)
fig.update_layout(
    shapes=[dict(type="rect", x0=a, y0=b, x1=c, y1=d, line=dict(width=1))
            for a, b, c, d in zip(x0, y0, x1, y1)],
    title="All tiles (top-left origin)",
    height=800, margin=dict(l=60, r=20, t=50, b=60),
    plot_bgcolor="white",
)

# 1:1 aspect, top-down Y
fig.update_xaxes(title="X (µm)", showline=True, mirror=True, zeroline=False, ticks="outside")
fig.update_yaxes(title="Y (µm, top-down)", showline=True, mirror=True, zeroline=False,
                 ticks="outside", scaleanchor="x", autorange="reversed")

fig.show()

In [None]:
# === Overlay first/last N thumbnails on the grid ===

D = site_df.copy()

# knobs
N_FIRST = 500
N_LAST  = 500
THUMB_MAX_SIDE = 400     # px (downscale if either side larger)
IMG_OPACITY = 1.0        # 0..1

# --- ensure geometry (same logic as before) ---
if ("TileWidth_um" not in D.columns) or D["TileWidth_um"].isna().any():
    if {"width_px","px_x_um"}.issubset(D.columns):
        D["TileWidth_um"] = D.get("TileWidth_um")
        m = D["TileWidth_um"].isna()
        D.loc[m, "TileWidth_um"] = D.loc[m, "width_px"] * D.loc[m, "px_x_um"]
if ("TileHeight_um" not in D.columns) or D["TileHeight_um"].isna().any():
    if {"height_px","px_y_um"}.issubset(D.columns):
        D["TileHeight_um"] = D.get("TileHeight_um")
        m = D["TileHeight_um"].isna()
        D.loc[m, "TileHeight_um"] = D.loc[m, "height_px"] * D.loc[m, "px_y_um"]

if ("X_rel_um" not in D.columns) or ("Y_rel_um" not in D.columns) or D["X_rel_um"].isna().any() or D["Y_rel_um"].isna().any():
    x0, y0 = float(D["stage_x_um"].min()), float(D["stage_y_um"].min())
    D["X_rel_um"] = D["stage_x_um"] - x0
    D["Y_rel_um"] = D["stage_y_um"] - y0

# rows we can actually place + have an image path
geom_cols = ["X_rel_um","Y_rel_um","TileWidth_um","TileHeight_um","png_path"]
Dv = D.dropna(subset=geom_cols).copy()

# choose first/last without duplicating the middle if dataset < N_FIRST+N_LAST
sel = pd.concat([Dv.head(N_FIRST), Dv.tail(N_LAST)])
sel = sel.loc[~sel.index.duplicated(keep="first")].copy()

# --- helper: load + downscale to uint8 RGB ---
def load_png_uint8(path, max_side=900):
    im = Image.open(path).convert("RGB")
    w, h = im.size
    if max(w, h) > max_side:
        s = max_side / float(max(w, h))
        im = im.resize((max(1, int(round(w*s))), max(1, int(round(h*s)))), Image.BILINEAR)
    arr = np.asarray(im, dtype=np.uint8)
    im.close()
    return Image.fromarray(arr, mode="RGB")

# --- base figure with all rectangles (as before) ---
x0 = Dv["X_rel_um"].to_numpy(float)
y0 = Dv["Y_rel_um"].to_numpy(float)
x1 = (Dv["X_rel_um"] + Dv["TileWidth_um"]).to_numpy(float)
y1 = (Dv["Y_rel_um"] + Dv["TileHeight_um"]).to_numpy(float)

fig = go.Figure()
fig.add_trace(go.Scatter(x=[x0.min(), x1.max()], y=[y0.min(), y1.max()],
                         mode="markers", opacity=0, showlegend=False, hoverinfo="skip"))
fig.update_layout(
    shapes=[dict(type="rect", x0=a, y0=b, x1=c, y1=d, line=dict(width=1))
            for a, b, c, d in zip(x0, y0, x1, y1)],
    title=f"Tiles with {len(sel)} thumbnails (top-left origin)",
    height=800, margin=dict(l=60, r=20, t=50, b=60),
    plot_bgcolor="white",
)
fig.update_xaxes(title="X (µm)", showline=True, mirror=True, zeroline=False, ticks="outside")
fig.update_yaxes(title="Y (µm, top-down)", showline=True, mirror=True, zeroline=False,
                 ticks="outside", scaleanchor="x", autorange="reversed")

# --- overlay thumbnails below the grid lines ---
missed = 0
for _, r in sel.iterrows():
    try:
        img = load_png_uint8(r["png_path"], max_side=THUMB_MAX_SIDE)
        fig.add_layout_image(dict(
            source=img,
            x=float(r["X_rel_um"]),
            y=float(r["Y_rel_um"]),
            xref="x", yref="y",
            sizex=float(r["TileWidth_um"]),
            sizey=float(r["TileHeight_um"]),
            xanchor="left", yanchor="top",
            sizing="stretch",
            layer="below",
            opacity=IMG_OPACITY,
        ))
    except Exception:
        missed += 1
if missed:
    print(f"Skipped {missed} images that could not be read.")

fig.show()

In [None]:
# === Overlay thumbnails from NPZ 'sem_data' with uniform NPZ-based clipping ===

D = site_df.copy()

# knobs
N_FIRST = 500
N_LAST  = 500
THUMB_MAX_SIDE = 400  # px (downscale if either side larger)
IMG_OPACITY = 1.0     # 0..1
USE_GLOBAL = True     # True: use NPZ global_lo/global_hi; False: per-image NPZ min/max

# --- geometry prep (unchanged) ---
if ("TileWidth_um" not in D.columns) or D["TileWidth_um"].isna().any():
    if {"width_px","px_x_um"}.issubset(D.columns):
        D["TileWidth_um"] = D.get("TileWidth_um")
        m = D["TileWidth_um"].isna()
        D.loc[m, "TileWidth_um"] = D.loc[m, "width_px"] * D.loc[m, "px_x_um"]
if ("TileHeight_um" not in D.columns) or D["TileHeight_um"].isna().any():
    if {"height_px","px_y_um"}.issubset(D.columns):
        D["TileHeight_um"] = D.get("TileHeight_um")
        m = D["TileHeight_um"].isna()
        D.loc[m, "TileHeight_um"] = D.loc[m, "height_px"] * D.loc[m, "px_y_um"]

if ("X_rel_um" not in D.columns) or ("Y_rel_um" not in D.columns) or D["X_rel_um"].isna().any() or D["Y_rel_um"].isna().any():
    x0, y0 = float(D["stage_x_um"].min()), float(D["stage_y_um"].min())
    D["X_rel_um"] = D["stage_x_um"] - x0
    D["Y_rel_um"] = D["stage_y_um"] - y0

# rows we can actually place + have NPZ path
geom_cols = ["X_rel_um","Y_rel_um","TileWidth_um","TileHeight_um","npz_path"]
Dv = D.dropna(subset=geom_cols).copy()

# choose first/last without duplicating the middle if dataset < N_FIRST+N_LAST
sel = pd.concat([Dv.head(N_FIRST), Dv.tail(N_LAST)])
sel = sel.loc[~sel.index.duplicated(keep="first")].copy()

# NPZ global lo/hi
if USE_GLOBAL:
    assert {"global_lo","global_hi"}.issubset(D.columns), "Need global_lo/global_hi in site_df"
    lo_global = float(D["global_lo"].iloc[0])
    hi_global = float(D["global_hi"].iloc[0])

# --- helpers (NPZ first, PNG fallback for rare missing files) -----------------
def load_npz_gray(path: str | Path) -> np.ndarray:
    """Return raw 2D array from *_sem.npz (key 'sem_data')."""
    with np.load(path) as z:
        a = z["sem_data"]
    if a.ndim == 3:
        a = a[..., 0]
    return a

def load_png_gray(path: str | Path) -> np.ndarray:
    from PIL import Image
    with Image.open(path) as im:
        if im.mode != "L":
            im = im.convert("L")
        return np.asarray(im)

def load_raw_gray(row) -> np.ndarray:
    p_npz = Path(row["npz_path"])
    if p_npz.exists():
        try:
            return load_npz_gray(p_npz)
        except Exception:
            pass
    # PNG fallback (already tone-mapped, but better than nothing)
    return load_png_gray(Path(row["png_path"]))

def normalize_to_uint8(arr, lo, hi):
    """Clip to [lo,hi] then rescale to 0..255 uint8."""
    a = arr.astype(np.float32, copy=False)
    a = np.clip(a, lo, hi)
    if hi <= lo:
        return np.zeros_like(a, dtype=np.uint8)
    a = (a - lo) * (255.0 / (hi - lo))
    return a.astype(np.uint8, copy=False)

def norm_thumb_from_row(row, max_side=400):
    """Load NPZ raw, normalize (global or per-image), downscale, return PIL RGB."""
    arr = load_raw_gray(row)
    if USE_GLOBAL:
        lo, hi = lo_global, hi_global
    else:
        # per-image NPZ min/max (use cached columns if present; else compute)
        lo = float(row.get("npz_min", np.nan))
        hi = float(row.get("npz_max", np.nan))
        if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
            lo, hi = float(np.min(arr)), float(np.max(arr))
    arr8 = normalize_to_uint8(arr, lo, hi)
    im = Image.fromarray(arr8, mode="L")
    w, h = im.size
    if max(w, h) > max_side:
        s = max_side / float(max(w, h))
        im = im.resize((max(1, int(round(w*s))), max(1, int(round(h*s)))), Image.BILINEAR)
    return im.convert("RGB")

# --- base figure with all rectangles -----------------------------------------
x0 = Dv["X_rel_um"].to_numpy(float)
y0 = Dv["Y_rel_um"].to_numpy(float)
x1 = (Dv["X_rel_um"] + Dv["TileWidth_um"]).to_numpy(float)
y1 = (Dv["Y_rel_um"] + Dv["TileHeight_um"]).to_numpy(float)

fig = go.Figure()
fig.add_trace(go.Scatter(x=[x0.min(), x1.max()], y=[y0.min(), y1.max()],
                         mode="markers", opacity=0, showlegend=False, hoverinfo="skip"))
fig.update_layout(
    shapes=[dict(type="rect", x0=a, y0=b, x1=c, y1=d, line=dict(width=1))
            for a, b, c, d in zip(x0, y0, x1, y1)],
    title=f"Tiles with {len(sel)} NPZ-normalized thumbnails (top-left origin)"
          + (" — global clip" if USE_GLOBAL else " — per-image NPZ clip"),
    height=800, margin=dict(l=60, r=20, t=50, b=60),
    plot_bgcolor="white",
)
fig.update_xaxes(title="X (µm)", showline=False, mirror=True, zeroline=False, ticks="outside")
fig.update_yaxes(title="Y (µm, top-down)", showline=False, mirror=True, zeroline=False,
                 ticks="outside", scaleanchor="x", autorange="reversed")

fig.update_xaxes(showgrid=False, zeroline=False)
fig.update_yaxes(showgrid=False, zeroline=False)

# --- overlay thumbnails below the grid lines ---------------------------------
missed = 0
for _, r in sel.iterrows():
    try:
        img = norm_thumb_from_row(r, max_side=THUMB_MAX_SIDE)
        fig.add_layout_image(dict(
            source=img,
            x=float(r["X_rel_um"]),
            y=float(r["Y_rel_um"]),
            xref="x", yref="y",
            sizex=float(r["TileWidth_um"]),
            sizey=float(r["TileHeight_um"]),
            xanchor="left", yanchor="top",
            sizing="stretch",
            layer="below",
            opacity=IMG_OPACITY,
        ))
    except Exception:
        missed += 1

if missed:
    print(f"Skipped {missed} thumbnails (missing/bad NPZ or PNG).")

fig.show()

In [None]:
# === Write summary_table.csv for stitch_h5data.py ===

D = site_df.copy()

# 1) Ensure required columns exist / are derivable
need_xy = {"X_rel_um", "Y_rel_um"}
if not need_xy.issubset(D.columns):
    raise ValueError(f"Missing columns: {sorted(need_xy - set(D.columns))}")

# npz_path: if missing, derive from png_path (same stem, .npz)
if "npz_path" not in D.columns:
    if "png_path" not in D.columns:
        raise ValueError("Need 'npz_path' or 'png_path' to derive it.")
    D["npz_path"] = D["png_path"].map(lambda p: str(Path(p).with_suffix(".npz")))

# Z_layer: optional; default to 0
if "Z_layer" not in D.columns:
    D["Z_layer"] = 0

# 2) Make sure tile sizes are available in either form:
#    A) TileWidth_um/TileHeight_um (+ px_x_um/px_y_um)  OR  B) width_px/height_px
has_um  = {"TileWidth_um", "TileHeight_um"}.issubset(D.columns)
has_px  = {"width_px", "height_px"}.issubset(D.columns)

# If we lack µm sizes but have pixels + pixel size, compute them
if not has_um and has_px and {"px_x_um","px_y_um"}.issubset(D.columns):
    D["TileWidth_um"]  = D["width_px"]  * D["px_x_um"]
    D["TileHeight_um"] = D["height_px"] * D["px_y_um"]
    has_um = True

# If we lack pixel sizes but have µm sizes + px size, compute pixels (helps sanity checks downstream)
if not has_px and has_um and {"px_x_um","px_y_um"}.issubset(D.columns):
    D["width_px"]  = (D["TileWidth_um"]  / D["px_x_um"]).round().astype("Int64")
    D["height_px"] = (D["TileHeight_um"] / D["px_y_um"]).round().astype("Int64")
    has_px = True

# Final guard: we need at least ONE size representation resolvable
if not (has_um or has_px):
    raise ValueError(
        "summary_table.csv needs either (TileWidth_um & TileHeight_um [+ px_x_um/px_y_um]) "
        "or (width_px & height_px). Please add those to site_df."
    )

# 3) Type hygiene (helps downstream)
for c in ["X_rel_um","Y_rel_um","TileWidth_um","TileHeight_um","px_x_um","px_y_um"]:
    if c in D.columns:
        D[c] = pd.to_numeric(D[c], errors="coerce")
for c in ["width_px","height_px","Z_layer"]:
    if c in D.columns:
        D[c] = pd.to_numeric(D[c], errors="coerce").astype("Int64")

# 4) Choose an output folder: use the parent of the first NPZ (fallback: CWD)
try:
    out_dir = Path(D["npz_path"].dropna().iloc[0]).parent
except Exception:
    out_dir = Path.cwd()
out_path = out_dir / "summary_table.csv"
out_dir.mkdir(parents=True, exist_ok=True)

# 5) Column order: put stitcher-important columns first, then all others
first_cols = [
    "X_rel_um", "Y_rel_um", "Z_layer", "npz_path",
    # either/both of these groups may exist; include whichever you have:
    "TileWidth_um", "TileHeight_um", "px_x_um", "px_y_um",
    "width_px", "height_px",
]
ordered = [c for c in first_cols if c in D.columns] + [c for c in D.columns if c not in first_cols]
D = D[ordered]

# 6) Write CSV (utf-8, no index)
D.to_csv(out_path, index=False)
print(f"✅ Wrote {out_path}  ({len(D)} rows, {D.shape[1]} columns)")
print("First columns:", ordered[:10])

In [None]:
# --- Write config.txt with a single px_x_um if possible ---

# We expect pixel sizes to be uniform across all tiles.
px_x_vals = []
px_y_vals = []

if "px_x_um" in D.columns:
    px_x_vals = D["px_x_um"].dropna().unique()
if "px_y_um" in D.columns:
    px_y_vals = D["px_y_um"].dropna().unique()

px_x_um = float(px_x_vals[0]) if len(px_x_vals) == 1 else None
px_y_um = float(px_y_vals[0]) if len(px_y_vals) == 1 else None

if px_x_um is not None and px_y_um is not None and np.isclose(px_x_um, px_y_um):
    config_value = px_x_um

    config_path = out_dir / "config.txt"
    with open(config_path, "w", encoding="utf-8") as f:
        f.write(f"um_per_px={config_value}\n")

    print(f"Wrote config.txt with px_x_um={config_value}")
else:
    print("px_x_um and px_y_um differ or are missing; not writing config.txt")
    print("Unique px_x_um:", px_x_vals)
    print("Unique px_y_um:", px_y_vals)

In [None]:
# Histogram of brightness values across ALL tiles (NPZ) — final cell
# Assumes `site_df` exists and has an `npz_path` column pointing to .npz files
# The code streams arrays tile-by-tile to avoid loading everything into memory.

# --- knobs ---
NUM_BINS = 1024       # histogram bins (increase for finer granularity)
LOG_Y    = True       # plot counts on log scale for visibility

# --- discover NPZ files from the dataframe ---
paths = []
if "npz_path" in site_df.columns:
    for p in site_df["npz_path"].dropna().tolist():
        if isinstance(p, str) and os.path.isfile(p):
            paths.append(p)

if not paths:
    raise SystemExit("No NPZ files found in site_df['npz_path'].")

# --- pass 1: find global min/max and total pixels ---
gmin, gmax, total_px = None, None, 0
for npz in paths:
    try:
        d = np.load(npz)
        arr = d["sem_data"] if "sem_data" in d.files else d[d.files[0]]
        if arr.ndim == 3 and arr.shape[-1] == 1:
            arr = arr[..., 0]
        amin = int(arr.min())
        amax = int(arr.max())
        gmin = amin if gmin is None else min(gmin, amin)
        gmax = amax if gmax is None else max(gmax, amax)
        total_px += arr.size
    except Exception as e:
        print(f"[warn] could not read {npz}: {e}")

if gmin is None or gmax is None:
    raise SystemExit("Could not compute global range from NPZ files.")

# Choose bin edges across observed data range
bin_edges = np.linspace(gmin, gmax, NUM_BINS + 1, dtype=np.float64)

# --- pass 2: accumulate histogram ---
hist = np.zeros(NUM_BINS, dtype=np.int64)
for npz in paths:
    try:
        d = np.load(npz)
        arr = d["sem_data"] if "sem_data" in d.files else d[d.files[0]]
        if arr.ndim == 3 and arr.shape[-1] == 1:
            arr = arr[..., 0]
        h, _ = np.histogram(arr.ravel(), bins=bin_edges)
        hist += h
    except Exception:
        pass

# --- derive a few useful stats from the histogram ---
cdf = np.cumsum(hist).astype(np.float64)
cdf /= (cdf[-1] if cdf[-1] else 1.0)

def percentile_from_hist(cdf, edges, p):
    idx = np.searchsorted(cdf, p/100.0, side="left")
    idx = np.clip(idx, 0, len(edges)-2)
    return edges[idx]

p1  = float(percentile_from_hist(cdf, bin_edges, 1.0))
p50 = float(percentile_from_hist(cdf, bin_edges, 50.0))
p99 = float(percentile_from_hist(cdf, bin_edges, 99.0))

print(f"Global NPZ range: min={gmin}, max={gmax}   (pixels: {total_px:,})")
print(f"Approx percentiles from histogram: p1≈{int(round(p1))}, p50≈{int(round(p50))}, p99≈{int(round(p99))}")

# --- plot ---
# --- plot with low/high clipping guides at 1%, 2%, 5% ---
import matplotlib.pyplot as plt

# compute low/high cutoff values from the histogram CDF
# choose clip percents
clips = [0.5, 1.0, 5.0]  # percent

# compute low/high cutoff values from the histogram CDF
clip_vals = []
for p in clips:
    lo_p = percentile_from_hist(cdf, bin_edges, p)
    hi_p = percentile_from_hist(cdf, bin_edges, 100.0 - p)
    clip_vals.append((p, float(lo_p), float(hi_p)))

# base plot
centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
plt.figure(figsize=(10, 5))
plt.plot(centers, hist, drawstyle="steps-mid", label="Histogram")
if LOG_Y:
    plt.yscale("log")

# colors aligned with clips (avoid float dict lookups)
clip_colors = dict(zip(clips, ["#d62728", "#ff7f0e", "#2ca02c"]))  # red, orange, green
fmt_pct = lambda x: f"{x:g}"  # 0.5 -> "0.5", 1.0 -> "1", 5.0 -> "5"

# draw lines + light shading
for p, lo_v, hi_v in clip_vals:
    c = clip_colors[p]
    plt.axvline(lo_v, color=c, linestyle="--", linewidth=1.5, label=f"{fmt_pct(p)}% lo")
    plt.axvline(hi_v, color=c, linestyle="--", linewidth=1.5, label=f"{fmt_pct(p)}% hi")
    y0, y1 = plt.ylim()
    plt.fill_betweenx([y0, y1], bin_edges[0], lo_v, color=c, alpha=0.06)
    plt.fill_betweenx([y0, y1], hi_v, bin_edges[-1], color=c, alpha=0.06)

plt.title("Histogram of brightness values across all tiles (NPZ)")
plt.xlabel("Raw intensity value")
plt.ylabel("Pixel count" + (" (log)" if LOG_Y else ""))

# compact legend: show each percentage once
handles, labels = plt.gca().get_legend_handles_labels()
seen, kept = set(), []
for h, l in zip(handles, labels):
    key = l.split()[0]  # "0.5%", "1%", "5%"
    if key not in seen:
        seen.add(key)
        kept.append((h, l))
plt.legend(*zip(*kept), loc="upper right", frameon=False)

plt.tight_layout()
plt.show()

# print exact thresholds
for p, lo_v, hi_v in clip_vals:
    print(f"{fmt_pct(p)}% clip → lo={int(round(lo_v))}, hi={int(round(hi_v))}")

In [None]:
def build_downsampled_mosaic_h5(
    df,
    num_tiles: int = 40,
    max_canvas_px: int = 5000,
    px_per_um_target: float | None = None,
):
    """
    Build a downsampled RGB mosaic from the first `num_tiles` tiles in `df`.

    Uses:
        - width_px, height_px, px_x_um, px_y_um
        - stage_x_um, stage_y_um  → X_rel_um, Y_rel_um
        - npz_path
        - norm_thumb_from_row(row)  -> PIL RGB thumbnail (handles grayscale)

    Returns:
        merged_rgb : uint8 (H, W, 3)
        extent     : [xmin, xmax, ymin, ymax] in µm
    """
    # work on a copy
    D = df.copy()

    # --- ensure tile sizes in µm exist (same logic as rectangle / overlay cells) ---
    if ("TileWidth_um" not in D.columns) or D["TileWidth_um"].isna().any():
        if {"width_px", "px_x_um"}.issubset(D.columns):
            D["TileWidth_um"] = D.get("TileWidth_um")
            m = D["TileWidth_um"].isna()
            D.loc[m, "TileWidth_um"] = D.loc[m, "width_px"] * D.loc[m, "px_x_um"]
        else:
            raise ValueError("Need either TileWidth_um or (width_px, px_x_um) in site_df.")

    if ("TileHeight_um" not in D.columns) or D["TileHeight_um"].isna().any():
        if {"height_px", "px_y_um"}.issubset(D.columns):
            D["TileHeight_um"] = D.get("TileHeight_um")
            m = D["TileHeight_um"].isna()
            D.loc[m, "TileHeight_um"] = D.loc[m, "height_px"] * D.loc[m, "px_y_um"]
        else:
            raise ValueError("Need either TileHeight_um or (height_px, px_y_um) in site_df.")

    # --- ensure relative coords in µm exist ---
    if (
        ("X_rel_um" not in D.columns)
        or ("Y_rel_um" not in D.columns)
        or D["X_rel_um"].isna().any()
        or D["Y_rel_um"].isna().any()
    ):
        x0 = float(D["stage_x_um"].min())
        y0 = float(D["stage_y_um"].min())
        D["X_rel_um"] = D["stage_x_um"] - x0
        D["Y_rel_um"] = D["stage_y_um"] - y0

    # --- keep only rows with complete geometry + npz_path ---
    if "npz_path" not in D.columns:
        raise ValueError("Need 'npz_path' column in site_df to build mosaic.")
    geom_cols = ["X_rel_um", "Y_rel_um", "TileWidth_um", "TileHeight_um", "npz_path"]
    Dm = D.dropna(subset=geom_cols).copy()
    Dm = Dm.head(num_tiles)

    if Dm.empty:
        raise ValueError("No tiles with geometry/npz info to build a mosaic from.")

    # --- spatial bounding box in µm ---
    x_min = float(Dm["X_rel_um"].min())
    x_max = float((Dm["X_rel_um"] + Dm["TileWidth_um"]).max())
    y_min = float(Dm["Y_rel_um"].min())
    y_max = float((Dm["Y_rel_um"] + Dm["TileHeight_um"]).max())

    width_um  = x_max - x_min
    height_um = y_max - y_min
    longest_um = max(width_um, height_um)
    if longest_um <= 0:
        raise ValueError("Degenerate extent (width/height in µm is zero).")

    # --- choose pixels-per-micron for the canvas ---
    if px_per_um_target is not None:
        px_per_um = float(px_per_um_target)
        if longest_um * px_per_um > max_canvas_px:
            px_per_um = max_canvas_px / longest_um
    else:
        px_per_um = max_canvas_px / longest_um

    canvas_w = max(1, int(round(width_um  * px_per_um)))
    canvas_h = max(1, int(round(height_um * px_per_um)))
    canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)

    # --- paste each tile ---
    for _, row in Dm.iterrows():
        # already handles grayscale → RGB
        img = norm_thumb_from_row(row, max_side=1024)

        tw_um = float(row["TileWidth_um"])
        th_um = float(row["TileHeight_um"])
        tw_px = max(1, int(round(tw_um * px_per_um)))
        th_px = max(1, int(round(th_um * px_per_um)))

        img_resized = img.resize((tw_px, th_px), Image.BILINEAR)
        tile = np.asarray(img_resized, dtype=np.uint8)

        if tile.ndim == 2:  # just in case we ever get raw grayscale here
            tile = np.repeat(tile[..., None], 3, axis=2)

        x_um = float(row["X_rel_um"]) - x_min
        y_um = float(row["Y_rel_um"]) - y_min

        x0_px = int(round(x_um * px_per_um))
        y0_px = int(round(y_um * px_per_um))
        x1_px = min(canvas_w, x0_px + tw_px)
        y1_px = min(canvas_h, y0_px + th_px)

        if x1_px <= x0_px or y1_px <= y0_px:
            continue

        canvas[y0_px:y1_px, x0_px:x1_px, :] = tile[: (y1_px - y0_px), : (x1_px - x0_px), :]

    merged_rgb = canvas
    extent = [x_min, x_max, y_min, y_max]

    # optional: save mosaic PNG next to summary_table/config.txt
    try:
        outpath = os.path.join(out_dir, f"merged_first{len(Dm)}.png")
        Image.fromarray(merged_rgb, mode="RGB").save(outpath)
        print(f"Saved downsampled mosaic to: {os.path.abspath(outpath)}")
    except Exception as e:
        print("Could not save mosaic PNG:", e)

    print(f"Canvas: {merged_rgb.shape[1]}x{merged_rgb.shape[0]} px, px/µm ≈ {px_per_um:.4f}")
    return merged_rgb, extent

# ---- run it on your h5 data ----
NUM_TILES_FOR_MOSAIC = 40
merged_rgb, extent = build_downsampled_mosaic_h5(site_df, num_tiles=NUM_TILES_FOR_MOSAIC)

In [None]:
%matplotlib widget

# ---------- figure ----------
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_title("Zoom (toolbar) to inspect, then LEFT-click two points → Process")
im = ax.imshow(merged_rgb, extent=extent, origin='upper')
ax.set_xlabel("X (µm)")
ax.set_ylabel("Y (µm)")
ax.set_aspect('equal')

# visuals for picks
p1_artist, = ax.plot([], [], 'o', ms=8, label="P1")
p2_artist, = ax.plot([], [], 'o', ms=8, label="P2")
line_artist, = ax.plot([], [], '-', lw=2, label="P1–P2")
ax.legend(loc="upper right")

picked = []   # holds up to 2 tuples (x_um, y_um)

# ---------- UI ----------
btn_reset   = w.Button(description="Reset",   button_style="warning")
btn_process = w.Button(description="Process", button_style="success")
status_out  = w.Output(layout={"border":"1px solid #ddd"})
display(w.HBox([btn_reset, btn_process]), status_out)

outname = "rotation.txt"
outpath = os.path.join(out_dir, outname)

def _status(msg):
    status_out.clear_output(wait=True)
    with status_out:
        print(msg)

def reset(_btn=None):
    picked.clear()
    p1_artist.set_data([], [])
    p2_artist.set_data([], [])
    line_artist.set_data([], [])
    fig.canvas.draw_idle()
    _status("Cleared picks. Left-click two points, then press Process.")

btn_reset.on_click(reset)

def rotation_to_horizontal_deg(angle_deg: float) -> float:
    """
    Given the absolute segment angle (deg, CCW +X), return the smallest CCW rotation
    in [-90, +90] deg that makes the segment horizontal (angle ≡ 0 mod 180).
    """
    # Wrap angle to [-180, 180)
    a = (angle_deg + 180) % 360 - 180
    rot = -a                      # rotate so angle→0
    if rot > 90:
        rot -= 180
    if rot <= -90:
        rot += 180
    return rot

def process(_btn=None):
    if len(picked) != 2:
        _status("Need exactly two points. Left-click twice, then press Process.")
        return

    (x1, y1), (x2, y2) = picked
    dx, dy = x2 - x1, y2 - y1
    dist = float(np.hypot(dx, dy))
    if dist == 0:
        _status("The two points coincide. Pick two distinct points.")
        return

    # Angle of segment; origin='upper' means Y grows downward, but atan2(dy,dx) is correct.
    angle_deg = float(np.degrees(np.arctan2(dy, dx)))
    rot_deg = float(rotation_to_horizontal_deg(angle_deg))

    # Write file
    with open(outpath, "w", encoding="utf-8") as f:
        f.write(f"{rot_deg:.6f}\n")

    config_path = out_dir / "config.txt"
    with open(config_path, "a", encoding="utf-8") as f:
        f.write(f"rotation={rot_deg:.6f}\n")

    # Visual + status
    _status(
        f"P1=({x1:.2f}, {y1:.2f}) µm | P2=({x2:.2f}, {y2:.2f}) µm\n"
        f"Δ=({dx:.2f}, {dy:.2f}) µm  | |Δ|={dist:.2f} µm  | angle={angle_deg:.2f}°\n"
        f"→ rotate by {rot_deg:.2f}° (CCW) to make the line horizontal\n"
        f"Wrote: {outpath}"
    )

    # optional: copy to clipboard
    payload = {"P1_um":[x1,y1], "P2_um":[x2,y2], "angle_deg":angle_deg, "rotation_deg":rot_deg}
    try:
        display(Javascript(f"navigator.clipboard.writeText({repr(payload)})"))
    except Exception:
        pass

btn_process.on_click(process)

def onclick(event):
    # Only inside axes, left button, and not while a toolbar tool is active
    if event.inaxes is not ax:
        return
    if getattr(event, "button", 1) != 1:  # 1=left, 2=middle, 3=right
        return
    tb = getattr(fig.canvas, "toolbar", None)
    if tb and getattr(tb, "mode", ""):
        return

    x, y = float(event.xdata), float(event.ydata)
    if len(picked) == 0:
        picked[:] = [(x, y)]
        p1_artist.set_data([x], [y])
        p2_artist.set_data([], [])
        line_artist.set_data([], [])
        _status(f"Picked P1=({x:.2f}, {y:.2f}) µm. Pick P2, then press Process.")
    elif len(picked) == 1:
        picked.append((x, y))
        p2_artist.set_data([x], [y])
        line_artist.set_data([picked[0][0], x], [picked[0][1], y])
        _status("P1 and P2 selected. Press Process.")
    else:
        picked[:] = [(x, y)]
        p1_artist.set_data([x], [y])
        p2_artist.set_data([], [])
        line_artist.set_data([], [])
        _status(f"Restarted: P1=({x:.2f}, {y:.2f}) µm. Pick P2, then press Process.")

    fig.canvas.draw_idle()

cid = fig.canvas.mpl_connect('button_press_event', onclick)
fig

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# -------------------- loader (handles all variants we've seen) -------------------- #
def _detect_channel_axis_3d(shape):
    """
    Given a 3D shape, pick which axis is the spectral axis.

    Heuristic based on formats we've seen:
      - Channel length is typically 1024 or 2048.
      - If multiple axes match {1024, 2048}, prefer the LAST one.
      - If none match, fall back to the largest dimension.
    """
    candidates = [i for i, d in enumerate(shape) if d in (1024, 2048)]
    if len(candidates) == 1:
        return candidates[0]
    elif len(candidates) > 1:
        return max(candidates)  # prefer last occurrence (works for 1408,2048,1024)
    else:
        return int(np.argmax(shape))  # fallback

def load_eds_array(p: Path) -> np.ndarray:
    """
    Load an EDS array from an NPZ file.

    Returns:
      • 3D data as (H, W, C) with channels last, for any permutation
        that contains a 1024/2048-like spectral axis (or falls back to
        the largest axis).
      • 2D data as-is for inputs that are (N_pixels, C).

    Covers:
      • legacy format  -> (C, H, W) or similar
      • modern format  -> (H, W, C)
      • flattened      -> (N_pixels, C)
    """
    with np.load(p, allow_pickle=False) as z:
        key = "eds_data" if "eds_data" in z.files else next(iter(z.files))
        arr = np.asarray(z[key])

    if arr.ndim == 3:
        c_axis = _detect_channel_axis_3d(arr.shape)
        if c_axis != 2:
            arr = np.moveaxis(arr, c_axis, -1)  # -> (H, W, C)
        return arr

    elif arr.ndim == 2:
        # (N_pixels, C) – we assume last axis is channels, which is what we want
        return arr

    else:
        raise ValueError(f"{p.name}: expected 2D or 3D EDS array, got shape {arr.shape}")


# -------------------- choose which EDS file to plot -------------------- #
n = 10  # <-- change this to select the nth *_eds.npz file (1-based index)

root = Path(pythondata_folder)  # or h5pythondata_folder, depending on your notebook
eds_files = sorted(root.rglob("*_eds.npz"))
print(f"Found {len(eds_files)} EDS NPZ file(s) under {root}")

if not eds_files:
    raise RuntimeError("No *_eds.npz files found; nothing to plot.")

if not (1 <= n <= len(eds_files)):
    raise ValueError(f"Requested n={n} but there are only {len(eds_files)} files.")

target_path = eds_files[n - 1]
print(f"Using file #{n}: {target_path}")

# -------------------- load and interpret -------------------- #
arr = load_eds_array(target_path)
print("Loaded array shape:", arr.shape, "dtype:", arr.dtype)

if arr.ndim == 3:
    H, W, C = arr.shape
    # map-sum spectrum
    sum_spec = arr.reshape(-1, C).sum(axis=0).astype(float)
    # center-pixel spectrum
    cy, cx = H // 2, W // 2
    one_pixel = arr[cy, cx, :].astype(float)

elif arr.ndim == 2:
    N, C = arr.shape
    sum_spec = arr.sum(axis=0).astype(float)
    one_pixel = arr[0, :].astype(float)

else:
    raise ValueError(f"Unexpected ndim={arr.ndim} after loading EDS data.")

channels = np.arange(C)

# -------------------- plot map-sum spectrum -------------------- #
plt.figure(figsize=(10, 4))
plt.plot(channels, sum_spec)
plt.xlabel("Channel index")
plt.ylabel("Total counts (sum over all pixels)")
plt.title(f"EDS map-sum spectrum for file #{n}\n{target_path.name}")
# plt.yscale("log")  # uncomment if you want log scale
plt.tight_layout()
plt.show()

# -------------------- plot single-pixel spectrum -------------------- #
plt.figure(figsize=(10, 4))
plt.plot(channels, one_pixel)
plt.xlabel("Channel index")
plt.ylabel("Counts (single pixel)")
plt.title(f"Single-pixel spectrum – file #{n}")
plt.tight_layout()
plt.show()