In [None]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import tkinter.font as tkfont
import os, re, random, cv2
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import segmentation_models_pytorch as smp
from sklearn.metrics import jaccard_score, mean_squared_error
from sklearn.linear_model import RANSACRegressor, LinearRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage import img_as_float

# ---------------------------
# Utilities
# ---------------------------
def natural_sort_key(s):
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'([0-9]+)', s)]

# ---------------------------
# Boundary fitting & overlay
# ---------------------------
def fit_and_plot_surfaces_v2(
    image_np,
    boundary_mask,
    image_name="Image",
    save_path=None,
    initial_residual=30,
    final_residual=25,
    min_points=10,
    gt_mask=None,
    show_individual_fits=False,
    show_final_overlay=True,
    min_segment_length=10,
    max_gap=4
):
    h, w = boundary_mask.shape
    final_overlay = np.stack([image_np]*3, axis=-1).copy()

    boundary_metrics = {
        'Cornea Top':    {'mse': None, 'points': 0, 'x_dense': [], 'y_dense': [], 'color_pred': [255, 0, 0],   'color_gt': [0, 255, 0]},
        'Cornea Bottom': {'mse': None, 'points': 0, 'x_dense': [], 'y_dense': [], 'color_pred': [255, 0, 0],   'color_gt': [0, 255, 0]},
        'Lens Top':      {'mse': None, 'points': 0, 'x_dense': [], 'y_dense': [], 'color_pred': [255, 165, 0], 'color_gt': [0, 0, 255]},
        'Lens Bottom':   {'mse': None, 'points': 0, 'x_dense': [], 'y_dense': [], 'color_pred': [255, 165, 0], 'color_gt': [0, 0, 255]}
    }

    thickness_metrics = {
        'Cornea': {'pred': {'mean': None, 'std': None}, 'gt': {'mean': None, 'std': None}},
        'Lens':   {'pred': {'mean': None, 'std': None}, 'gt': {'mean': None, 'std': None}}
    }

    surfaces = [("Cornea Top", 1, True), ("Cornea Bottom", 1, False),
                ("Lens Top", 2, True),   ("Lens Bottom", 2, False)]

    def extract_filtered_boundaries(mask, class_id):
        x_top, y_top, x_bot, y_bot = [], [], [], []
        for x in range(w):
            ys = np.where(mask[:, x] == class_id)[0]
            if len(ys) > 0:
                x_top.append(x); y_top.append(ys[0])
                x_bot.append(x); y_bot.append(ys[-1])

        def filter_segments(x_arr, y_arr):
            if len(x_arr) < 2: return np.array([]), np.array([])
            order = np.argsort(x_arr)
            xs, ys = np.array(x_arr)[order], np.array(y_arr)[order]
            seg = []
            current = [(xs[0], ys[0])]
            for i in range(1, len(xs)):
                px, py = current[-1]
                cx, cy = xs[i], ys[i]
                if (cx - px <= max_gap + 1) and (abs(cy - py) <= max_gap):
                    current.append((cx, cy))
                else:
                    if len(current) >= min_segment_length: seg.extend(current)
                    current = [(cx, cy)]
            if len(current) >= min_segment_length: seg.extend(current)
            if not seg: return np.array([]), np.array([])
            return np.array([p[0] for p in seg]), np.array([p[1] for p in seg])

        xt, yt = filter_segments(x_top, y_top)
        xb, yb = filter_segments(x_bot, y_bot)
        return xt, yt, xb, yb

    def maybe_save(fig, fname):
        if save_path:
            os.makedirs(save_path, exist_ok=True)
            fig.savefig(os.path.join(save_path, fname), dpi=200, bbox_inches='tight')

    has_gt_global = (gt_mask is not None) and np.any(gt_mask)

    for label_text, class_id, top in surfaces:
        xt, yt, xb, yb = extract_filtered_boundaries(boundary_mask, class_id)
        x_pred = xt if top else xb
        y_pred = yt if top else yb
        if len(x_pred) < min_points:
            continue

        if has_gt_global:
            try:
                gxt, gyt, gxb, gyb = extract_filtered_boundaries(gt_mask, class_id)
                gt_x = gxt if top else gxb
                gt_y = gyt if top else gyb
            except Exception:
                gt_x, gt_y = np.array([]), np.array([])
        else:
            gt_x, gt_y = np.array([]), np.array([])

        try:
            ransac = RANSACRegressor(
                estimator=make_pipeline(PolynomialFeatures(3), LinearRegression()),
                residual_threshold=initial_residual,
                min_samples=min_points
            )
            ransac.fit(x_pred.reshape(-1, 1), y_pred)
            mask = ransac.inlier_mask_
            x_in, y_in = x_pred[mask], y_pred[mask]
            x_out, y_out = x_pred[~mask], y_pred[~mask]
        except Exception:
            x_in, y_in = x_pred, y_pred
            x_out, y_out = np.array([]), np.array([])

        coeffs = np.polyfit(x_in, y_in, 3)
        x_dense = np.linspace(min(x_in), max(x_in), 300)
        y_dense = np.polyval(coeffs, x_dense)

        boundary_metrics[label_text]['x_dense'] = x_dense
        boundary_metrics[label_text]['y_dense'] = y_dense
        boundary_metrics[label_text]['points']  = len(x_in)

        if has_gt_global and len(gt_x) > 4:
            interp_gt = np.interp(x_dense, gt_x, gt_y, left=np.nan, right=np.nan)
            valid = ~np.isnan(interp_gt)
            if valid.sum() > 4:
                mse = mean_squared_error(interp_gt[valid], y_dense[valid])
                boundary_metrics[label_text]['mse'] = mse

        if show_individual_fits:
            fig = plt.figure(figsize=(10,6))
            plt.imshow(image_np, cmap='gray')
            if len(x_out)>0: plt.plot(x_out, y_out, 'rx', ms=4, label='Initial Rejected')
            plt.plot(x_pred, y_pred, 'yo', ms=3, label='All Points', alpha=0.5)
            plt.plot(x_in, y_in, 'go', ms=4, label='Inliers')
            plt.plot(x_in, y_in, 'bo', ms=4, label='Final Points')
            plt.plot(x_dense, y_dense, 'c-', lw=2, label='Fitted Curve')
            if len(gt_x)>0: plt.plot(gt_x, gt_y, 'm--', lw=2, label="GT")
            m = boundary_metrics[label_text]['mse']
            t = f"{label_text} Fit (n={len(x_in)})" + (f"\nMSE: {m:.4f}" if m is not None else "")
            plt.title(t); plt.legend(); plt.tight_layout()
            maybe_save(fig, f"{image_name}_{label_text.replace(' ','_')}_fit.png")
            plt.show(); plt.close(fig)

        for xi, yi in zip(x_dense, y_dense):
            xi, yi = int(round(xi)), int(round(yi))
            if 0 <= xi < w and 0 <= yi < h:
                final_overlay[yi, xi] = boundary_metrics[label_text]['color_pred']

        if has_gt_global and len(gt_x) > 0:
            for xi, yi in zip(gt_x, gt_y):
                xi, yi = int(round(xi)), int(round(yi))
                if 0 <= xi < w and 0 <= yi < h:
                    final_overlay[yi, xi] = boundary_metrics[label_text]['color_gt']

    def thickness(top_x, top_y, bottom_x, bottom_y):
        if len(top_x)==0 or len(bottom_x)==0: return None, None
        center = w // 2
        xs = [center-5, center, center+5]
        top_yi = np.interp(xs, top_x, top_y, left=np.nan, right=np.nan)
        bot_yi = np.interp(xs, bottom_x, bottom_y, left=np.nan, right=np.nan)
        v = ~np.isnan(top_yi) & ~np.isnan(bot_yi)
        if v.sum()==0: return None, None
        th = bot_yi[v] - top_yi[v]
        return float(np.mean(th)), float(np.std(th))

    for name, top_name, bot_name in [('Cornea', 'Cornea Top', 'Cornea Bottom'),
                                     ('Lens',   'Lens Top',   'Lens Bottom')]:
        if len(boundary_metrics[top_name]['x_dense']) and len(boundary_metrics[bot_name]['x_dense']):
            mean_t, std_t = thickness(boundary_metrics[top_name]['x_dense'], boundary_metrics[top_name]['y_dense'],
                                      boundary_metrics[bot_name]['x_dense'], boundary_metrics[bot_name]['y_dense'])
            thickness_metrics[name]['pred']['mean'] = mean_t
            thickness_metrics[name]['pred']['std']  = std_t

    if has_gt_global:
        try:
            def get_gt(cid):
                xt, yt, xb, yb = extract_filtered_boundaries(gt_mask, cid)
                return (xt, yt), (xb, yb)
            (c_t, c_b) = get_gt(1)
            (l_t, l_b) = get_gt(2)
            if len(c_t[0]) and len(c_b[0]):
                m, s = thickness(c_t[0], c_t[1], c_b[0], c_b[1])
                thickness_metrics['Cornea']['gt']['mean'] = m
                thickness_metrics['Cornea']['gt']['std']  = s
            if len(l_t[0]) and len(l_b[0]):
                m, s = thickness(l_t[0], l_t[1], l_b[0], l_b[1])
                thickness_metrics['Lens']['gt']['mean'] = m
                thickness_metrics['Lens']['gt']['std']  = s
        except Exception:
            pass

    if show_final_overlay:
        fig = plt.figure(figsize=(16,10))
        plt.imshow(final_overlay)
        legend = [
            plt.Line2D([0],[0], marker='o', color='w', label='Pred Cornea', markerfacecolor='red',    markersize=10),
            plt.Line2D([0],[0], marker='o', color='w', label='GT Cornea',   markerfacecolor='green',  markersize=10),
            plt.Line2D([0],[0], marker='o', color='w', label='Pred Lens',   markerfacecolor='orange', markersize=10),
            plt.Line2D([0],[0], marker='o', color='w', label='GT Lens',     markerfacecolor='blue',   markersize=10),
        ]
        info = ["Thickness (px):"]
        for nm in ['Cornea','Lens']:
            if thickness_metrics[nm]['pred']['mean'] is not None:
                info.append(f"Pred {nm}: {thickness_metrics[nm]['pred']['mean']:.1f} ± {thickness_metrics[nm]['pred']['std']:.1f}")
            if has_gt_global and thickness_metrics[nm]['gt']['mean'] is not None:
                info.append(f"GT {nm}: {thickness_metrics[nm]['gt']['mean']:.1f} ± {thickness_metrics[nm]['gt']['std']:.1f}")
        legend.append(plt.Line2D([0],[0], marker='', color='w', label="\n".join(info), markersize=10))

        if has_gt_global:
            rtxt = ["RMSE per surface:"]
            for surf in ['Cornea Top','Cornea Bottom','Lens Top','Lens Bottom']:
                m = boundary_metrics[surf]['mse']
                if m is not None: rtxt.append(f"{surf}: {np.sqrt(m):.4f}")
            legend.append(plt.Line2D([0],[0], marker='', color='w', label="\n".join(rtxt), markersize=10))

        plt.legend(handles=legend, bbox_to_anchor=(1.05,1), loc='upper left', fontsize=10)
        plt.title(f"{image_name} - Boundary Analysis"); plt.tight_layout(); plt.subplots_adjust(right=0.7)
        maybe_save(fig, f"{image_name}_boundary_analysis.png")
        plt.show(); plt.close(fig)

    return {'boundary_metrics': boundary_metrics, 'thickness_metrics': thickness_metrics, 'final_overlay': final_overlay}

