In [None]:
# ============================================================
# 2-PANEL OUTPUT PER FIELD (DAPI+53BP1 + classified overlay)
#
# What this script does:
#   1) Finds one "anchor" DAPI image per field-of-view (FOV), even if DAPI is not at Z0010.
#   2) Builds max-intensity projections (MIPs) across Z for:
#        - AF488 (53BP1 visualization)
#        - AF594 and AF647 (CRITICAL: classification uses these)
#   3) Segments nuclei from the anchor DAPI image (Cellpose if available; Otsu fallback).
#   4) Classifies nuclei using per-well AF594/AF647 cutoffs (exclusive calls):
#        - AF594-only => purple
#        - AF647-only => red
#        - otherwise => grayscale background
#   5) Saves a 2-panel PNG per FOV and logs metrics to a CSV.
#
# Notes:
#   - This script is meant for generating readable panels, not quantitative normalization.
# ============================================================

import os, re, glob, csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tifffile import imread

from skimage.filters import threshold_otsu
from skimage.morphology import remove_small_holes, remove_small_objects, dilation, disk
from skimage.segmentation import clear_border
from skimage.measure import label

# ----------------------------
# Optional Cellpose (recommended)
# ----------------------------
USE_CELLPOSE = True
try:
    from cellpose import models
except Exception:
    USE_CELLPOSE = False

# ----------------------------
# Config
# ----------------------------
FORCE_RERUN = True  # set False to skip anchors already processed

data_dir = "/content/drive/My Drive/representative images"
cutoff_csv_path = "/content/drive/My Drive/per_well_cutoffs_all_plates_AF594_AF647_0.975_WITH_secondary_floors_high.csv"

output_dir = "/content/drive/My Drive/Output_MIP_panel_1"
os.makedirs(output_dir, exist_ok=True)

csv_out_path   = os.path.join(output_dir, "panel_metrics.csv")
done_list_path = os.path.join(output_dir, "panels_done.txt")

TARGET_PLATES = {934}

# DAPI display window (fixed)
DISP_DAPI = (500, 2500)

# ----------------------------
# AF488 display parameters (NOT aggressive)
#   - percentiles for windowing
#   - gamma=1.0 means no extra darkening
# ----------------------------
AF488_P_LO  = 1.0
AF488_P_HI  = 99.7
AF488_GAMMA = 1.0

# Z-range used for max projections
Z_RANGE = range(0, 21)  # Z0000..Z0020 inclusive

CHANNEL_TAGS = {
    "DAPI":  ("C00", "DAPI"),
    "AF488": ("C01", "AF488"),  # 53BP1
    "AF594": ("C02", "AF594"),
    "AF647": ("C03", "AF647"),
}

# CSV output columns
COLS_ORDER = [
    "image_anchor", "panel_path", "plate", "well",
    "cutoff_594", "cutoff_647",
    "af488_vmin", "af488_vmax", "af488_gamma",
    "n_nuclei", "n_594_only", "n_647_only", "n_unclassified"
]