# ---------------------------
# Main UI
# ---------------------------
class OCTSegmentationUI:
    def __init__(self, root):
        self.root = root
        self.root.title("OCT Image Segmentation")
        self.root.geometry("900x620")
        self._setup_styles()

        # Defaults
        self.DEFAULT_IMAGE_DIR = "/Users/wanneslutsdemartelaere/Desktop/OCT Data/ALL DATA"
        self.DEFAULT_MASK_DIR  = "/Users/wanneslutsdemartelaere/Desktop/OCT Data/ALL MASKS"
        self.TARGET_SIZE = (768, 767)
        self.MODEL_OPTIONS = {
            1: "Model_ALLDATA_model_1.pth",
            2: "Model_ALLDATA_model_2.pth",
            3: "Model_ALLDATA_model_3.pth",
            4: "Model_ALLDATA_model_4.pth",
            5: "Model_ALLDATA_model_5.pth"
        }

        # Vars
        self.image_path = tk.StringVar(value=self.DEFAULT_IMAGE_DIR)
        self.mask_path  = tk.StringVar()
        self.use_masks  = tk.BooleanVar(value=False)
        self.selected_model = 4
        self.model_var = tk.IntVar(value=4)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.save_images = tk.BooleanVar(value=False)
        self.show_summary = tk.BooleanVar(value=True)
        self.vis_option = tk.IntVar(value=4)  # default: All three

        self.create_widgets()

    # ---------- Styling ----------
    def _setup_styles(self):
        style = ttk.Style()
        try: style.theme_use('clam')
        except Exception: pass

        self._palette = {
            'bg': '#F5F5F7', 'panel': '#FFFFFF', 'text': '#1D1D1F',
            'subtext': '#6E6E73', 'accent': '#0A84FF', 'accentHo': '#0066D6',
            'border': '#E5E5EA'
        }
        self.root.configure(bg=self._palette['bg'])
        base = tkfont.nametofont('TkDefaultFont'); base.configure(family='SF Pro Text', size=12)
        try:
            tkfont.nametofont('TkMenuFont').configure(family='SF Pro Text', size=12)
            tkfont.nametofont('TkFixedFont').configure(family='SF Mono', size=12)
        except Exception: pass

        self._title_font    = tkfont.Font(family='SF Pro Display', size=20, weight='bold')
        self._subtitle_font = tkfont.Font(family='SF Pro Text', size=12)
        self._section_font  = tkfont.Font(family='SF Pro Text', size=13, weight='bold')

        style.configure('TFrame', background=self._palette['bg'])
        style.configure('TLabel', background=self._palette['bg'], foreground=self._palette['text'])
        style.configure('TCheckbutton', background=self._palette['bg'], foreground=self._palette['text'])
        style.configure('TRadiobutton', background=self._palette['bg'], foreground=self._palette['text'])
        style.configure('TEntry', fieldbackground=self._palette['panel'])
        style.configure('Card.TFrame', background=self._palette['panel'], relief='flat')
        style.configure('Heading.TLabel', background=self._palette['bg'], foreground=self._palette['text'], font=self._title_font)
        style.configure('Subheading.TLabel', background=self._palette['bg'], foreground=self._palette['subtext'], font=self._subtitle_font)
        style.configure('Section.TLabel', background=self._palette['panel'], foreground=self._palette['text'], font=self._section_font)
        style.configure('Primary.TButton', background=self._palette['accent'], foreground='white', padding=10, borderwidth=0)
        style.map('Primary.TButton', background=[('active', self._palette['accentHo'])])
        style.configure('Accent.TButton', background=self._palette['accent'], foreground='white', padding=8, borderwidth=0)
        style.map('Accent.TButton', background=[('active', self._palette['accentHo'])])
        style.configure('TSeparator', background=self._palette['border'])

    def _style_toplevel(self, tl):
        tl.configure(bg=self._palette['bg'])

    def _rounded_card(self, parent, pad=16):
        cnv = tk.Canvas(parent, bg=self._palette['bg'], highlightthickness=0, bd=0)
        cnv.pack(fill=tk.X, pady=8)
        body = ttk.Frame(cnv, style='Card.TFrame', padding=pad)
        win = cnv.create_window(0, 0, anchor='nw', window=body)
        r = 16
        def redraw(_=None):
            cnv.delete('rr')
            w = max(cnv.winfo_width(), 400)
            h = body.winfo_reqheight() + pad
            cnv.config(height=h + pad)
            x1, y1, x2, y2 = 6, 6, w-6, h
            # rounded rect (two rects + four arcs)
            cnv.create_rectangle(x1+r, y1, x2-r, y2, fill=self._palette['panel'], outline='', tags='rr')
            cnv.create_rectangle(x1, y1+r, x2, y2-r, fill=self._palette['panel'], outline='', tags='rr')
            cnv.create_arc(x1, y1, x1+2*r, y1+2*r, start=90, extent=90, fill=self._palette['panel'], outline='', tags='rr')
            cnv.create_arc(x2-2*r, y1, x2, y1+2*r, start=0, extent=90, fill=self._palette['panel'], outline='', tags='rr')
            cnv.create_arc(x1, y2-2*r, x1+2*r, y2, start=180, extent=90, fill=self._palette['panel'], outline='', tags='rr')
            cnv.create_arc(x2-2*r, y2-2*r, x2, y2, start=270, extent=90, fill=self._palette['panel'], outline='', tags='rr')
            cnv.coords(win, x1+pad//2, y1+pad//2)
        cnv.bind('<Configure>', redraw)
        body.bind('<Configure>', lambda e: redraw())
        return body

    # ---------- Layout ----------
    def create_widgets(self):
        main = ttk.Frame(self.root, padding=20); main.pack(expand=True, fill=tk.BOTH)
        header = ttk.Frame(main); header.pack(fill=tk.X, pady=(0,12))
        ttk.Label(header, text='👁️  OCT Image Segmentation', style='Heading.TLabel').pack(anchor=tk.W)
        ttk.Label(header, text='End-to-end pipeline • Elegant, simple, fast', style='Subheading.TLabel').pack(anchor=tk.W, pady=(2,0))
        ttk.Separator(main).pack(fill=tk.X, pady=10)

        cards = ttk.Frame(main); cards.pack(expand=True, fill=tk.BOTH)

        c1 = self._rounded_card(cards)
        ttk.Label(c1, text='Step 1: Image Selection', style='Section.TLabel').pack(anchor=tk.W)
        ttk.Label(c1, text='Default path will be used if left blank.', background=self._palette['panel'], foreground=self._palette['subtext']).pack(anchor=tk.W, pady=(2,6))
        r1 = ttk.Frame(c1, style='Card.TFrame'); r1.pack(fill=tk.X, pady=4)
        ttk.Entry(r1, textvariable=self.image_path, width=60).pack(side=tk.LEFT, padx=(0,8))
        ttk.Button(r1, text='Browse…', style='Accent.TButton', command=self.browse_image).pack(side=tk.LEFT)

        c2 = self._rounded_card(cards)
        ttk.Label(c2, text='Step 2: Mask Options', style='Section.TLabel').pack(anchor=tk.W)
        ttk.Checkbutton(c2, text='Use Masks', variable=self.use_masks, command=self.toggle_mask_input).pack(anchor=tk.W, pady=(6,4))
        r2 = ttk.Frame(c2, style='Card.TFrame'); r2.pack(fill=tk.X, pady=(0,4))
        self.mask_entry = ttk.Entry(r2, textvariable=self.mask_path, width=60, state=tk.DISABLED)
        self.mask_entry.pack(side=tk.LEFT, padx=(0,8))
        self.mask_browse_btn = ttk.Button(r2, text='Browse…', style='Accent.TButton', command=self.browse_mask, state=tk.DISABLED)
        self.mask_browse_btn.pack(side=tk.LEFT)

        footer = ttk.Frame(main); footer.pack(fill=tk.X, pady=(10,0))
        ttk.Separator(footer).pack(fill=tk.X, pady=(0,10))
        r3 = ttk.Frame(footer); r3.pack(fill=tk.X)
        ttk.Button(r3, text='Run Pipeline', style='Primary.TButton', command=self.start_pipeline).pack(side=tk.RIGHT)

    # ---------- Path handlers ----------
    def toggle_mask_input(self):
        if self.use_masks.get():
            self.mask_entry.config(state=tk.NORMAL)
            self.mask_browse_btn.config(state=tk.NORMAL)
            self.mask_path.set(self.DEFAULT_MASK_DIR)
        else:
            self.mask_entry.config(state=tk.DISABLED)
            self.mask_browse_btn.config(state=tk.DISABLED)
            self.mask_path.set("")

    def browse_image(self):
        initial = self.DEFAULT_IMAGE_DIR if os.path.exists(self.DEFAULT_IMAGE_DIR) else "/"
        p = filedialog.askdirectory(initialdir=initial)
        if p: self.image_path.set(p)

    def browse_mask(self):
        initial = self.DEFAULT_MASK_DIR if os.path.exists(self.DEFAULT_MASK_DIR) else "/"
        p = filedialog.askdirectory(initialdir=initial)
        if p: self.mask_path.set(p)

    # ---------- Pipeline ----------
    def start_pipeline(self):
        if not self.validate_paths(): return
        image_dir = self.image_path.get() or self.DEFAULT_IMAGE_DIR
        mask_dir  = self.mask_path.get() if self.use_masks.get() else None
        resized = self.check_and_resize_images(image_dir, mask_dir)
        if resized: self.show_noise_reduction_dialog()
        else:       self.start_segmentation()

    def validate_paths(self):
        img = self.image_path.get() or self.DEFAULT_IMAGE_DIR
        if not os.path.exists(img):
            messagebox.showerror("Error", f"Image path does not exist:\n{img}")
            return False
        if self.use_masks.get():
            msk = self.mask_path.get() or self.DEFAULT_MASK_DIR
            if not os.path.exists(msk):
                messagebox.showerror("Error", f"Mask path does not exist:\n{msk}")
                return False
        return True

    def check_and_resize_images(self, image_dir, mask_dir=None):
        need = False
        for f in os.listdir(image_dir):
            if f.lower().endswith(('.png','.jpg','.jpeg','.tiff','.bmp')):
                try:
                    with Image.open(os.path.join(image_dir,f)) as im:
                        if im.size != self.TARGET_SIZE: need=True; break
                except: pass
        if not need and mask_dir and self.use_masks.get():
            for f in os.listdir(mask_dir):
                if f.lower().endswith(('.png','.jpg','.jpeg','.tiff','.bmp')):
                    try:
                        with Image.open(os.path.join(mask_dir,f)) as im:
                            if im.size != self.TARGET_SIZE: need=True; break
                    except: pass
        if need: self.resize_images(image_dir, mask_dir)
        return need

    def resize_images(self, image_dir, mask_dir=None):
        for f in os.listdir(image_dir):
            if f.lower().endswith(('.png','.jpg','.jpeg','.tiff','.bmp')):
                p = os.path.join(image_dir, f)
                try:
                    img = cv2.imread(p, cv2.IMREAD_UNCHANGED)
                    if img.shape[1]!=self.TARGET_SIZE[0] or img.shape[0]!=self.TARGET_SIZE[1]:
                        cv2.imwrite(p, cv2.resize(img, self.TARGET_SIZE, interpolation=cv2.INTER_AREA))
                except Exception as e:
                    print("Resize image fail:", f, e)
        if mask_dir and self.use_masks.get():
            for f in os.listdir(mask_dir):
                if f.lower().endswith(('.png','.jpg','.jpeg','.tiff','.bmp')):
                    p = os.path.join(mask_dir, f)
                    try:
                        mk = cv2.imread(p, cv2.IMREAD_UNCHANGED)
                        if mk.shape[1]!=self.TARGET_SIZE[0] or mk.shape[0]!=self.TARGET_SIZE[1]:
                            cv2.imwrite(p, cv2.resize(mk, self.TARGET_SIZE, interpolation=cv2.INTER_NEAREST))
                    except Exception as e:
                        print("Resize mask fail:", f, e)

    # ---------- Progress & dialogs (rounded styling) ----------
    def _rounded_dialog(self, title, size="560x220"):
        dlg = tk.Toplevel(self.root)
        self._style_toplevel(dlg)
        dlg.title(title); dlg.geometry(size); dlg.resizable(False, False)
        wrapper = ttk.Frame(dlg, padding=14); wrapper.pack(expand=True, fill=tk.BOTH)
        card = self._rounded_card(wrapper, pad=14)
        return dlg, card

    def show_noise_reduction_dialog(self):
        dlg, card = self._rounded_dialog("Additional Processing", "560x240")
        ttk.Label(card, text="Images were resized", style='Section.TLabel').pack(anchor=tk.W)
        ttk.Label(card, text="Would you like to also apply noise reduction processing?",
                  background=self._palette['panel'], foreground=self._palette['subtext']).pack(anchor=tk.W, pady=(4,8))
        row = ttk.Frame(card, style='Card.TFrame'); row.pack(pady=(6,0))
        ttk.Button(row, text="Yes, Apply Noise Reduction", style='Primary.TButton',
                   command=lambda: [self.run_noise_reduction(dlg)]).pack(side=tk.LEFT, padx=6)
        ttk.Button(row, text="No, Continue", style='Accent.TButton',
                   command=lambda: [dlg.destroy(), self.start_segmentation()]).pack(side=tk.LEFT, padx=6)

    def create_progress_window(self, title, label):
        self.progress_window, body = self._rounded_dialog(title, "600x180")
        ttk.Label(body, text=label, background=self._palette['panel']).pack(anchor=tk.W, pady=(0,6))
        self.progress_bar = ttk.Progressbar(body, orient='horizontal', length=520, mode='determinate')
        self.progress_bar.pack(pady=4)
        self.progress_label = ttk.Label(body, text="0/0", background=self._palette['panel'], foreground=self._palette['subtext'])
        self.progress_label.pack(anchor=tk.W)

    def update_progress(self, i, n):
        self.progress_bar["value"] = (i/max(n,1))*100
        self.progress_label.config(text=f"{i}/{n}")
        self.progress_window.update()

    # ---------- Noise reduction ----------
    def run_noise_reduction(self, dialog=None):
        if dialog: dialog.destroy()
        image_dir = self.image_path.get() or self.DEFAULT_IMAGE_DIR
        files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png','.jpg','.jpeg','.tif','.tiff'))]
        n = len(files)
        if n>1: self.create_progress_window("Processing Images", "Applying noise reduction to images…"); self.update_progress(0,n)
        try:
            for idx, f in enumerate(sorted(files, key=natural_sort_key), 1):
                try:
                    p = os.path.join(image_dir, f)
                    img = Image.open(p).convert("L")
                    arr = img_as_float(np.array(img))
                    work = arr.copy()
                    for _ in range(3,5):
                        s = np.mean(estimate_sigma(work, channel_axis=None))
                        mild = denoise_nl_means(work, h=1.1*s, fast_mode=True, patch_size=2, patch_distance=10, channel_axis=None)
                        mild_u8 = (np.clip(mild,0,1)*255).astype(np.uint8)
                        work_u8 = (np.clip(work,0,1)*255).astype(np.uint8)
                        res = np.abs(work_u8 - mild_u8)
                        mask = (res>10).astype(np.uint8)
                        strong = denoise_nl_means(work, h=5*s, fast_mode=False, patch_size=3, patch_distance=10, channel_axis=None)
                        work = (1-mask)*work + mask*strong
                    Image.fromarray((np.clip(work,0,1)*255).astype(np.uint8)).save(p)
                    if n>1: self.update_progress(idx,n)
                except Exception as e:
                    print("NR fail:", f, e)
        finally:
            if n>1 and hasattr(self, 'progress_window'): self.progress_window.destroy()
        self.start_segmentation()

    # ---------- Segmentation ----------
    def start_segmentation(self):
        # Model selection
        dlg, card = self._rounded_dialog("Model Selection", "420x330")
        ttk.Label(card, text="Select Model Version", style='Section.TLabel').pack(anchor=tk.W, pady=(0,6))
        box = ttk.Frame(card, style='Card.TFrame'); box.pack(anchor=tk.W)
        for i in range(1,6):
            ttk.Radiobutton(box, text=f"Model {i}", variable=self.model_var, value=i).pack(anchor=tk.W)
        ttk.Button(card, text="Continue", style='Primary.TButton',
                   command=lambda: [self.set_selected_model(self.model_var.get()), dlg.destroy(), self.show_data_selection_dialog()]).pack(pady=(10,0))

    def set_selected_model(self, m): self.selected_model = m

    def show_data_selection_dialog(self):
        dlg, card = self._rounded_dialog("Data & Visualization", "620x520")
        ttk.Label(card, text="Select Data to Segment", style='Section.TLabel').pack(anchor=tk.W)
        v = tk.IntVar(value=1)
        ttk.Radiobutton(card, text="Segment all images", variable=v, value=1).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="Segment 10 random images", variable=v, value=2).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="Select specific images (comma separated)", variable=v, value=3).pack(anchor=tk.W)
        self.custom_entry = ttk.Entry(card, width=40, state=tk.DISABLED); self.custom_entry.pack(anchor=tk.W, pady=(4,8))
        def toggle(*_): self.custom_entry.config(state=(tk.NORMAL if v.get()==3 else tk.DISABLED))
        v.trace_add('write', lambda *_: toggle())

        ttk.Separator(card).pack(fill=tk.X, pady=8)
        ttk.Label(card, text="Visualization", style='Section.TLabel').pack(anchor=tk.W, pady=(2,2))
        # 0 = numbers only
        ttk.Radiobutton(card, text="No visualization — just final numbers", variable=self.vis_option, value=0).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="Segmented Areas", variable=self.vis_option, value=1).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="Boundary Edges", variable=self.vis_option, value=2).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="Smoothed Boundary Edges", variable=self.vis_option, value=3).pack(anchor=tk.W)
        ttk.Radiobutton(card, text="All three (areas + edges + smoothed)", variable=self.vis_option, value=4).pack(anchor=tk.W)

        ttk.Checkbutton(card, text="Save outputs", variable=self.save_images).pack(anchor=tk.W, pady=(6,2))
        ttk.Checkbutton(card, text="Show summary metrics (IoU, RMSE)", variable=self.show_summary).pack(anchor=tk.W)

        ttk.Button(card, text="Run Segmentation", style='Primary.TButton',
                   command=lambda: [dlg.destroy(), self.run_segmentation(v.get())]).pack(pady=(12,0))

    def run_segmentation(self, option):
        image_dir = self.image_path.get() or self.DEFAULT_IMAGE_DIR
        mask_dir  = self.mask_path.get() if self.use_masks.get() else None

        model_path = self.MODEL_OPTIONS[self.selected_model]
        model = smp.Unet(encoder_name="timm-efficientnet-b2", encoder_weights="imagenet", classes=3, activation=None)
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        model = model.to(self.device); model.eval()

        dataset = OCTFolderDataset(image_dir, mask_dir)

        # indices
        if option == 1:
            indices = list(range(len(dataset)))
        elif option == 2:
            indices = random.sample(range(len(dataset)), min(10, len(dataset)))
        else:
            raw = self.custom_entry.get()
            try:
                indices = [int(x.strip())-1 for x in raw.split(",")]
                indices = [i for i in indices if 0 <= i < len(dataset)]
                if not indices: raise ValueError
            except:
                messagebox.showerror("Error", "Invalid format. Use numbers like: 1,2,3")
                return

        self.create_progress_window("Segmenting Images", "Running segmentation on images…")
        self.update_progress(0, len(indices))

        results = []
        for i, idx in enumerate(indices, 1):
            if mask_dir:
                img_t, gt_t, orig = dataset[idx]
            else:
                img_t, orig = dataset[idx]
                gt_t = None

            with torch.no_grad():
                out = model(img_t.unsqueeze(0).to(self.device))
                pred = torch.argmax(out, dim=1).squeeze().cpu().numpy()

            pred = pred[:orig[0], :orig[1]]

            if gt_t is not None:
                gt_mask = gt_t[:orig[0], :orig[1]].numpy()
                iou_c = self.compute_iou(gt_mask, pred, 1)
                iou_l = self.compute_iou(gt_mask, pred, 2)
                miou  = (iou_c + iou_l)/2
            else:
                gt_mask = None
                iou_c = iou_l = miou = -1

            results.append({
                "image": dataset.image_filenames[idx],
                "image_path": os.path.join(image_dir, dataset.image_filenames[idx]),
                "gt": gt_mask,
                "pred": pred,
                "iou_cornea": iou_c,
                "iou_lens": iou_l,
                "mean_iou": miou,
                "index": idx,
                "orig_size": orig
            })
            self.update_progress(i, len(indices))

        self.progress_window.destroy()
        self.visualize_results(results, image_dir)

    # ---------- Viz & metrics ----------
    def _sibling_output_dir(self, image_dir):
        parent = os.path.dirname(image_dir.rstrip(os.sep))
        base   = os.path.basename(image_dir.rstrip(os.sep))
        outdir = os.path.join(parent, f"{base}_results_model_{self.selected_model}")
        return outdir

    def visualize_results(self, results, image_dir):
        vis = self.vis_option.get()
        save = self.save_images.get()
        outdir = self._sibling_output_dir(image_dir)

        # For RMSE aggregation
        all_surface_rmse = {k: [] for k in ["Cornea Top","Cornea Bottom","Lens Top","Lens Bottom"]}

        # Always compute metrics; optionally draw
        cases = results

        # If saving OR numbers-only, ensure folder exists (sibling)
        if save or vis == 0:
            os.makedirs(outdir, exist_ok=True)

        if vis in (1,4):  # areas
            self.show_area_visualization(cases, save, outdir)
        if vis in (2,4):  # edges
            self.show_boundary_visualization(cases, save, outdir, max_gap=1, all_surface_rmse=all_surface_rmse)
        if vis in (3,4):  # smoothed
            self.show_boundary_visualization(cases, save, outdir, max_gap=4, all_surface_rmse=all_surface_rmse)

        # Summary + exports
        if self.show_summary.get() or vis == 0:
            self.summary_and_exports(cases, all_surface_rmse, outdir)

        messagebox.showinfo("Done", f"Processed {len(results)} image(s).\nOutputs: {outdir if (save or vis==0) else 'not saved'}")

    def show_area_visualization(self, cases, save, outdir):
        for r in cases:
            try:
                img = np.array(Image.open(r["image_path"]).convert('L'))
            except Exception:
                img = (r["pred"]>0).astype(np.uint8)*255
            fig, axes = plt.subplots(1,2, figsize=(12,6))
            axes[0].imshow(img, cmap='gray'); axes[0].axis('off')
            axes[0].set_title(f"{r['image']}\nIndex: {r['index']+1}")
            overlay = self.overlay_mask(r["gt"], r["pred"]) if r["gt"] is not None else self.pred_to_rgb(r["pred"])
            blended = (0.6*np.stack([img]*3,-1) + 0.4*overlay).astype(np.uint8)
            axes[1].imshow(blended); axes[1].axis('off')
            t="Segmentation Areas"; 
            if r["mean_iou"]>=0: t+=f"\nMean IoU: {r['mean_iou']:.3f}"
            axes[1].set_title(t); self.add_legend(axes[1], mode="area")
            plt.tight_layout()
            if save:
                Image.fromarray(blended).save(os.path.join(outdir, f"areas_{r['image']}"))
            plt.show(); plt.close(fig)

    def show_boundary_visualization(self, cases, save, outdir, max_gap=1, all_surface_rmse=None):
        for r in cases:
            try:
                img = np.array(Image.open(r["image_path"]).convert('L'))[:r["orig_size"][0], :r["orig_size"][1]]
            except Exception:
                img = (r["pred"]>0).astype(np.uint8)*255
            bmask = self.extract_boundaries(r["pred"], target_shape=img.shape, max_gap=max_gap)
            res = fit_and_plot_surfaces_v2(
                image_np=img, boundary_mask=bmask,
                image_name=os.path.splitext(r['image'])[0] + ("_smoothed" if max_gap>1 else "_edges"),
                save_path=outdir if save else None,
                gt_mask=r["gt"] if r["gt"] is not None else None,
                show_individual_fits=False, show_final_overlay=True, max_gap=max_gap
            )
            if r["gt"] is not None and all_surface_rmse is not None:
                for surf, met in res['boundary_metrics'].items():
                    if met['mse'] is not None:
                        all_surface_rmse[surf].append(np.sqrt(met['mse']))
            if save:
                suffix = "edges" if max_gap==1 else "smoothed_edges"
                Image.fromarray(res['final_overlay']).save(os.path.join(outdir, f"{suffix}_{r['image']}"))

    def summary_and_exports(self, cases, all_surface_rmse, outdir):
        # IoU arrays
        iou_c = [r["iou_cornea"] for r in cases if r["iou_cornea"]>=0]
        iou_l = [r["iou_lens"]  for r in cases if r["iou_lens"] >=0]

        def ms(arr):
            if len(arr)==0: return "n/a", "n/a"
            return f"{np.mean(arr):.4f}", f"{np.std(arr):.4f}"

        mc, sc = ms(iou_c); ml, sl = ms(iou_l)

        lines = [
            "📊 Summary metrics (all images):",
            f"Cornea IoU: {mc} ± {sc} (n={len(iou_c)})",
            f"Lens   IoU: {ml} ± {sl} (n={len(iou_l)})",
            ""
        ]
        for surf, arr in all_surface_rmse.items():
            if len(arr):
                lines.append(f"{surf} RMSE: {np.mean(arr):.4f} ± {np.std(arr):.4f} (n={len(arr)})")
            else:
                lines.append(f"{surf} RMSE: n/a")

        summary_text = "\n".join(lines)
        print("\n"+summary_text+"\n")
        messagebox.showinfo("Summary Metrics", summary_text)

        # ---- CSV exports (per-image) ----
        # IoU per image
        iou_csv = os.path.join(outdir, "per_image_iou.csv")
        with open(iou_csv, "w") as f:
            f.write("image,iou_cornea,iou_lens,mean_iou\n")
            for r in cases:
                f.write(f"{r['image']},{r['iou_cornea']:.6f},{r['iou_lens']:.6f},{r['mean_iou']:.6f}\n")

        # RMSE per image (requires boundary refit with GT)
        # Recompute once with max_gap=4 to stabilize; skip drawing
        rmse_csv = os.path.join(outdir, "per_image_surface_rmse.csv")
        with open(rmse_csv, "w") as f:
            f.write("image,cornea_top,cornea_bottom,lens_top,lens_bottom\n")
            for r in cases:
                if r["gt"] is None:
                    f.write(f"{r['image']},,,,\n"); continue
                try:
                    img = np.array(Image.open(r["image_path"]).convert('L'))[:r["orig_size"][0], :r["orig_size"][1]]
                except Exception:
                    img = (r["pred"]>0).astype(np.uint8)*255
                bmask = self.extract_boundaries(r["pred"], target_shape=img.shape, max_gap=4)
                res = fit_and_plot_surfaces_v2(
                    image_np=img, boundary_mask=bmask, save_path=None,
                    show_individual_fits=False, show_final_overlay=False, gt_mask=r["gt"], max_gap=4
                )
                vals = []
                for surf in ['Cornea Top','Cornea Bottom','Lens Top','Lens Bottom']:
                    m = res['boundary_metrics'][surf]['mse']
                    vals.append(np.sqrt(m) if m is not None else np.nan)
                f.write(f"{r['image']}," + ",".join("" if np.isnan(v) else f"{v:.6f}" for v in vals) + "\n")

    # ---------- Helpers ----------
    def pred_to_rgb(self, pred):
        vis = np.zeros((*pred.shape,3), dtype=np.uint8)
        vis[pred==1] = [255,0,0]
        vis[pred==2] = [0,0,255]
        return vis

    def overlay_mask(self, gt, pred):
        if gt is None: gt = np.zeros_like(pred)
        h, w = pred.shape
        ov = np.zeros((h,w,3), dtype=np.uint8)
        ov[gt==1]  = [0,255,0]
        ov[gt==2]  = [0,0,255]
        ov[pred==1]= [255,255,0]
        ov[pred==2]= [255,0,255]
        return ov

    def compute_iou(self, gt, pred, cid):
        gt_b = (gt==cid).astype(np.uint8).flatten()
        pr_b = (pred==cid).astype(np.uint8).flatten()
        if gt_b.sum()==0 and pr_b.sum()==0: return 1.0
        if gt_b.sum()==0 or  pr_b.sum()==0: return 0.0
        return jaccard_score(gt_b, pr_b)

    def add_legend(self, ax, mode="area"):
        """Add a legend to the visualization plot based on the mode."""
        if mode == "area":
            # Legend for segmentation areas
            legend_elements = [
                mpatches.Patch(color='green', label='GT Cornea'),
                mpatches.Patch(color='blue', label='GT Lens'),
                mpatches.Patch(color='yellow', label='Pred Cornea'),
                mpatches.Patch(color='magenta', label='Pred Lens')
            ]
        elif mode == "boundary":
            # Legend for boundary visualization
            legend_elements = [
                mpatches.Patch(color='red', label='Pred Cornea'),
                mpatches.Patch(color='green', label='GT Cornea'),
                mpatches.Patch(color='orange', label='Pred Lens'),
                mpatches.Patch(color='blue', label='GT Lens')
            ]
        else:
            return
        
        ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))

    def extract_boundaries(self, mask, class_ids=(1,2), target_shape=None, min_segment_length=10, max_gap=4):
        def filter_segments(coords, h, max_gap, min_len):
            if not coords: return []
            coords = sorted(coords, key=lambda x: x[0])
            seg=[]; cur=[coords[0]]
            for (x,y) in coords[1:]:
                px,py = cur[-1]
                if (x-px<=max_gap+1) and (abs(y-py)<=max_gap):
                    cur.append((x,y))
                else:
                    if len(cur)>=min_len: seg.extend(cur)
                    cur=[(x,y)]
            if len(cur)>=min_len: seg.extend(cur)
            return seg

        h,w = mask.shape
        b = np.zeros_like(mask, dtype=np.uint8)
        for cid in class_ids:
            binm = (mask==cid).astype(np.uint8)
            top, bot = [], []
            for x in range(w):
                ys = np.where(binm[:,x])[0]
                if len(ys)>0:
                    top.append((x,ys[0])); bot.append((x,ys[-1]))
            ft = filter_segments(top,h,max_gap,min_segment_length)
            fb = filter_segments(bot,h,max_gap,min_segment_length)
            for x,y in ft: b[y,x]=cid
            for x,y in fb: b[y,x]=cid
        if target_shape is not None:
            b = b[:target_shape[0], :target_shape[1]]
        return b