# ============================================================
# Resumable helpers
# ============================================================
def load_completed_sets():
    """Collect anchors already processed based on CSV + text log."""
    completed = set()

    if os.path.exists(csv_out_path):
        try:
            df_prev = pd.read_csv(csv_out_path)
            if "image_anchor" in df_prev.columns:
                completed.update(df_prev["image_anchor"].astype(str).tolist())
        except Exception:
            pass

    if os.path.exists(done_list_path):
        try:
            with open(done_list_path, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        completed.add(line)
        except Exception:
            pass

    return completed


def append_row_to_csv(row_dict):
    """Append one row to the metrics CSV + mark the anchor as done."""
    row = {k: row_dict.get(k, None) for k in COLS_ORDER}
    write_header = not os.path.exists(csv_out_path)

    with open(csv_out_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=COLS_ORDER)
        if write_header:
            w.writeheader()
        w.writerow(row)

    with open(done_list_path, "a") as f:
        f.write(f"{row['image_anchor']}\n")


# ============================================================
# Plate / well parsing
# ============================================================
def parse_plate_from_name(fname):
    """Plate is the leading integer token at the start of the filename."""
    m = re.match(r"^(\d+)_", os.path.basename(fname))
    return int(m.group(1)) if m else None


def parse_folder_well_from_path(path):
    """
    Well is expected to be the last underscore token of a folder name, e.g.:
      .../934_Gamma_1_4hr_CD4_CD19_B11/...
    """
    norm = path.replace("\\", "/")
    parts = norm.split("/")

    # search backwards through folders
    for part in parts[::-1]:
        if part.lower().endswith(".tif"):
            continue
        last = part.split("_")[-1]
        if re.fullmatch(r"[A-H]\d{1,2}", last):
            return last

    # fallback: if well appears inside filename like _B11_
    bn = os.path.basename(norm)
    m = re.search(r"_([A-H]\d{1,2})_", bn)
    return m.group(1) if m else None


# ============================================================
# Image IO helpers
# ============================================================
def derive_base_prefix(example_path):
    """
    Strip trailing: _Z####_C##(Name)_M0000_ORG.tif
    This preserves the unique field identity (e.g., includes S#### token).
    """
    bn = os.path.basename(example_path)
    m = re.search(r"_Z\d{4}_C\d{2}\([^)]+\)_M0000_ORG\.tif$", bn)
    if m:
        core = bn[:m.start()]
    else:
        core = "_".join(bn.split("_")[:-3])
    return os.path.join(os.path.dirname(example_path), core)


def read_single_z_from_anchor(anchor_path, ch_tag, ch_name):
    """Read the SAME-Z image for a different channel by swapping the channel token."""
    if ch_tag == "C00" and ch_name == "DAPI":
        target = anchor_path
    else:
        bn = os.path.basename(anchor_path)
        dirn = os.path.dirname(anchor_path)
        new_bn = re.sub(r"C00\(DAPI\)", f"{ch_tag}({ch_name})", bn)
        target = os.path.join(dirn, new_bn)

    if not os.path.exists(target):
        raise FileNotFoundError(f"No file found for derived path: {target}")

    return imread(target)


def read_channel_maxproj_from_anchor(anchor_path, ch_tag, ch_name, z_range):
    """
    Max-intensity projection for a channel stack at the SAME field-of-view.
    If no planes are found, fall back to the single-Z derived from the anchor.
    """
    base_prefix = derive_base_prefix(anchor_path)
    planes = []

    for z in z_range:
        z_code = f"Z{z:04d}"
        fp = f"{base_prefix}_{z_code}_{ch_tag}({ch_name})_M0000_ORG.tif"
        if os.path.exists(fp):
            planes.append(imread(fp))

    if planes:
        return np.stack(planes, axis=0).max(axis=0)

    return read_single_z_from_anchor(anchor_path, ch_tag, ch_name)


# ============================================================
# Anchor selection: one DAPI per field (any-Z)
# ============================================================
def choose_anchor_dapi_per_field(root, target_plates=None, prefer_z=10):
    """
    Find all DAPI images at any Z, group by field, choose one anchor per field.
    Prefer Z0010 if present; otherwise pick a middle Z.
    """
    patt = os.path.join(root, "**", "*_Z????_C00(DAPI)_M0000_ORG.tif")
    all_dapi = sorted(glob.glob(patt, recursive=True))

    if target_plates is not None:
        tset = set(target_plates)
        all_dapi = [p for p in all_dapi if parse_plate_from_name(p) in tset]

    z_re = re.compile(r"_Z(\d{4})_C00\(DAPI\)_M0000_ORG\.tif$")

    by_field = {}  # base_prefix -> set(Z)
    for fp in all_dapi:
        mz = z_re.search(os.path.basename(fp))
        if not mz:
            continue
        z_int = int(mz.group(1))
        base = derive_base_prefix(fp)
        by_field.setdefault(base, set()).add(z_int)

    anchors = []
    for base, zs in by_field.items():
        zs_sorted = sorted(zs)
        z_pick = prefer_z if prefer_z in zs else zs_sorted[len(zs_sorted) // 2]
        anchor = f"{base}_Z{z_pick:04d}_C00(DAPI)_M0000_ORG.tif"
        if os.path.exists(anchor):
            anchors.append(anchor)

    return sorted(anchors)


# ============================================================
# Per-well cutoffs
# ============================================================
def load_per_well_cutoffs(path):
    """Load per-well cutoffs CSV with required columns."""
    dfc = pd.read_csv(path, low_memory=False)
    dfc.columns = [c.strip().lower() for c in dfc.columns]

    need = {"plate", "well", "channel", "cutoff_linear"}
    if not need.issubset(set(dfc.columns)):
        raise KeyError(f"Per-well cutoff CSV must have {need}, got {set(dfc.columns)}")

    dfc["plate"] = pd.to_numeric(dfc["plate"], errors="coerce").astype("Int64")
    dfc["well"] = dfc["well"].astype(str).str.strip()
    dfc["channel"] = dfc["channel"].astype(str).str.strip()
    dfc["cutoff_linear"] = pd.to_numeric(dfc["cutoff_linear"], errors="coerce")
    return dfc


CUTOFFS = load_per_well_cutoffs(cutoff_csv_path)


def get_cutoffs_for_plate_well(dfc, plate, well):
    """Return (AF594 cutoff, AF647 cutoff) for a given (plate, well)."""
    if plate is None or well is None:
        return (None, None)

    dfp = dfc[(dfc["plate"].astype(str) == str(plate)) & (dfc["well"] == str(well))].copy()
    if dfp.empty:
        return (None, None)

    c594 = dfp.loc[dfp["channel"].str.contains("594", case=False, na=False), "cutoff_linear"].dropna()
    c647 = dfp.loc[dfp["channel"].str.contains("647", case=False, na=False), "cutoff_linear"].dropna()

    return (
        float(c594.iloc[0]) if len(c594) else None,
        float(c647.iloc[0]) if len(c647) else None
    )


# ============================================================
# Display scaling (simple, standard)
# ============================================================
def clip_and_scale(img, vmin, vmax):
    """Clip to [vmin, vmax] then rescale to [0,1]."""
    x = np.asarray(img, dtype=float)
    x = np.clip(x, vmin, vmax)
    x = (x - vmin) / (vmax - vmin + 1e-9)
    return np.clip(x, 0, 1)


def robust_window(img, p_lo=1.0, p_hi=99.8):
    """Compute a percentile window for display."""
    x = np.asarray(img, dtype=float)
    vmin = float(np.nanpercentile(x, p_lo))
    vmax = float(np.nanpercentile(x, p_hi))

    if (not np.isfinite(vmin)) or (not np.isfinite(vmax)) or (vmax <= vmin):
        vmin, vmax = float(np.nanmin(x)), float(np.nanmax(x))

    if vmax <= vmin:
        vmax = vmin + 1.0

    return vmin, vmax


def two_color_blue_green(dapi, af488, dapi_window, p_lo, p_hi, gamma=1.0):
    """
    Build a simple RGB: DAPI->blue, AF488->green.
    Uses standard percentile windowing and optional gamma.
    """
    d = clip_and_scale(dapi, *dapi_window)

    vmin488, vmax488 = robust_window(af488, p_lo=p_lo, p_hi=p_hi)
    g = clip_and_scale(af488, vmin488, vmax488)

    if gamma is not None and float(gamma) != 1.0:
        g = np.power(g, float(gamma))

    r = np.zeros_like(d)
    rgb = np.stack([r, g, d], axis=-1)
    return rgb, vmin488, vmax488


def grayscale_bg(two_col):
    """Convert the two-color panel to a neutral grayscale background."""
    bg = 0.25 * two_col[..., 0] + 0.60 * two_col[..., 1] + 0.15 * two_col[..., 2]
    bg = np.clip(bg, 0, 1)
    return np.stack([bg, bg, bg], axis=-1)


def overlay_classification(bg_rgb, mask594, mask647, alpha=0.92):
    """
    Overlay classification:
      - AF594-only => purple
      - AF647-only => red
      - others => background
    """
    out = bg_rgb.copy()

    purple = mask594 & (~mask647)
    red    = mask647 & (~mask594)

    overlay = np.zeros_like(out)
    overlay[..., 0] = red.astype(float) + purple.astype(float)  # R
    overlay[..., 2] = purple.astype(float)                      # B

    m = (red | purple)[..., None]
    return np.where(m, (1 - alpha) * out + alpha * overlay, out)


# ============================================================
# Segmentation
# ============================================================
def segment_nuclei(dapi_img):
    """Segment nuclei mask from DAPI using Cellpose (if available) or Otsu."""
    if USE_CELLPOSE:
        try:
            model = models.Cellpose(gpu=False, model_type="nuclei")
            masks, *_ = model.eval(
                [dapi_img],
                channels=[0, 0],
                diameter=None,
                flow_threshold=0.4,
                cellprob_threshold=0.0,
            )
            nuc = masks[0] > 0
            nuc = remove_small_objects(nuc, 64)
            nuc = remove_small_holes(nuc, 64)
            nuc = clear_border(nuc)
            nuc = dilation(nuc, footprint=disk(1))
            return nuc
        except Exception:
            pass

    thr = threshold_otsu(dapi_img.astype(np.float32))
    nuc = dapi_img > thr
    nuc = remove_small_objects(nuc, 64)
    nuc = remove_small_holes(nuc, 64)
    nuc = clear_border(nuc)
    nuc = dilation(nuc, footprint=disk(1))
    return nuc


# ============================================================
# Classification (per nucleus mean; exclusive)
# ============================================================
def classify_cells_by_channel(nuc_mask, af594, af647, cutoff594, cutoff647):
    """
    For each connected nucleus:
      - compute mean(AF594) and mean(AF647)
      - apply per-well cutoffs
      - assign AF594-only or AF647-only, otherwise unclassified
    Returns:
      mask594_only, mask647_only, n_nuclei, n594, n647, n_unclassified
    """
    labels = label(nuc_mask)
    nlab = int(labels.max())

    mask594_only = np.zeros_like(nuc_mask, dtype=bool)
    mask647_only = np.zeros_like(nuc_mask, dtype=bool)

    if cutoff594 is None or cutoff647 is None:
        return mask594_only, mask647_only, nlab, 0, 0, nlab
    if (not np.isfinite(cutoff594)) or (not np.isfinite(cutoff647)):
        return mask594_only, mask647_only, nlab, 0, 0, nlab

    n_594 = n_647 = n_un = 0

    for lab_id in range(1, nlab + 1):
        cell_mask = (labels == lab_id)
        if not np.any(cell_mask):
            continue

        m594 = float(af594[cell_mask].mean())
        m647 = float(af647[cell_mask].mean())

        is594 = m594 > cutoff594
        is647 = m647 > cutoff647

        if is594 and not is647:
            mask594_only[cell_mask] = True
            n_594 += 1
        elif is647 and not is594:
            mask647_only[cell_mask] = True
            n_647 += 1
        else:
            n_un += 1

    return mask594_only, mask647_only, nlab, n_594, n_647, n_un


# ============================================================
# Process one anchor -> save 1x2 panel
# ============================================================
def process_one(anchor_path, outdir):
    """
    Build:
      - Left: DAPI + AF488 (MIP) panel
      - Right: classification overlay (AF594/AF647 MIPs + per-well cutoffs)
    """
    # Anchor DAPI (single Z)
    dapi = read_single_z_from_anchor(anchor_path, *CHANNEL_TAGS["DAPI"])

    # Visualization channel: AF488 MIP
    af488 = read_channel_maxproj_from_anchor(anchor_path, *CHANNEL_TAGS["AF488"], z_range=Z_RANGE)

    # Classification channels: AF594 + AF647 MIPs (important!)
    af594 = read_channel_maxproj_from_anchor(anchor_path, *CHANNEL_TAGS["AF594"], z_range=Z_RANGE)
    af647 = read_channel_maxproj_from_anchor(anchor_path, *CHANNEL_TAGS["AF647"], z_range=Z_RANGE)

    # Plate / well + cutoffs
    plate = parse_plate_from_name(anchor_path)
    well  = parse_folder_well_from_path(anchor_path)
    cutoff_594, cutoff_647 = get_cutoffs_for_plate_well(CUTOFFS, plate, well)

    # Segment nuclei
    nuc = segment_nuclei(dapi)

    # Classify nuclei
    mask594, mask647, n_nuc, n594, n647, nun = classify_cells_by_channel(
        nuc, af594, af647, cutoff_594, cutoff_647
    )

    # Build display panels
    two, vmin488, vmax488 = two_color_blue_green(
        dapi, af488,
        dapi_window=DISP_DAPI,
        p_lo=AF488_P_LO, p_hi=AF488_P_HI,
        gamma=AF488_GAMMA
    )
    bg = grayscale_bg(two)
    over = overlay_classification(bg, mask594, mask647, alpha=0.92)

    # Plot + save
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))

    axes[0].imshow(two)
    axes[0].set_title(
        f"DAPI (blue) + 53BP1 AF488 (green)\n"
        f"AF488 window p{AF488_P_LO}-p{AF488_P_HI}, gamma={AF488_GAMMA}"
    )

    axes[1].imshow(over)
    if cutoff_594 is not None and cutoff_647 is not None and np.isfinite(cutoff_594) and np.isfinite(cutoff_647):
        axes[1].set_title(
            f"Classified (per-well cutoffs)\n"
            f"AF594-only purple | AF647-only red | unclassified grayscale\n"
            f"plate={plate} well={well}  594>{cutoff_594:.2f}  647>{cutoff_647:.2f}"
        )
    else:
        axes[1].set_title(f"Classified (per-well cutoffs)\nplate={plate} well={well} (cutoffs missing)")

    for ax in axes:
        ax.axis("off")

    plt.tight_layout()

    rel = os.path.relpath(anchor_path, data_dir).replace("/", "__")
    base_bn = rel.replace("_C00(DAPI)_M0000_ORG.tif", "")
    out_path = os.path.join(outdir, f"{base_bn}_two_panel_classified.png")

    plt.savefig(out_path, dpi=220, bbox_inches="tight")
    plt.close(fig)

    return {
        "image_anchor": os.path.relpath(anchor_path, data_dir),
        "panel_path": os.path.relpath(out_path, data_dir),
        "plate": plate,
        "well": well,
        "cutoff_594": cutoff_594,
        "cutoff_647": cutoff_647,
        "af488_vmin": float(vmin488),
        "af488_vmax": float(vmax488),
        "af488_gamma": float(AF488_GAMMA),
        "n_nuclei": int(n_nuc),
        "n_594_only": int(n594),
        "n_647_only": int(n647),
        "n_unclassified": int(nun),
    }


# ============================================================
# Main run
# ============================================================
anchors = choose_anchor_dapi_per_field(data_dir, target_plates=TARGET_PLATES, prefer_z=10)
print(f"Found {len(anchors)} field anchors for plates {sorted(TARGET_PLATES)}.")

completed = load_completed_sets()
processed = skipped = 0

for p in anchors:
    anchor_rel = os.path.relpath(p, data_dir)

    rel_for_name = anchor_rel.replace("/", "__")
    base_bn = rel_for_name.replace("_C00(DAPI)_M0000_ORG.tif", "")
    fig_path = os.path.join(output_dir, f"{base_bn}_two_panel_classified.png")

    if (not FORCE_RERUN) and (anchor_rel in completed) and os.path.exists(fig_path):
        skipped += 1
        continue

    try:
        row = process_one(p, output_dir)
        append_row_to_csv(row)
        processed += 1
        print(f"✔ {anchor_rel}")
    except Exception as e:
        print(f"✖ {anchor_rel} -> {e}")

print(f"\nDone. Processed: {processed} | Skipped: {skipped}")
print(f"Panels -> {output_dir}")
print(f"Metrics CSV -> {csv_out_path}")
print(f"Done list -> {done_list_path}")