# ---------------------------
# Dataset
# ---------------------------
class OCTFolderDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None):
        self.image_dir = image_dir; self.mask_dir = mask_dir
        def list_images(d):
            files = [f for f in os.listdir(d) if not f.startswith('.')]
            files = [f for f in files if f.lower().endswith(('.png','.jpg','.jpeg','.tif','.tiff','.bmp'))]
            return sorted(files, key=natural_sort_key)
        self.image_filenames = list_images(image_dir)
        self.mask_filenames  = list_images(mask_dir) if (mask_dir and os.path.isdir(mask_dir)) else None

    def __len__(self): return len(self.image_filenames)

    def pad_to_32(self, t):
        h,w = t.shape[-2:]
        nh = ((h+31)//32)*32; nw = ((w+31)//32)*32
        return F.pad(t, (0, nw-w, 0, nh-h))

    def __getitem__(self, idx):
        ip = os.path.join(self.image_dir, self.image_filenames[idx])
        img = Image.open(ip).convert('L'); arr = np.array(img, dtype=np.float32)/255.0
        h,w = arr.shape
        t = torch.from_numpy(arr).unsqueeze(0)
        t = self.pad_to_32(t).repeat(3,1,1)
        if self.mask_filenames is not None:
            mp = os.path.join(self.mask_dir, self.mask_filenames[idx])
            mk = Image.open(mp).convert('L')
            mk = np.array(mk, dtype=np.uint8)
            cls = np.zeros_like(mk, dtype=np.int64)
            cls[(mk>=201)&(mk<=255)] = 1
            cls[(mk>= 80)&(mk<=200)] = 2
            cls = torch.from_numpy(cls)
            cls = self.pad_to_32(cls.unsqueeze(0)).squeeze(0)
            return t, cls, (h,w)
        else:
            return t, (h,w)

# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    root = tk.Tk()
    app = OCTSegmentationUI(root)
    root.mainloop()