# <b>Diabetic Retinopathy Lesion Classification using Machine Learning</b>

- <b>Name: </b>Sibin Shibu
- <b>Student ID: </b>A00014748
- <b>Subject: </b>MSc Dissertation
- <b>Supervisor: </b>Dr. Mohammed Farhan Khan
- <b>Submission Date: </b> August 23, 2025

In [None]:
!rm -r sample_data

In [None]:
# # Delete my_dir

# import shutil
# folder_path = ['myproject/processed_images/segmentation']

# for folder in folder_path:
#   shutil.rmtree(f'/content/{folder}')

# <b>1. Import Required Libraries</b>

In [None]:
# Standard libraries
import os
import random

# Image processing library
import cv2

# data processing libraries
import numpy as np
import pandas as pd
from tqdm import tqdm

# Visualization libraries
import seaborn as sns
import matplotlib.pyplot as plt

# <b>2. Data Preparation</b>

## 2.1 Dataset Extraction (Unzipping)

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("nguyenhung1903/diaretdb1-v21")
print("Path to dataset files:", path)

In [None]:
# Print the top-level folder first
print(f"{path}")

# List its contents (1 level down)
for item in os.listdir(path):
    item_path = os.path.join(path, item)

    # If it's a folder, go inside and list contents
    if os.path.isdir(item_path):
        print(f"- {item}")
        for sub_item in os.listdir(item_path):
            print(f"  - {sub_item}")
    else:
        print(f"  - {item}")

## 2.2 Define Input Directories and Paths

In [None]:
base_path = f'{path}/ddb1_v02_01'

# Image Directory
img_dir = f'{base_path}/images'

# Groundtruth Directory
gtruth_dir = f'{base_path}/groundtruth'

## 2.3 Parse Ground Truth XML Annotations

In [None]:
import xml.etree.ElementTree as ET

# Parser function to extract annotations
def parse_xml_annotation(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    anns = []
    for m in root.findall('.//marking'):
        lesion = m.findtext('markingtype')
        conf   = m.findtext('confidencelevel')

        # Representative point (if present)
        rep_txt = m.findtext('representativepoint/coords2d')
        rep_pt = tuple(map(int, rep_txt.split(','))) if rep_txt else None

        # 1) Polygon‑based (<polygonregion>)
        poly_region = m.find('polygonregion')
        if poly_region is not None:
            # centroid under polygonregion
            c_txt = poly_region.findtext('centroid/coords2d')
            cx, cy = map(int, c_txt.split(',')) if c_txt else (None, None)
            # all coords2d points (including centroid)
            polygon_pts = []
            for coord in poly_region.findall('coords2d'):
                xy = coord.text
                if xy and ',' in xy:
                    x, y = map(int, xy.split(','))
                    polygon_pts.append((x, y))
            anns.append({
                'type':                lesion,
                'centroid':            (cx, cy),
                'representative_pt':   rep_pt,
                'radius':              None,
                'polygon':             polygon_pts,
                'confidence':          conf
            })
            continue

        # 2) Circle‑based (<circleregion>)
        circ = m.find('circleregion')
        if circ is not None:
            c_txt = circ.findtext('centroid/coords2d')
            r_txt = circ.findtext('radius')
            if c_txt and r_txt:
                cx, cy = map(int, c_txt.split(','))
                r      = int(r_txt)
                anns.append({
                    'type':                lesion,
                    'centroid':            (cx, cy),
                    'representative_pt':   rep_pt,
                    'radius':              r,
                    'polygon':             None,
                    'confidence':          conf
                })
    return anns

## 2.4 Create Unified Annotation DataFrame

In [None]:
# Build and display annotation DataFrame
def build_annotation_dataframe():
  records = []
  for fname in sorted(os.listdir(gtruth_dir)):
      if not fname.lower().endswith('.xml'):
          continue

      xml_path = os.path.join(gtruth_dir, fname)
      # Determine image type from filename
      img_type = 'plain' if '_plain' in fname else 'non_plain'
      anns = parse_xml_annotation(xml_path)
      for ann in anns:
          records.append({
              'gtruth_name':            fname,
              'gtruth_type':          img_type,
              'lesion_type':          ann['type'],
              'centroid':             ann['centroid'],
              'representative_point': ann['representative_pt'],
              'radius':               ann['radius'],
              'polygon':              ann['polygon'],
              'confidence':           ann['confidence']
          })
  df = pd.DataFrame(records)
  return df

In [None]:
df = build_annotation_dataframe()
df

## 2.5 Link Annotations with Corresponding Raw Images

In [None]:
# Select non_plain xml points and create a copy to avoid SettingWithCopyWarning
df = df[df['gtruth_type'] == 'non_plain'].copy()

# Filter and remove irrelevant lesion classes
df = df[~df['lesion_type'].isin(['Disc', 'Fundus_area', 'IRMA', 'Neovascularisation'])]

# Derive image_name
df['image_name'] = df['gtruth_name'].str.replace(r"_\d+\.xml$", ".png", regex=True)

# Expand coordinates safely
df[['centroid_x', 'centroid_y']] = pd.DataFrame(df['centroid'].tolist(), index=df.index)
df[['rep_x', 'rep_y']] = pd.DataFrame(df['representative_point'].tolist(), index=df.index)

# Drop the original tuple columns
df = df.drop(columns=['gtruth_name', 'gtruth_type', 'centroid', 'representative_point'])

# Rename Lesion categories for clarity
rename_map = {
    "Red_small_dots": "MA",
    "Haemorrhages"  : "HE",
    "Hard_exudates" : "EX",
    "Soft_exudates" : "CWS"
}
df['lesion_type'] = df['lesion_type'].replace(rename_map)

# Organize dataframe columns
df = df[['image_name', 'lesion_type', 'centroid_x', 'centroid_y', 'radius', 'polygon', 'rep_x', 'rep_y', 'confidence']]
df

## 2.6 Save Raw Annotations

In [None]:
# Save dataframe
df.to_csv('annotations.csv', index=False)

## 2.7 Overlay Annotations on Raw Images

In [None]:
def overlay_annotations(df, img_dir, output_dir, default_color=(0, 255, 0)):
  """
  Overlays lesion annotations (circles, polygons, points) on images and saves them.
  """

  # Default color map
  color_map = {
      'MA'  : (255,   0, 255),  # Magenta
      'HE'  : (  0,   0, 255),  # Red
      'EX'  : (  0, 255, 255),  # Yellow
      'CWS' : (  255, 0, 0),  # Orange
  }

  # Ensure output directory exists
  os.makedirs(output_dir, exist_ok=True)

  # Process images grouped by name
  for img_file, group in df.groupby('image_name'):
      img_path = os.path.join(img_dir, img_file)
      img = cv2.imread(img_path)

      if img is None:
          print(f"[WARNING] Missing image: {img_file}")
          continue

      # Draw annotations
      for _, row in group.iterrows():
          lesion = row['lesion_type']
          col = color_map.get(lesion, default_color)

          # Circle lesions (if radius is available)
          if pd.notnull(row.get('radius', np.nan)):
              center = (int(row['centroid_x']), int(row['centroid_y']))
              cv2.circle(img, center, int(row['radius']), col, 2)

          # Polygon lesions
          if isinstance(row.get('polygon', None), list):
              pts = np.array(row['polygon'], dtype=np.int32).reshape(-1, 1, 2)
              cv2.polylines(img, [pts], isClosed=True, color=col, thickness=2)

          # Representative point (if available)
          if pd.notnull(row.get('rep_x', np.nan)):
              rep = (int(row['rep_x']), int(row['rep_y']))
              cv2.circle(img, rep, 3, col, -1)

      # Save annotated image
      cv2.imwrite(os.path.join(output_dir, img_file), img)

  print("Image annotation completed!")
  print(f"Annotated images saved in: {output_dir}")

In [None]:
output_dir = 'raw_overlay_annotations'
os.makedirs(output_dir, exist_ok=True)

overlay_annotations(df, img_dir, output_dir)

In [None]:
def display_images(folder, cols=5, row_height=4):
  files = [f for f in sorted(os.listdir(folder)) if f.endswith(".png")]
  rows = -(-len(files) // cols)  # ceiling division

  plt.figure(figsize=(20, row_height * rows))
  for i, file in enumerate(files, 1):
      img = cv2.cvtColor(cv2.imread(os.path.join(folder, file)), cv2.COLOR_BGR2RGB)
      plt.subplot(rows, cols, i)
      plt.imshow(img)
      plt.title(file)
      plt.axis("off")

  plt.tight_layout()
  plt.show()

In [None]:
display_images('raw_overlay_annotations')

# <b>3. Exploratory Data Analysis (EDA)</b>

In [None]:
df.describe()

In [None]:
df['lesion_type'].value_counts()

## 3.1 Lesion Type Distribution Pie Chart

In [None]:
"""
Generate pie chart to show distribution of classes (labels) in the dataset.
"""
# Set figure size
plt.figure(figsize=(8, 7))

# Create the scatterplot
plt.pie(df['lesion_type'].value_counts(), labels=df['lesion_type'].value_counts().index, autopct='%1.1f%%')

# Add title and legend
plt.title('Lesion Type Distribution', fontsize=14, fontweight='bold')
plt.legend(title="Lesion Type")

# Show plot
plt.show()

## 3.2 Lesion Type vs Confidence Bar Plot

In [None]:


# Aggregate counts of lesion_type × confidence
counts = df.groupby(["lesion_type", "confidence"]).size().unstack(fill_value=0)

# Reorder lesion types by frequency (descending)
order = df['lesion_type'].value_counts().index
counts = counts.reindex(order)

# Choose a seaborn palette (3 colors, since you have High/Medium/Low)
palette = sns.color_palette("Dark2", n_colors=len(counts.columns))

# Plot stacked bar chart with seaborn palette
counts.plot(
    kind="bar",
    stacked=True,
    figsize=(10,6),
    color=palette
)

plt.title("Lesion Type vs Confidence", fontsize=14, fontweight='bold')
plt.ylabel("Count")
plt.xlabel("Lesion Type")
plt.xticks(rotation=0)
plt.legend(title="Confidence")
plt.tight_layout()
plt.show()

## 3.3 Annotations: Polygon vs Centroid Distribution Bar Plot

In [None]:
# Count polygon vs centroid annotations
polygon_counts = df['polygon'].notna().map({True: 'Polygon-based', False: 'Centroid-based'}).value_counts()

# Plot
plt.figure(figsize=(6,4))
polygon_counts.plot(kind='bar', color=['skyblue', 'salmon'])

plt.title("Annotations: Polygon vs Centroid", fontsize=14, fontweight='bold')
plt.ylabel('Count')
plt.xlabel('Annotation Type')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()


## 3.4 Top 20 Images with Most Lesions

In [None]:
# Group by image and lesion type
lesion_counts = df.groupby(['image_name', 'lesion_type']).size().unstack(fill_value=0)

# Select top 20 images by total lesion count
top_20_images = lesion_counts.sum(axis=1).nlargest(20).index
top_lesion_counts = lesion_counts.loc[top_20_images]

# Plot grouped/hue bar chart
top_lesion_counts.plot(kind='bar', stacked=True, figsize=(14, 6), colormap='tab20')

plt.title('Top 20 Images with Most Lesions', fontsize=14, fontweight='bold')
plt.xlabel('Image Name')
plt.ylabel('Number of Lesions')
plt.xticks(rotation=45, ha='right')
plt.legend(title='Lesion Type', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# <b>4. Image Pre-processing</b>

## 4.1 Define Processed Image Directory

In [None]:
output1 = "preprocessed"
output2 = "preprocessed_stages"

os.makedirs(output1, exist_ok=True)
os.makedirs(output2, exist_ok=True)

## 4.2 Pipeline

In [None]:
# ==================== Single Image Preprocessing ====================
def preprocess_image_with_annotations_stages(img_path, annot_df, target_size):
    """Preprocess one image and update its annotations while capturing stages."""
    image_name = os.path.basename(img_path)
    stages = {}

    # Read image
    image = cv2.imread(img_path)
    stages["Original"] = image.copy()

    # Filter annotations
    annots = annot_df[annot_df['image_name'] == image_name].copy()

    # === 1. Resize ===
    original_w, original_h = image.shape[1], image.shape[0]
    resized_img = cv2.resize(image, target_size)
    scale_x = target_size[0] / original_w
    scale_y = target_size[1] / original_h
    stages["Resized"] = resized_img.copy()

    def scale_polygon(polygon):
        if isinstance(polygon, list):
            return [(px * scale_x, py * scale_y) for px, py in polygon]
        return polygon

    for coord in ['centroid_x', 'rep_x']:
        annots[coord] *= scale_x
    for coord in ['centroid_y', 'rep_y']:
        annots[coord] *= scale_y

    if 'radius' in annots.columns:
        annots['radius'] *= (scale_x + scale_y) / 2

    annots['polygon'] = annots['polygon'].apply(scale_polygon)

    # === 2. Green channel ===
    green = resized_img[:, :, 1]
    stages["Green Channel"] = green.copy()

    # === 3. CLAHE ===
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(green)
    stages["CLAHE"] = enhanced.copy()

    # === 4. Median Filter ===
    denoised = cv2.medianBlur(enhanced, 3)
    stages["Median Filtered"] = denoised.copy()

    # === 5. Morphological Opening ===
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    morph_cleaned = cv2.morphologyEx(denoised, cv2.MORPH_OPEN, kernel)
    stages["Morphological Opening"] = morph_cleaned.copy()

    # === 6. Normalize ===
    normalized = morph_cleaned.astype('float32') / 255.0
    stages["Normalized"] = (normalized * 255).astype(np.uint8)
    return normalized, annots, stages

In [None]:
# ==================== Utility to Save Stages ====================
def save_stages_grid(stages_dict, save_path):
    """Save all preprocessing stages as a single grid image."""
    num_images = len(stages_dict)
    cols = 3
    rows = (num_images + cols - 1) // cols

    plt.figure(figsize=(12, 4 * rows))
    for i, (stage_name, img) in enumerate(stages_dict.items()):
        plt.subplot(rows, cols, i + 1)
        if len(img.shape) == 3:
            plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        else:
            plt.imshow(img, cmap='gray')
        plt.title(stage_name)
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

## 4.3 Save Pre-processed Output

In [None]:
# ==================== Batch Processing ====================
def process_all_images(img_dir, prep_path, stages_path, annot_df, target_size):
  all_annots = []
  for img_name in tqdm(annot_df['image_name'].unique(), desc="Processing images"):
    img_path = os.path.join(img_dir, img_name)

    # Preprocess the image and annotations
    processed_img, updated_annots, stages = preprocess_image_with_annotations_stages(img_path, annot_df, target_size)

    # Save final preprocessed image
    save_path = os.path.join(prep_path, img_name)
    cv2.imwrite(save_path, (processed_img * 255).astype('uint8'))

    # Save the stages grid
    base_name, _ = os.path.splitext(img_name)
    stage_save_path = os.path.join(stages_path, f"{base_name}.png")
    save_stages_grid(stages, stage_save_path)

    # Collect updated annotations
    all_annots.append(updated_annots)

  # Merge all updated annotations into a single DataFrame
  final_df = pd.concat(all_annots, ignore_index=True)
  return final_df

## 4.4 Update and Save Annotations

In [None]:
# Define target resizing size
target_size=(512, 512)

# Process all images
df = process_all_images(img_dir, "preprocessed", "preprocessed_stages", df, target_size)

# Save preprocessed annotations
df.to_csv('preprocessed_annotations.csv', index=False)

In [None]:
display_images("preprocessed")

In [None]:
plt.figure(figsize=(10, 9))
img_path = 'preprocessed_stages/diaretdb1_image005.png'
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

plt.imshow(img_rgb)
plt.title('Preprocessing Stages', fontsize=16, fontweight='bold')
plt.axis('off')

# Display plot
plt.tight_layout()
plt.show()

# <b>5. Segmentation</b>

## 5.1 Define Input Directory

In [None]:
# Segmentation Mask Directory
output = "segmentation"
subdirs = ["major_vessels", "minor_vessels", "optic_disc", "combined", "images", "overlay"]

for s in subdirs:
  os.makedirs(os.path.join(output, s), exist_ok=True)

## 5.2 Segmentation Process


In [None]:
from skimage.filters import frangi

for img_name in tqdm(df['image_name'].unique()):
  img_path = os.path.join("preprocessed", img_name)
  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  h, w = img.shape

  # === 1. Frangi filter (vessels) ===
  vesselness = frangi(img)
  vesselness_norm = (vesselness - vesselness.min()) / (vesselness.max() - vesselness.min())

  # === 2. Separate masks with improved thresholds ===
  major_vessels = (vesselness_norm > 0.25).astype(np.uint8) * 255
  minor_vessels = ((vesselness_norm > 0.1) & (vesselness_norm <= 0.25)).astype(np.uint8) * 255

  # === 3. Optic disc mask ===
  od_mask = np.zeros((h, w), dtype=np.uint8)
  blurred = cv2.GaussianBlur(img, (15, 15), 0)
  _, _, _, max_loc = cv2.minMaxLoc(blurred)
  cx, cy = max_loc
  # r = max(20, int(min(h, w) * 0.05))
  r = int(min(h, w) * 0.08)
  cv2.circle(od_mask, (cx, cy), r, 255, -1)

  # === 4. Save individual masks ===
  cv2.imwrite(os.path.join(output, "major_vessels", img_name), major_vessels)
  cv2.imwrite(os.path.join(output, "minor_vessels", img_name), minor_vessels)
  cv2.imwrite(os.path.join(output, "optic_disc", img_name), od_mask)

  # === 5. Save combined mask ===
  combined_mask = cv2.bitwise_or(major_vessels, minor_vessels)
  combined_mask = cv2.bitwise_or(combined_mask, od_mask)

  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
  combined_mask = cv2.dilate(combined_mask, kernel, iterations=2)
  cv2.imwrite(os.path.join(output, "combined", img_name), combined_mask)

  # === 6. 0 Masking to remove background ===
  combined_mask = (combined_mask > 0).astype(np.uint8)

  cleaned = img.copy()
  cleaned[combined_mask == 1] = 0
  cv2.imwrite(os.path.join(output, "images", img_name), cleaned)

## 5.3 Overlay Segmented Mask on Preprocessed Image

In [None]:
for img_name in tqdm(df['image_name'].unique()):
  img_path = os.path.join("preprocessed", img_name)

  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  img_bgr = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

  major_mask = cv2.imread(os.path.join("segmentation", "major_vessels", img_name), cv2.IMREAD_GRAYSCALE)
  minor_mask = cv2.imread(os.path.join("segmentation", "minor_vessels", img_name), cv2.IMREAD_GRAYSCALE)
  od_mask = cv2.imread(os.path.join("segmentation", "optic_disc", img_name), cv2.IMREAD_GRAYSCALE)

  # Dilate for better visibility
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2,2))
  major_mask = cv2.dilate(major_mask, kernel, iterations=1)
  minor_mask = cv2.dilate(minor_mask, kernel, iterations=1)
  od_mask = cv2.dilate(od_mask, kernel, iterations=1)

  overlay = img_bgr.copy()
  overlay[major_mask > 0] = (0, 0, 255)     # Red
  # overlay[minor_mask > 0] = (0, 255, 255)   # Yellow
  overlay[minor_mask > 0] = (0, 255, 0)   # Green
  overlay[od_mask > 0] = (255, 0, 0)        # Blue

  blended = cv2.addWeighted(img_bgr, 0.5, overlay, 0.8, 0)
  cv2.imwrite(os.path.join("segmentation", "overlay", img_name), blended)

## 5.4 Display Segemntation Output

### 5.4.1 Stages

In [None]:
def visualize_result(image_path, major_path, minor_path, od_path, mask_path, cleaned_path, overlay_path):
  img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
  major = cv2.imread(major_path, cv2.IMREAD_GRAYSCALE)
  minor = cv2.imread(minor_path, cv2.IMREAD_GRAYSCALE)
  od = cv2.imread(od_path, cv2.IMREAD_GRAYSCALE)
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  cleaned = cv2.imread(cleaned_path, cv2.IMREAD_UNCHANGED)
  overlay = cv2.imread(overlay_path, cv2.IMREAD_COLOR)
  overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

  # Bigger canvas
  plt.figure(figsize=(14, 12))

  # Preprocessed image
  plt.subplot(3, 3, 1)
  plt.imshow(img, cmap='gray')
  plt.title("Preprocessed Image")

  # Major vessels
  plt.subplot(3, 3, 2)
  plt.imshow(major, cmap='gray')
  plt.title("Major Vessels")

  # Minor vessels
  plt.subplot(3, 3, 3)
  plt.imshow(minor, cmap='gray')
  plt.title("Minor Vessels")

  # Optic disc
  plt.subplot(3, 3, 4)
  plt.imshow(od, cmap='gray')
  plt.title("Optic Disc")

  # Combined mask
  plt.subplot(3, 3, 5)
  plt.imshow(mask, cmap='gray')
  plt.title("Combined Mask")

  # Cleaned/Segmented Image
  plt.subplot(3, 3, 6)
  plt.imshow(cleaned, cmap='gray')
  plt.title("Cleaned Image")

  # Overlay Image
  plt.subplot(3, 3, 7)
  plt.imshow(overlay)
  plt.title("Overlay Image")

  plt.tight_layout(pad=3.0)  # add padding between subplots
  plt.show()

In [None]:
img_name = df['image_name'].unique()[1]
visualize_result(
  os.path.join("preprocessed", img_name),
  os.path.join("segmentation", "major_vessels", img_name),
  os.path.join("segmentation", "minor_vessels", img_name),
  os.path.join("segmentation", "optic_disc", img_name),
  os.path.join("segmentation", "combined", img_name),
  os.path.join("segmentation", "images", img_name),
  os.path.join("segmentation", "overlay", img_name)
)

### 5.4.2 Major Vessels

In [None]:
display_images("segmentation/major_vessels")

### 5.4.3 Minor Vessels

In [None]:
display_images("segmentation/minor_vessels")

### 5.4.4 Optic Disc

In [None]:
display_images("segmentation/optic_disc")

### 5.4.4 Combined Mask

In [None]:
display_images("segmentation/combined")


### 5.4.5 Segmented Images

In [None]:
display_images("segmentation/images")

### 5.4.5 Overlay

In [None]:
display_images("segmentation/overlay")

# <b>6. Lesion Labelling and Feature Extraction</b>

In [None]:
import math
from skimage.filters import threshold_otsu
from skimage.measure import regionprops, label
from skimage.morphology import convex_hull_image
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import binary_erosion as ndi_binary_erosion

# ----------------- helpers -----------------
def build_lesion_image(img_rgb, g_res):
    """Paper lesion image: [R_orig, G_res, B_orig]."""
    return np.stack([img_rgb[:,:,0], g_res, img_rgb[:,:,2]], axis=-1)

def rgb_to_hsv01(img_rgb):
    """RGB uint8 -> HSV in [0,1]. (OpenCV HSV: H in [0,179], S,V in [0,255])"""
    hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV).astype(np.float32)
    hsv[...,0] /= 179.0
    hsv[...,1] /= 255.0
    hsv[...,2] /= 255.0
    return hsv

def masked_stats(win, mask=None):
    """mean, var, min, max with optional boolean mask (True=use)."""
    if mask is not None:
        sel = win[mask]
    else:
        sel = win.reshape(-1)
    if sel.size == 0:
        return 0.0, 0.0, 0.0, 0.0
    return float(sel.mean()), float(sel.var()), float(sel.min()), float(sel.max())

def threshold_blob(winG):
    """Otsu (fallback 85th pct), keep component at window center."""
    try:
        t = threshold_otsu(winG)
    except Exception:
        t = np.percentile(winG, 85)
    bw = (winG >= t)
    cy, cx = bw.shape[0]//2, bw.shape[1]//2
    lbl = label(bw, connectivity=2)
    cc = lbl[cy, cx]
    return (lbl == cc) if cc != 0 else np.zeros_like(bw, dtype=bool)

def shape_features(bw):
  A = float(bw.sum())
  if A == 0:
      return [0.0]*11
  # perimeter via erosion border (use SciPy’s binary_erosion)
  eroded = ndi_binary_erosion(bw.astype(bool), border_value=0)
  edge = bw.astype(bool) ^ eroded
  P = float(edge.sum())
  props = regionprops(bw.astype(np.uint8))[0]
  circ = (4.0*math.pi*A)/(P*P) if P>0 else 0.0
  eqd  = 2.0*math.sqrt(A/math.pi)
  hull = convex_hull_image(bw.astype(bool))
  ch_area = float(hull.sum()) or A
  bbox_area = max(1.0, (props.bbox[2]-props.bbox[0])*(props.bbox[3]-props.bbox[1]))
  extent = A / bbox_area
  return [
      A, P, circ,
      float(getattr(props,"eccentricity",0.0)),
      float(getattr(props,"major_axis_length",0.0)),
      float(getattr(props,"minor_axis_length",0.0)),
      float(getattr(props,"orientation",0.0)),
      (A/ch_area) if ch_area>0 else 0.0,
      extent, eqd,
      (P/A) if A>0 else 0.0
  ]


def vessel_distance_map(m_vessel):
    """Distance to nearest vessel (expects 1=vessel, 0=else); None -> None."""
    if m_vessel is None:
        return None
    return distance_transform_edt((m_vessel==0).astype(np.uint8)).astype(np.float32)

def parse_polygon(poly):
    """
    Accepts:
      - list/array of [[x,y],...]
      - string like '[(x1,y1),(x2,y2),...]' or 'x1 y1; x2 y2; ...'
    Returns Nx2 float32 array or None.
    """
    if poly is None or (isinstance(poly,float) and np.isnan(poly)):
        return None
    if isinstance(poly, (list, tuple, np.ndarray)):
        arr = np.asarray(poly, dtype=np.float32)
        return arr if (arr.ndim==2 and arr.shape[1]==2) else None
    if isinstance(poly, str):
        s = poly.strip()
        # try literal
        try:
            import ast
            v = ast.literal_eval(s)
            arr = np.asarray(v, dtype=np.float32)
            if arr.ndim==2 and arr.shape[1]==2: return arr
        except Exception:
            pass
        # try "x y; x y; ..."
        pts=[]
        for token in s.replace(',', ' ').split(';'):
            token=token.strip()
            if not token: continue
            parts = token.split()
            if len(parts)>=2:
                try: pts.append([float(parts[0]), float(parts[1])])
                except: pass
        if pts:
            return np.asarray(pts, dtype=np.float32)
    return None

def od_info_from_mask(od_mask):
    """Compute OD center (x,y) and diameter from OD binary mask; return (None,None) if unavailable."""
    if od_mask is None:
        return None, None
    lbl = label((od_mask!=0).astype(np.uint8))
    props = regionprops(lbl)
    if not props:
        return None, None
    r = max(props, key=lambda p: p.area)
    cy, cx = r.centroid   # (row, col) => (y, x)
    diameter = max(r.major_axis_length, r.minor_axis_length)
    return (float(cx), float(cy)), float(diameter)

# ----------------- main extractor -----------------
def extract_features(
    img_bgr,                # OpenCV-loaded original (BGR)
    g_res,                  # green residual (HxW, uint8) – reference size
    image_name="",
    window=17, stride=5,
    m_major_vessel=None,    # binary mask, same HxW as g_res (1=vessel)
    m_minor_vessel=None,    # binary mask, same HxW as g_res
    m_od=None,              # optic disc mask, same HxW as g_res
    gt_df=None              # ground-truth DataFrame (optional)
):
    """
    Returns a DataFrame with:
      metadata: image, x, y
      30 features: 16 intensity (R/G/S/I mean,var,min,max) +
                   11 shape/struct +
                   3 context (distance_to_vessel, distance_to_OD, distance_to_OD_normalized)
      labels: lesion, lesion_group, lesion_type (defaults set to background)
    """
    # ---- sanity checks ----
    if img_bgr is None or g_res is None:
        raise ValueError("img_bgr or g_res is None. Check your paths/loads.")
    if g_res.ndim != 2:
        raise ValueError("g_res must be a single-channel (HxW) uint8 image.")

    # --- 1) Resize raw RGB to match g_res size (no cropping) ---
    Ht, Wt = g_res.shape[:2]
    img_bgr = cv2.resize(img_bgr, (Wt, Ht), interpolation=cv2.INTER_LINEAR)

    # Optional: verify masks sizes if provided
    for name, m in (("m_major_vessel", m_major_vessel),
                    ("m_minor_vessel", m_minor_vessel),
                    ("m_od", m_od)):
        if m is not None and m.shape[:2] != (Ht, Wt):
            raise ValueError(f"{name} must have shape {(Ht, Wt)}, got {m.shape[:2]}.")

    # --- 2) Convert to RGB and prepare planes ---
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    H, W, _ = img_rgb.shape
    half    = window//2

    lesion = build_lesion_image(img_rgb, g_res).astype(np.uint8)
    R, G, B = lesion[:,:,0].astype(np.float32), lesion[:,:,1].astype(np.float32), lesion[:,:,2].astype(np.float32)
    hsv     = rgb_to_hsv01(lesion)
    S       = hsv[:,:,1].astype(np.float32)
    I       = ((R+G+B)/3.0).astype(np.float32)

    # --- 3) Build masks (foreground) ---
    def _bin(x): return None if x is None else (x != 0).astype(np.uint8)
    m_major = _bin(m_major_vessel)
    m_minor = _bin(m_minor_vessel)
    m_od    = _bin(m_od)

    if m_major is not None and m_minor is not None:
        m_vessel = np.clip(m_major + m_minor, 0, 1)
    elif m_major is not None:
        m_vessel = m_major
    else:
        m_vessel = m_minor

    m_fg = np.ones((H,W), dtype=np.uint8)
    if m_vessel is not None:
        m_fg[m_vessel==1] = 0
    if m_od is not None:
        m_fg[m_od==1] = 0

    # --- 4) Context maps ---
    DT_vessel = vessel_distance_map(m_vessel) if m_vessel is not None else None
    od_center, od_diameter = od_info_from_mask(m_od)

    # --- 5) Feature extraction ---
    rows=[]
    for y in range(half, H-half, stride):
        if np.count_nonzero(m_fg[y, half:W-half]) == 0:
            continue
        for x in range(half, W-half, stride):
            if m_fg[y, x] == 0:
                continue

            ys, ye = y-half, y+half+1
            xs, xe = x-half, x+half+1
            W_fg   = m_fg[ys:ye, xs:xe].astype(bool)
            if W_fg.mean() < 0.30:
                continue

            WR, WG, WS, WI = R[ys:ye,xs:xe], G[ys:ye,xs:xe], S[ys:ye,xs:xe], I[ys:ye,xs:xe]

            # 16 intensity
            feats=[]
            for Wp in (WR, WG, WS, WI):
                feats.extend(masked_stats(Wp, W_fg))

            # 11 shape
            try:
                vals = WG[W_fg] if W_fg.any() else WG
                t = threshold_otsu(vals)
            except Exception:
                vals = WG[W_fg] if W_fg.any() else WG
                t = np.percentile(vals, 85)
            BW = (WG >= t)
            cy, cx = BW.shape[0]//2, BW.shape[1]//2
            lbl = label(BW, connectivity=2)
            cc  = lbl[cy, cx]
            BW  = (lbl == cc) if cc != 0 else np.zeros_like(BW, dtype=bool)
            struct = shape_features(BW)

            # 3 context
            d_vessel = float(DT_vessel[y,x]) if DT_vessel is not None else 0.0
            if od_center is not None and od_diameter and od_diameter>0:
                dx, dy = float(x-od_center[0]), float(y-od_center[1])
                d_od = float(math.hypot(dx,dy))
                d_od_norm = d_od/float(od_diameter)
            else:
                d_od, d_od_norm = 0.0, 0.0

            # assemble row
            row = {"image": image_name, "x": int(x), "y": int(y)}
            planes=["R","G","S","I"]; stats=["mean","var","min","max"]
            for pi,p in enumerate(planes):
                for si,s in enumerate(stats):
                    row[f"{p}_{s}"] = feats[pi*4+si]
            names=["area","perimeter","circularity","eccentricity",
                   "major_axis","minor_axis","orientation",
                   "solidity","extent","equiv_diam","perim_area_ratio"]
            for n,v in zip(names, struct):
                row[n]=v
            row["distance_to_vessel"] = d_vessel
            row["distance_to_OD"] = d_od
            row["distance_to_OD_normalized"] = d_od_norm

            # default labels
            row["lesion"] = "not_lesion"
            row["lesion_group"] = "background"
            row["lesion_type"] = "background"

            rows.append(row)

    df = pd.DataFrame(rows)

    # --- 6) Label with GT (polygon → circle), filtered by image_name ---
    if gt_df is not None and len(df):
        gt_rows = gt_df[gt_df["image_name"] == image_name]
        if len(gt_rows):
            xs = df["x"].to_numpy(); ys = df["y"].to_numpy()
            for _, gt in gt_rows.iterrows():
                lesion_type = gt.get("lesion_type", "background")
                group = "red" if lesion_type in ("MA","HE") else ("bright" if lesion_type in ("EX","CWS") else "background")
                inside = None
                poly = parse_polygon(gt.get("polygon", None))
                if poly is not None and len(poly) >= 3:
                    poly_i = poly.astype(np.int32).reshape(-1,1,2)
                    poly_mask = np.zeros((H,W), dtype=np.uint8)
                    cv2.fillPoly(poly_mask, [poly_i], 1)
                    inside = poly_mask[ys, xs].astype(bool)
                if inside is None or not inside.any():
                    cx = float(gt.get("centroid_x", np.nan))
                    cy = float(gt.get("centroid_y", np.nan))
                    r  = float(gt.get("radius", 0.0))
                    if not (np.isnan(cx) or np.isnan(cy) or r<=0):
                        inside = ((xs - cx)**2 + (ys - cy)**2) <= (r*r)
                    else:
                        inside = np.zeros_like(xs, dtype=bool)
                if inside.any():
                    df.loc[inside, "lesion"]       = "lesion"
                    df.loc[inside, "lesion_group"] = group
                    df.loc[inside, "lesion_type"]  = lesion_type

    return df

In [None]:
df_list = []
for img_name in tqdm(df['image_name'].unique()):
  # raw RGB fundus
  img_path   = f"{img_dir}/{img_name}"

  # green residual
  g_res_path = f"segmentation/images/{img_name}"

  # Load
  img_bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
  g_res   = cv2.imread(g_res_path, cv2.IMREAD_GRAYSCALE)

  # masks
  m_major_vessel = cv2.imread(f"segmentation/major_vessels/{img_name}", cv2.IMREAD_GRAYSCALE)
  m_minor_vessel = cv2.imread(f"segmentation/minor_vessels/{img_name}", cv2.IMREAD_GRAYSCALE)
  m_od           = cv2.imread(f"segmentation/optic_disc/{img_name}", cv2.IMREAD_GRAYSCALE)

  # Ground truth DataFrame (full table)
  gt_df = df

  df_feat = extract_features(
      img_bgr=img_bgr,
      g_res=g_res,
      image_name=img_name,
      window=17,
      stride=5,
      m_major_vessel=m_major_vessel,
      m_minor_vessel=m_minor_vessel,
      m_od=m_od,
      gt_df=gt_df
  )

  df_list.append(df_feat)

In [None]:
df_main = pd.concat(df_list)
df_main

In [None]:
df_main.shape

In [None]:
df_main.columns

In [None]:
df_main['lesion'].value_counts()

In [None]:
df_main['lesion_group'].value_counts()

In [None]:
df_main['lesion_type'].value_counts()

In [None]:
df_main.to_csv("df_main_stride_5.csv", index=False)

# <b>7. Model Building</b>

## 7.1 Bright

### 7.1.1 Bright vs Background

#### Logistic Regression

In [None]:
# ===== Stage 1 — Bright vs Background (Logistic Regression) =====
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)
import matplotlib.pyplot as plt

SEED = 42
TEST_SIZE = 0.20
N_BG_BRIGHT = 50_000  # undersample size for background (tweak if needed)

# 1) Build dataset (undersample background + all bright), make target, shuffle
df_bg = df_main[df_main['lesion_group'] == 'background']
df_bright = df_main[df_main['lesion_group'] == 'bright']
df_bg_s = df_bg.sample(n=min(N_BG_BRIGHT, len(df_bg)), random_state=SEED)

df_s1_bright = pd.concat([df_bg_s, df_bright], ignore_index=True)
df_s1_bright['target'] = np.where(df_s1_bright['lesion_group'] == 'bright', 'bright', 'background')
df_s1_bright = df_s1_bright.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Leakage-safe split by image
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_bright, groups=df_s1_bright['image']))
train_df = df_s1_bright.iloc[train_idx].copy()
test_df  = df_s1_bright.iloc[test_idx].copy()

# 3) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_bright.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 4) Model: Impute + Scale + Logistic Regression (balanced)
lr_stage1_bright = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('scale', StandardScaler()),
    ('clf', LogisticRegression(
        solver='saga', max_iter=200, n_jobs=-1, class_weight='balanced',
        C=1.0, random_state=SEED
    ))
])

lr_stage1_bright.fit(X_train, y_train)

# 5) Scores (bright = positive), predictions at 0.5
bright_idx = list(lr_stage1_bright.named_steps['clf'].classes_).index('bright')
y_scores = lr_stage1_bright.predict_proba(X_test)[:, bright_idx]
y_pred = np.where(y_scores >= 0.5, 'bright', 'background')

# 6) Metrics (bright positive)
y_true_bin = (y_test == 'bright').astype(int)
y_pred_bin = (y_pred == 'bright').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()

sensitivity = tp / (tp + fn) if (tp + fn) else 0.0   # recall for bright
specificity = tn / (tn + fp) if (tn + fp) else 0.0   # TNR for background
accuracy   = (tp + tn) / (tp + tn + fp + fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 1 — Bright vs Background (LogReg) ===")
print(f"Sensitivity (bright+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (bright+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 7) Confusion matrix plots (counts + normalized)
labels = ['background','bright']
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Bright — Confusion Matrix (counts)')
plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Bright — Confusion Matrix (row-normalized)')
plt.show()

# 8) ROC curve
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
roc_auc_val = auc(fpr, tpr)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc_val:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Stage 1 Bright — ROC Curve (Logistic Regression)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.show()

# 9) Save for model comparison later
results_lr_b1 = {
    'stage': 'Stage 1 — Bright vs Background',
    'model': 'LogReg',
    'sensitivity': sensitivity,
    'specificity': specificity,
    'accuracy': accuracy,
    'roc_auc': roc_auc,
    'pr_auc': pr_auc,
    'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}

#### Random Forest

In [None]:
# ===== Stage 1 — Bright vs Background (Random Forest) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
N_BG_BRIGHT = 50_000   # undersample background size (tweak for speed/balance)

# 1) Build dataset (undersample background + all bright), make target, shuffle
df_bg = df_main[df_main['lesion_group'] == 'background']
df_bright = df_main[df_main['lesion_group'] == 'bright']
df_bg_s = df_bg.sample(n=min(N_BG_BRIGHT, len(df_bg)), random_state=SEED)

df_s1_bright = pd.concat([df_bg_s, df_bright], ignore_index=True)
df_s1_bright['target'] = np.where(df_s1_bright['lesion_group'] == 'bright', 'bright', 'background')
df_s1_bright = df_s1_bright.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Leakage-safe split by image
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_bright, groups=df_s1_bright['image']))
train_df = df_s1_bright.iloc[train_idx].copy()
test_df  = df_s1_bright.iloc[test_idx].copy()

# 3) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_bright.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 4) Model: Impute -> RandomForest (trees don’t need scaling)
rf_stage1_bright = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('clf', RandomForestClassifier(
        n_estimators=300,
        max_depth=None,                 # set to e.g. 16 if you want extra regularization
        min_samples_leaf=1,
        n_jobs=-1,
        class_weight='balanced_subsample',
        random_state=SEED
    ))
])

rf_stage1_bright.fit(X_train, y_train)

# 5) Scores & predictions (bright = positive, threshold 0.5)
bright_idx = list(rf_stage1_bright.named_steps['clf'].classes_).index('bright')
y_scores = rf_stage1_bright.predict_proba(X_test)[:, bright_idx]
y_pred = np.where(y_scores >= 0.5, 'bright', 'background')

# 6) Metrics (bright positive)
y_true_bin = (y_test == 'bright').astype(int)
y_pred_bin = (y_pred == 'bright').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()

sensitivity = tp / (tp + fn) if (tp + fn) else 0.0   # recall for bright
specificity = tn / (tn + fp) if (tn + fp) else 0.0   # TNR for background
accuracy   = (tp + tn) / (tp + tn + fp + fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 1 — Bright vs Background (Random Forest) ===")
print(f"Sensitivity (bright+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (bright+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 7) Confusion matrix plots (counts + normalized)
labels = ['background','bright']
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Bright — Confusion Matrix (counts)')
plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Bright — Confusion Matrix (row-normalized)')
plt.show()

# 8) ROC curve
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
roc_auc_val = auc(fpr, tpr)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc_val:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Stage 1 Bright — ROC Curve (Random Forest)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.show()

# 9) Save for model comparison later
results_rf_b1 = {
    'stage': 'Stage 1 — Bright vs Background',
    'model': 'RandomForest',
    'sensitivity': sensitivity,
    'specificity': specificity,
    'accuracy': accuracy,
    'roc_auc': roc_auc,
    'pr_auc': pr_auc,
    'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}


#### XGBoost

In [None]:
# ===== Stage 1 — Bright vs Background (XGBoost — simple) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_auc_score,
    average_precision_score, ConfusionMatrixDisplay, roc_curve, auc
)
from xgboost import XGBClassifier

SEED = 42
TEST_SIZE = 0.20
N_BG_BRIGHT = 50_000  # undersample background size (tweak for speed/balance)

# 1) Build dataset (undersample background + all bright), make target, shuffle
df_bg = df_main[df_main['lesion_group'] == 'background']
df_bright = df_main[df_main['lesion_group'] == 'bright']
df_bg_s = df_bg.sample(n=min(N_BG_BRIGHT, len(df_bg)), random_state=SEED)

df_s1_bright = pd.concat([df_bg_s, df_bright], ignore_index=True)
df_s1_bright['target'] = np.where(df_s1_bright['lesion_group']=='bright','bright','background')
df_s1_bright = df_s1_bright.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Leakage-safe split by image
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_bright, groups=df_s1_bright['image']))
train_df = df_s1_bright.iloc[train_idx].copy()
test_df  = df_s1_bright.iloc[test_idx].copy()

# 3) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_bright.columns if c not in drop_cols]

X_train_full, y_train_full_lbl = train_df[feat_cols], train_df['target']
X_test,         y_test_lbl     = test_df[feat_cols],  test_df['target']

# Binarize labels for XGBoost (bright=1, background=0)
y_train = (y_train_full_lbl == 'bright').astype(int).values
y_test  = (y_test_lbl == 'bright').astype(int).values

# 4) Impute NaNs (trees don't need scaling)
imp = SimpleImputer(strategy='median')
X_train_np = imp.fit_transform(X_train_full)
X_test_np  = imp.transform(X_test)

# 5) Class imbalance handling: scale_pos_weight = (neg/pos) on TRAIN
pos = int(y_train.sum())
neg = int(len(y_train) - pos)
scale_pos_weight = (neg / pos) if pos > 0 else 1.0
print(f"scale_pos_weight (train): {scale_pos_weight:.3f}  (neg={neg}, pos={pos})")

# 6) Model (simple, version-friendly)
xgb_stage1_bright = XGBClassifier(
    objective='binary:logistic',
    max_depth=6,
    n_estimators=400,          # keep modest; raise if you want
    learning_rate=0.10,        # a bit higher since no early stopping
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=1.0,
    scale_pos_weight=scale_pos_weight,
    n_jobs=-1,
    random_state=SEED
)

# 7) Fit
xgb_stage1_bright.fit(X_train_np, y_train)

# 8) Scores & predictions (bright = positive)
y_scores = xgb_stage1_bright.predict_proba(X_test_np)[:, 1]
y_pred   = (y_scores >= 0.5).astype(int)  # 1=bright, 0=background

# 9) Metrics
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
sensitivity = tp / (tp + fn) if (tp + fn) else 0.0   # recall for bright
specificity = tn / (tn + fp) if (tn + fp) else 0.0   # TNR for background
accuracy   = (tp + tn) / (tp + tn + fp + fn)
roc_auc    = roc_auc_score(y_test, y_scores)
pr_auc     = average_precision_score(y_test, y_scores)

print("\n=== Stage 1 — Bright vs Background (XGBoost — simple) ===")
print(f"Sensitivity (bright+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (bright+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(
    pd.Series(np.where(y_test==1,'bright','background')),
    pd.Series(np.where(y_pred==1,'bright','background')),
    digits=4
))

# 10) Confusion matrix (counts + normalized)
labels = ['background','bright']
cm = confusion_matrix(np.where(y_test==1,'bright','background'),
                      np.where(y_pred==1,'bright','background'),
                      labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Bright — Confusion Matrix (counts, XGBoost)')
plt.show()

cm_norm = confusion_matrix(np.where(y_test==1,'bright','background'),
                           np.where(y_pred==1,'bright','background'),
                           labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Bright — Confusion Matrix (row-normalized, XGBoost)')
plt.show()

# 11) ROC curve
fpr, tpr, _ = roc_curve(y_test, y_scores)
roc_auc_val = auc(fpr, tpr)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc_val:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Stage 1 Bright — ROC Curve (XGBoost)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.show()

# 12) Save for model comparison later
results_xgb_b1 = {
    'stage': 'Stage 1 — Bright vs Background',
    'model': 'XGBoost',
    'sensitivity': sensitivity,
    'specificity': specificity,
    'accuracy': accuracy,
    'roc_auc': roc_auc,
    'pr_auc': pr_auc,
    'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}


#### Model Comparison

In [None]:
import pandas as pd

# Collect your saved result dicts here
rows = [results_lr_b1, results_rf_b1, results_xgb_b1]  # LR / RF / XGB (Stage 1 Bright)

# Build comparison table
tbl = pd.DataFrame([
    {
        "Model": r["model"],
        "Sensitivity (bright)": r["sensitivity"],
        "Specificity (background)": r["specificity"],
        "Balanced Accuracy": 0.5 * (r["sensitivity"] + r["specificity"]),
        "Accuracy": r["accuracy"],
        "ROC AUC": r["roc_auc"],
        "PR AUC (bright)": r["pr_auc"],
        "TP": r["tp"], "FP": r["fp"], "TN": r["tn"], "FN": r["fn"]
    }
    for r in rows
]).round(4).sort_values("ROC AUC", ascending=False)

print(tbl)

# Optional: pretty display (Jupyter)
try:
    display(tbl.style.highlight_max(subset=["Sensitivity (bright)", "Specificity (background)",
                                            "Balanced Accuracy", "Accuracy", "ROC AUC", "PR AUC (bright)"],
                                    color="#d5f5e3"))
except:
    pass

# Export for your report
tbl.to_csv("stage1_bright_model_comparison.csv", index=False)
print("\nLaTeX (paste into your paper):\n")
print(tbl.to_latex(index=False, float_format="%.4f", caption="Stage 1 (Bright vs Background) — Model comparison",
                   label="tab:s1_bright_models"))


### 7.1.2 Bright: EX vs CWS

#### Logistic Regression

In [None]:
# ===== Stage 2 — Bright (EX vs CWS) — Logistic Regression =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
EX_TO_CWS_RATIO_IN_TRAIN = 2.0   # set None to skip train-only rebalance

# 1) Filter: bright-only, EX vs CWS
b2 = df_main[df_main['lesion_group']=='bright'].copy()
b2 = b2[b2['lesion_type'].isin(['EX','CWS'])].copy()
b2['target'] = b2['lesion_type']

# 2) Group-safe split by image
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(b2, groups=b2['image']))
train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()

# optional: ensure both classes appear in test; reshuffle once if not
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(b2, groups=b2['image']))
    train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()

# 3) Train-only rebalance (undersample EX)
if EX_TO_CWS_RATIO_IN_TRAIN is not None:
    ex_tr  = train_df[train_df['target']=='EX']
    cws_tr = train_df[train_df['target']=='CWS']
    target_ex = int(min(len(ex_tr), EX_TO_CWS_RATIO_IN_TRAIN * len(cws_tr))) if len(cws_tr)>0 else len(ex_tr)
    ex_down = ex_tr.sample(n=target_ex, random_state=SEED) if len(ex_tr) > target_ex else ex_tr
    train_df = pd.concat([ex_down, cws_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in b2.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 5) Model: Impute + Scale + Logistic Regression
lr_stage2_bright = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('scale', StandardScaler()),
    ('clf', LogisticRegression(
        solver='lbfgs', max_iter=1000, class_weight='balanced',
        C=1.0, random_state=SEED
    ))
])

lr_stage2_bright.fit(X_train, y_train)

# 6) Scores & predictions (EX = positive)
ex_idx = list(lr_stage2_bright.named_steps['clf'].classes_).index('EX')
y_scores = lr_stage2_bright.predict_proba(X_test)[:, ex_idx]
y_pred   = np.where(y_scores >= 0.5, 'EX', 'CWS')

# 7) Metrics
y_true_bin = (y_test == 'EX').astype(int)
y_pred_bin = (y_pred == 'EX').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0    # EX recall
specificity = tn/(tn+fp) if (tn+fp) else 0.0    # CWS TNR
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 2 — Bright (EX vs CWS) — Logistic Regression ===")
print(f"Sensitivity (EX+): {sensitivity:.4f}")
print(f"Specificity (CWS): {specificity:.4f}")
print(f"Accuracy:          {accuracy:.4f}")
print(f"ROC-AUC:           {roc_auc:.4f}")
print(f"PR-AUC (EX+):      {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 8) Confusion matrices
labels = ['CWS','EX']  # order shown
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Bright — Confusion Matrix (counts, LR)')
plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Bright — Confusion Matrix (row-normalized, LR)')
plt.show()

# 9) ROC
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 2 Bright — ROC (LR)')
plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 10) Save results dict
results_lr_b2 = {
    'stage':'Stage 2 — Bright (EX vs CWS)', 'model':'LogReg',
    'sensitivity':sensitivity,'specificity':specificity,'accuracy':accuracy,
    'roc_auc':roc_auc,'pr_auc':pr_auc,'tp':tp,'tn':tn,'fp':fp,'fn':fn
}


#### Random Forest

In [None]:
# ===== Stage 2 — Bright (EX vs CWS) — Random Forest =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
EX_TO_CWS_RATIO_IN_TRAIN = 2.0   # set None to skip

# 1) Data (bright-only, EX vs CWS)
b2 = df_main[df_main['lesion_group']=='bright'].copy()
b2 = b2[b2['lesion_type'].isin(['EX','CWS'])].copy()
b2['target'] = b2['lesion_type']

# 2) Group-safe split
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(b2, groups=b2['image']))
train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(b2, groups=b2['image']))
    train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()

# 3) Train-only rebalance (EX downsample)
if EX_TO_CWS_RATIO_IN_TRAIN is not None:
    ex_tr  = train_df[train_df['target']=='EX']
    cws_tr = train_df[train_df['target']=='CWS']
    target_ex = int(min(len(ex_tr), EX_TO_CWS_RATIO_IN_TRAIN * len(cws_tr))) if len(cws_tr)>0 else len(ex_tr)
    ex_down = ex_tr.sample(n=target_ex, random_state=SEED) if len(ex_tr) > target_ex else ex_tr
    train_df = pd.concat([ex_down, cws_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in b2.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 5) Model: Impute -> RandomForest
rf_stage2_bright = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('clf', RandomForestClassifier(
        n_estimators=300, max_depth=None, min_samples_leaf=1,
        n_jobs=-1, class_weight='balanced_subsample', random_state=SEED
    ))
])
rf_stage2_bright.fit(X_train, y_train)

# 6) Scores & predictions (EX positive)
ex_idx = list(rf_stage2_bright.named_steps['clf'].classes_).index('EX')
y_scores = rf_stage2_bright.predict_proba(X_test)[:, ex_idx]
y_pred   = np.where(y_scores >= 0.5, 'EX', 'CWS')

# 7) Metrics
y_true_bin = (y_test == 'EX').astype(int)
y_pred_bin = (y_pred == 'EX').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0
specificity = tn/(tn+fp) if (tn+fp) else 0.0
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 2 — Bright (EX vs CWS) — Random Forest ===")
print(f"Sensitivity (EX+): {sensitivity:.4f}")
print(f"Specificity (CWS): {specificity:.4f}")
print(f"Accuracy:          {accuracy:.4f}")
print(f"ROC-AUC:           {roc_auc:.4f}")
print(f"PR-AUC (EX+):      {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 8) Confusion matrices
labels = ['CWS','EX']
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Bright — Confusion Matrix (counts, RF)'); plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Bright — Confusion Matrix (row-normalized, RF)'); plt.show()

# 9) ROC
from sklearn.metrics import roc_curve, auc
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 2 Bright — ROC (RF)')
plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 10) Save results
results_rf_b2 = {
    'stage':'Stage 2 — Bright (EX vs CWS)', 'model':'RandomForest',
    'sensitivity':sensitivity,'specificity':specificity,'accuracy':accuracy,
    'roc_auc':roc_auc,'pr_auc':pr_auc,'tp':tp,'tn':tn,'fp':fp,'fn':fn
}


#### XGBoost

In [None]:
# ===== Stage 2 — Bright (EX vs CWS) — XGBoost (simple) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

from xgboost import XGBClassifier

SEED = 42
TEST_SIZE = 0.20
EX_TO_CWS_RATIO_IN_TRAIN = 2.0   # set None to skip

# 1) Data (bright-only, EX vs CWS)
b2 = df_main[df_main['lesion_group']=='bright'].copy()
b2 = b2[b2['lesion_type'].isin(['EX','CWS'])].copy()
b2['target'] = b2['lesion_type']

# 2) Group-safe split
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(b2, groups=b2['image']))
train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(b2, groups=b2['image']))
    train_df, test_df = b2.iloc[train_idx].copy(), b2.iloc[test_idx].copy()

# 3) Train-only rebalance (EX downsample)
if EX_TO_CWS_RATIO_IN_TRAIN is not None:
    ex_tr  = train_df[train_df['target']=='EX']
    cws_tr = train_df[train_df['target']=='CWS']
    target_ex = int(min(len(ex_tr), EX_TO_CWS_RATIO_IN_TRAIN * len(cws_tr))) if len(cws_tr)>0 else len(ex_tr)
    ex_down = ex_tr.sample(n=target_ex, random_state=SEED) if len(ex_tr) > target_ex else ex_tr
    train_df = pd.concat([ex_down, cws_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in b2.columns if c not in drop_cols]
X_train, y_train_lbl = train_df[feat_cols], train_df['target']
X_test,  y_test_lbl  = test_df[feat_cols],  test_df['target']

# Binarize (EX=1, CWS=0) for XGB
y_train = (y_train_lbl == 'EX').astype(int).values
y_test  = (y_test_lbl == 'EX').astype(int).values

# 5) Impute (trees don't need scaling)
imp = SimpleImputer(strategy='median')
X_train_np = imp.fit_transform(X_train)
X_test_np  = imp.transform(X_test)

# 6) Imbalance weight from TRAIN
pos = int(y_train.sum()); neg = int(len(y_train) - pos)
scale_pos_weight = (neg / pos) if pos > 0 else 1.0
print(f"scale_pos_weight (train): {scale_pos_weight:.3f}  (neg={neg}, pos={pos})")

# 7) Model (simple; no early stopping to avoid version differences)
xgb_stage2_bright = XGBClassifier(
    objective='binary:logistic',
    max_depth=6,
    n_estimators=400,
    learning_rate=0.10,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=1.0,
    scale_pos_weight=scale_pos_weight,
    n_jobs=-1,
    random_state=SEED
)
xgb_stage2_bright.fit(X_train_np, y_train)

# 8) Scores & predictions (EX positive)
y_scores = xgb_stage2_bright.predict_proba(X_test_np)[:, 1]
y_pred_bin = (y_scores >= 0.5).astype(int)
y_pred_lbl = np.where(y_pred_bin==1, 'EX', 'CWS')

# 9) Metrics
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0
specificity = tn/(tn+fp) if (tn+fp) else 0.0
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_test, y_scores)
pr_auc     = average_precision_score(y_test, y_scores)

print("\n=== Stage 2 — Bright (EX vs CWS) — XGBoost (simple) ===")
print(f"Sensitivity (EX+): {sensitivity:.4f}")
print(f"Specificity (CWS): {specificity:.4f}")
print(f"Accuracy:          {accuracy:.4f}")
print(f"ROC-AUC:           {roc_auc:.4f}")
print(f"PR-AUC (EX+):      {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test_lbl, y_pred_lbl, digits=4))

# 10) Confusion matrices
labels = ['CWS','EX']
cm = confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Bright — Confusion Matrix (counts, XGB)'); plt.show()

cm_norm = confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Bright — Confusion Matrix (row-normalized, XGB)'); plt.show()

# 11) ROC
fpr, tpr, _ = roc_curve(y_test, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 2 Bright — ROC (XGBoost)')
plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 12) Save results
results_xgb_b2 = {
    'stage':'Stage 2 — Bright (EX vs CWS)', 'model':'XGBoost',
    'sensitivity':sensitivity,'specificity':specificity,'accuracy':accuracy,
    'roc_auc':roc_auc,'pr_auc':pr_auc,'tp':tp,'tn':tn,'fp':fp,'fn':fn
}


#### Model Comparison

In [None]:
import pandas as pd

def make_model_comparison(rows, stage_name, sort_by="ROC AUC"):
    """
    rows: list of results_* dicts like {'model','sensitivity','specificity','accuracy','roc_auc','pr_auc','tp','fp','tn','fn'}
    stage_name: string used for filenames/captions
    sort_by: which metric to sort by (e.g., 'ROC AUC', 'Accuracy', 'Balanced Accuracy')
    """
    tbl = pd.DataFrame([
        {
            "Model": r["model"],
            "Sensitivity": r["sensitivity"],
            "Specificity": r["specificity"],
            "Balanced Accuracy": 0.5 * (r["sensitivity"] + r["specificity"]),
            "Accuracy": r["accuracy"],
            "ROC AUC": r["roc_auc"],
            "PR AUC": r["pr_auc"],
            "TP": r["tp"], "FP": r["fp"], "TN": r["tn"], "FN": r["fn"]
        }
        for r in rows
    ]).round(4).sort_values(sort_by, ascending=False)

    print(tbl)

    # Optional pretty display in notebooks
    try:
        display(tbl.style.highlight_max(
            subset=["Sensitivity","Specificity","Balanced Accuracy","Accuracy","ROC AUC","PR AUC"],
            color="#d5f5e3"
        ))
    except Exception:
        pass

    # Save
    csv_name = stage_name.lower().replace(" ", "_").replace("/", "_") + "_model_comparison.csv"
    tbl.to_csv(csv_name, index=False)

    # LaTeX
    print("\nLaTeX (paste into your paper):\n")
    print(tbl.to_latex(index=False, float_format="%.4f",
                       caption=f"{stage_name} — Model comparison",
                       label="tab:" + stage_name.lower().replace(" ", "_").replace("/", "_")))
    return tbl

# ==== Usage examples ====
# Stage 1 Bright vs Background
_ = make_model_comparison(
    [results_lr_b1, results_rf_b1, results_xgb_b1],
    stage_name="Stage 1 (Bright vs Background)"
)

# Stage 2 Bright (EX vs CWS)
_ = make_model_comparison(
    [results_lr_b2, results_rf_b2, results_xgb_b2],
    stage_name="Stage 2 Bright (EX vs CWS)"
)


## 7.2 Red

### 7.2.1 Red vs Background

#### Logistic Regression

In [None]:
# ===== Stage 1 — Red vs Background (Logistic Regression) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
N_BG_RED = 50_000  # undersample background (tweak for speed/balance)

# 1) Build dataset (undersample background + all red)
df_bg = df_main[df_main['lesion_group'] == 'background']
df_red = df_main[df_main['lesion_group'] == 'red']
df_bg_s = df_bg.sample(n=min(N_BG_RED, len(df_bg)), random_state=SEED)

df_s1_red = pd.concat([df_bg_s, df_red], ignore_index=True)
df_s1_red['target'] = np.where(df_s1_red['lesion_group']=='red', 'red', 'background')
df_s1_red = df_s1_red.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Leakage-safe split by image
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_red, groups=df_s1_red['image']))
train_df, test_df = df_s1_red.iloc[train_idx].copy(), df_s1_red.iloc[test_idx].copy()

# 3) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_red.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 4) Model: Impute + Scale + Logistic Regression (balanced)
lr_stage1_red = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('scale', StandardScaler()),
    ('clf', LogisticRegression(
        solver='lbfgs', max_iter=1000, class_weight='balanced',
        C=1.0, random_state=SEED
    ))
])
lr_stage1_red.fit(X_train, y_train)

# 5) Scores & predictions (red = positive)
red_idx = list(lr_stage1_red.named_steps['clf'].classes_).index('red')
y_scores = lr_stage1_red.predict_proba(X_test)[:, red_idx]
y_pred   = np.where(y_scores >= 0.5, 'red', 'background')

# 6) Metrics
y_true_bin = (y_test == 'red').astype(int)
y_pred_bin = (y_pred == 'red').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0   # red recall
specificity = tn/(tn+fp) if (tn+fp) else 0.0   # background TNR
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 1 — Red vs Background (LogReg) ===")
print(f"Sensitivity (red+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (red+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 7) Confusion matrices
labels = ['background','red']
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Red — Confusion Matrix (counts, LR)'); plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Red — Confusion Matrix (row-normalized, LR)'); plt.show()

# 8) ROC
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 1 Red — ROC (LR)'); plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 9) Save results
results_lr_r1 = {
    'stage': 'Stage 1 — Red vs Background', 'model': 'LogReg',
    'sensitivity': sensitivity, 'specificity': specificity, 'accuracy': accuracy,
    'roc_auc': roc_auc, 'pr_auc': pr_auc, 'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}


#### Random Forest

In [None]:
# ===== Stage 1 — Red vs Background (Random Forest) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
N_BG_RED = 50_000

# 1) Build dataset
df_bg = df_main[df_main['lesion_group'] == 'background']
df_red = df_main[df_main['lesion_group'] == 'red']
df_bg_s = df_bg.sample(n=min(N_BG_RED, len(df_bg)), random_state=SEED)

df_s1_red = pd.concat([df_bg_s, df_red], ignore_index=True)
df_s1_red['target'] = np.where(df_s1_red['lesion_group']=='red', 'red', 'background')
df_s1_red = df_s1_red.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Group-safe split
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_red, groups=df_s1_red['image']))
train_df, test_df = df_s1_red.iloc[train_idx].copy(), df_s1_red.iloc[test_idx].copy()

# 3) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_red.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 4) Model: Impute -> RF (balanced subsample)
rf_stage1_red = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('clf', RandomForestClassifier(
        n_estimators=300, max_depth=None, min_samples_leaf=1,
        n_jobs=-1, class_weight='balanced_subsample', random_state=SEED
    ))
])
rf_stage1_red.fit(X_train, y_train)

# 5) Scores & predictions (red positive)
red_idx = list(rf_stage1_red.named_steps['clf'].classes_).index('red')
y_scores = rf_stage1_red.predict_proba(X_test)[:, red_idx]
y_pred   = np.where(y_scores >= 0.5, 'red', 'background')

# 6) Metrics
y_true_bin = (y_test == 'red').astype(int)
y_pred_bin = (y_pred == 'red').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0
specificity = tn/(tn+fp) if (tn+fp) else 0.0
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_true_bin, y_scores)
pr_auc     = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 1 — Red vs Background (Random Forest) ===")
print(f"Sensitivity (red+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (red+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test, y_pred, digits=4))

# 7) Confusion matrices
labels = ['background','red']
cm = confusion_matrix(y_test, y_pred, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Red — Confusion Matrix (counts, RF)'); plt.show()

cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Red — Confusion Matrix (row-normalized, RF)'); plt.show()

# 8) ROC
fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 1 Red — ROC (RF)'); plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 9) Save results
results_rf_r1 = {
    'stage': 'Stage 1 — Red vs Background', 'model': 'RandomForest',
    'sensitivity': sensitivity, 'specificity': specificity, 'accuracy': accuracy,
    'roc_auc': roc_auc, 'pr_auc': pr_auc, 'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}


#### XGBoost

In [None]:
# ===== Stage 1 — Red vs Background (XGBoost — simple) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)
from xgboost import XGBClassifier

SEED = 42
TEST_SIZE = 0.20
N_BG_RED = 50_000  # undersample background

# 1) Build dataset
df_bg = df_main[df_main['lesion_group'] == 'background']
df_red = df_main[df_main['lesion_group'] == 'red']
df_bg_s = df_bg.sample(n=min(N_BG_RED, len(df_bg)), random_state=SEED)

df_s1_red = pd.concat([df_bg_s, df_red], ignore_index=True)
df_s1_red['target'] = np.where(df_s1_red['lesion_group']=='red', 'red', 'background')
df_s1_red = df_s1_red.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 2) Group-safe split
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(df_s1_red, groups=df_s1_red['image']))
train_df, test_df = df_s1_red.iloc[train_idx].copy(), df_s1_red.iloc[test_idx].copy()

# 3) Features / labels, binarize (red=1)
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in df_s1_red.columns if c not in drop_cols]
X_train, y_train_lbl = train_df[feat_cols], train_df['target']
X_test,  y_test_lbl  = test_df[feat_cols],  test_df['target']

y_train = (y_train_lbl == 'red').astype(int).values
y_test  = (y_test_lbl == 'red').astype(int).values

# 4) Impute (trees don't need scaling)
imp = SimpleImputer(strategy='median')
X_train_np = imp.fit_transform(X_train)
X_test_np  = imp.transform(X_test)

# 5) Imbalance weight from TRAIN
pos = int(y_train.sum()); neg = int(len(y_train) - pos)
scale_pos_weight = (neg / pos) if pos > 0 else 1.0
print(f"scale_pos_weight (train): {scale_pos_weight:.3f}  (neg={neg}, pos={pos})")

# 6) Model (simple; no early stopping)
xgb_stage1_red = XGBClassifier(
    objective='binary:logistic',
    max_depth=6,
    n_estimators=400,
    learning_rate=0.10,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=1.0,
    scale_pos_weight=scale_pos_weight,
    n_jobs=-1,
    random_state=SEED
)
xgb_stage1_red.fit(X_train_np, y_train)

# 7) Scores & predictions (red positive)
y_scores = xgb_stage1_red.predict_proba(X_test_np)[:, 1]
y_pred_bin = (y_scores >= 0.5).astype(int)
y_pred_lbl = np.where(y_pred_bin==1, 'red', 'background')

# 8) Metrics
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_bin).ravel()
sensitivity = tp/(tp+fn) if (tp+fn) else 0.0
specificity = tn/(tn+fp) if (tn+fp) else 0.0
accuracy   = (tp+tn)/(tp+tn+fp+fn)
roc_auc    = roc_auc_score(y_test, y_scores)
pr_auc     = average_precision_score(y_test, y_scores)

print("\n=== Stage 1 — Red vs Background (XGBoost — simple) ===")
print(f"Sensitivity (red+): {sensitivity:.4f}")
print(f"Specificity (background): {specificity:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (red+): {pr_auc:.4f}")
print("\nClassification report:")
print(classification_report(y_test_lbl, y_pred_lbl, digits=4))

# 9) Confusion matrices
labels = ['background','red']
cm = confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(values_format='d')
plt.title('Stage 1 Red — Confusion Matrix (counts, XGB)'); plt.show()

cm_norm = confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels, normalize='true')
ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels).plot(values_format='.2f')
plt.title('Stage 1 Red — Confusion Matrix (row-normalized, XGB)'); plt.show()

# 10) ROC
fpr, tpr, _ = roc_curve(y_test, y_scores)
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1)
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('Stage 1 Red — ROC (XGBoost)')
plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 11) Save results
results_xgb_r1 = {
    'stage': 'Stage 1 — Red vs Background', 'model': 'XGBoost',
    'sensitivity': sensitivity, 'specificity': specificity, 'accuracy': accuracy,
    'roc_auc': roc_auc, 'pr_auc': pr_auc, 'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
}


#### Model Comparison

In [None]:
import pandas as pd

# Collect result dicts (LogReg / RF / XGB)
rows = [results_lr_r1, results_rf_r1, results_xgb_r1]

# Build comparison table
tbl_r1 = pd.DataFrame([
    {
        "Model": r["model"],
        "Sensitivity (red)": r["sensitivity"],
        "Specificity (background)": r["specificity"],
        "Balanced Accuracy": 0.5 * (r["sensitivity"] + r["specificity"]),
        "Accuracy": r["accuracy"],
        "ROC AUC": r["roc_auc"],
        "PR AUC (red)": r["pr_auc"],
        "TP": r["tp"], "FP": r["fp"], "TN": r["tn"], "FN": r["fn"]
    }
    for r in rows
]).round(4).sort_values("ROC AUC", ascending=False)

# Paper’s AUC definition (avg of sensitivity & specificity)
tbl_r1["AUC (paper def)"] = tbl_r1["Balanced Accuracy"]

print(tbl_r1)

# Optional pretty display (Jupyter)
try:
    display(
        tbl_r1.style.highlight_max(
            subset=["Sensitivity (red)", "Specificity (background)",
                    "Balanced Accuracy", "Accuracy", "ROC AUC", "PR AUC (red)"],
            color="#d5f5e3"
        )
    )
except:
    pass

# Export for your report
tbl_r1.to_csv("stage1_red_model_comparison.csv", index=False)
print("\nLaTeX (paste into your paper):\n")
print(tbl_r1.to_latex(
    index=False,
    float_format="%.4f",
    caption="Stage 1 (Red vs Background) — Model comparison",
    label="tab:s1_red_models"
))


### 7.2.2 Red: MA vs HM

#### Logistic Regression

In [None]:
# ===== Stage 2 — Red (MA vs HM) — Logistic Regression =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
HM_TO_MA_RATIO_IN_TRAIN = 2.0   # set None to skip train-only rebalance

# 1) Data: red-only, normalize HE->HM, keep MA/HM
r2 = df_main[df_main['lesion_group']=='red'].copy()
r2['lesion_type'] = r2['lesion_type'].replace({'HE':'HM'})
r2 = r2[r2['lesion_type'].isin(['MA','HM'])].copy()
r2['target'] = r2['lesion_type']

# 2) Group-safe split by image (+ ensure both classes in test)
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(r2, groups=r2['image']))
train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(r2, groups=r2['image']))
    train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()

# 3) Train-only rebalance: undersample HM to ratio × MA
if HM_TO_MA_RATIO_IN_TRAIN is not None:
    hm_tr = train_df[train_df['target']=='HM']
    ma_tr = train_df[train_df['target']=='MA']
    target_hm = int(min(len(hm_tr), HM_TO_MA_RATIO_IN_TRAIN * len(ma_tr))) if len(ma_tr)>0 else len(hm_tr)
    hm_down = hm_tr.sample(n=target_hm, random_state=SEED) if len(hm_tr) > target_hm else hm_tr
    train_df = pd.concat([hm_down, ma_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in r2.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 5) Model: Impute + Scale + LogReg (balanced)
lr_stage2_red = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('scale', StandardScaler()),
    ('clf', LogisticRegression(
        solver='lbfgs', max_iter=1000, class_weight='balanced',
        C=1.0, random_state=SEED
    ))
]).fit(X_train, y_train)

# 6) Scores & preds (MA positive)
ma_idx = list(lr_stage2_red.named_steps['clf'].classes_).index('MA')
y_scores = lr_stage2_red.predict_proba(X_test)[:, ma_idx]
y_pred   = np.where(y_scores >= 0.5, 'MA', 'HM')

# 7) Metrics
y_true_bin = (y_test == 'MA').astype(int); y_pred_bin = (y_pred == 'MA').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sens = tp/(tp+fn) if (tp+fn) else 0.0
spec = tn/(tn+fp) if (tn+fp) else 0.0
acc  = (tp+tn)/(tp+tn+fp+fn)
rocA = roc_auc_score(y_true_bin, y_scores)
prA  = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 2 — Red (MA vs HM) — Logistic Regression ===")
print(f"Sensitivity (MA+): {sens:.4f}  Specificity (HM): {spec:.4f}  Acc: {acc:.4f}  ROC-AUC: {rocA:.4f}  PR-AUC: {prA:.4f}")
print("\nClassification report:\n", classification_report(y_test, y_pred, digits=4))

# 8) Plots
labels = ['HM','MA']
ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred, labels=labels),
                       display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Red — Confusion Matrix (counts, LR)'); plt.show()

ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred, labels=labels, normalize='true'),
                       display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Red — Confusion Matrix (row-normalized, LR)'); plt.show()

fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4)); plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {rocA:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1); plt.xlabel('FPR'); plt.ylabel('TPR')
plt.title('Stage 2 Red — ROC (LR)'); plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 9) Save results
results_lr_r2 = {'stage':'Stage 2 — Red (MA vs HM)','model':'LogReg',
                 'sensitivity':sens,'specificity':spec,'accuracy':acc,
                 'roc_auc':rocA,'pr_auc':prA,'tp':tp,'tn':tn,'fp':fp,'fn':fn}


#### Random Forest

In [None]:
# ===== Stage 2 — Red (MA vs HM) — Random Forest =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)

SEED = 42
TEST_SIZE = 0.20
HM_TO_MA_RATIO_IN_TRAIN = 2.0

# 1) Data
r2 = df_main[df_main['lesion_group']=='red'].copy()
r2['lesion_type'] = r2['lesion_type'].replace({'HE':'HM'})
r2 = r2[r2['lesion_type'].isin(['MA','HM'])].copy()
r2['target'] = r2['lesion_type']

# 2) Split (group-safe)
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(r2, groups=r2['image']))
train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(r2, groups=r2['image']))
    train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()

# 3) Train-only rebalance
if HM_TO_MA_RATIO_IN_TRAIN is not None:
    hm_tr = train_df[train_df['target']=='HM']
    ma_tr = train_df[train_df['target']=='MA']
    target_hm = int(min(len(hm_tr), HM_TO_MA_RATIO_IN_TRAIN * len(ma_tr))) if len(ma_tr)>0 else len(hm_tr)
    hm_down = hm_tr.sample(n=target_hm, random_state=SEED) if len(hm_tr) > target_hm else hm_tr
    train_df = pd.concat([hm_down, ma_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in r2.columns if c not in drop_cols]
X_train, y_train = train_df[feat_cols], train_df['target']
X_test,  y_test  = test_df[feat_cols],  test_df['target']

# 5) Model
rf_stage2_red = Pipeline([
    ('impute', SimpleImputer(strategy='median')),
    ('clf', RandomForestClassifier(
        n_estimators=300, max_depth=None, min_samples_leaf=1,
        n_jobs=-1, class_weight='balanced_subsample', random_state=SEED
    ))
]).fit(X_train, y_train)

# 6) Scores & preds (MA positive)
ma_idx = list(rf_stage2_red.named_steps['clf'].classes_).index('MA')
y_scores = rf_stage2_red.predict_proba(X_test)[:, ma_idx]
y_pred   = np.where(y_scores >= 0.5, 'MA', 'HM')

# 7) Metrics
y_true_bin = (y_test == 'MA').astype(int); y_pred_bin = (y_pred == 'MA').astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
sens = tp/(tp+fn) if (tp+fn) else 0.0
spec = tn/(tn+fp) if (tn+fp) else 0.0
acc  = (tp+tn)/(tp+tn+fp+fn)
rocA = roc_auc_score(y_true_bin, y_scores)
prA  = average_precision_score(y_true_bin, y_scores)

print("\n=== Stage 2 — Red (MA vs HM) — Random Forest ===")
print(f"Sensitivity (MA+): {sens:.4f}  Specificity (HM): {spec:.4f}  Acc: {acc:.4f}  ROC-AUC: {rocA:.4f}  PR-AUC: {prA:.4f}")
print("\nClassification report:\n", classification_report(y_test, y_pred, digits=4))

# 8) Plots
labels = ['HM','MA']
ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred, labels=labels),
                       display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Red — Confusion Matrix (counts, RF)'); plt.show()

ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred, labels=labels, normalize='true'),
                       display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Red — Confusion Matrix (row-normalized, RF)'); plt.show()

fpr, tpr, _ = roc_curve(y_true_bin, y_scores)
plt.figure(figsize=(5,4)); plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {rocA:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1); plt.xlabel('FPR'); plt.ylabel('TPR')
plt.title('Stage 2 Red — ROC (RF)'); plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 9) Save results
results_rf_r2 = {'stage':'Stage 2 — Red (MA vs HM)','model':'RandomForest',
                 'sensitivity':sens,'specificity':spec,'accuracy':acc,
                 'roc_auc':rocA,'pr_auc':prA,'tp':tp,'tn':tn,'fp':fp,'fn':fn}


#### XGBoost

In [None]:
# ===== Stage 2 — Red (MA vs HM) — XGBoost (simple) =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (confusion_matrix, classification_report,
                             roc_auc_score, average_precision_score,
                             ConfusionMatrixDisplay, roc_curve, auc)
from xgboost import XGBClassifier

SEED = 42
TEST_SIZE = 0.20
HM_TO_MA_RATIO_IN_TRAIN = 2.0  # set None to skip rebalance

# 1) Data
r2 = df_main[df_main['lesion_group']=='red'].copy()
r2['lesion_type'] = r2['lesion_type'].replace({'HE':'HM'})
r2 = r2[r2['lesion_type'].isin(['MA','HM'])].copy()
r2['target'] = r2['lesion_type']

# 2) Split (group-safe)
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, test_idx = next(gss.split(r2, groups=r2['image']))
train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()
if test_df['target'].nunique() < 2:
    train_idx, test_idx = next(GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED+1)
                               .split(r2, groups=r2['image']))
    train_df, test_df = r2.iloc[train_idx].copy(), r2.iloc[test_idx].copy()

# 3) Train-only rebalance
if HM_TO_MA_RATIO_IN_TRAIN is not None:
    hm_tr = train_df[train_df['target']=='HM']
    ma_tr = train_df[train_df['target']=='MA']
    target_hm = int(min(len(hm_tr), HM_TO_MA_RATIO_IN_TRAIN * len(ma_tr))) if len(ma_tr)>0 else len(hm_tr)
    hm_down = hm_tr.sample(n=target_hm, random_state=SEED) if len(hm_tr) > target_hm else hm_tr
    train_df = pd.concat([hm_down, ma_tr], ignore_index=True).sample(frac=1, random_state=SEED).reset_index(drop=True)

# 4) Features / labels (MA=1, HM=0)
drop_cols = ['image','x','y','lesion','lesion_group','lesion_type','target']
feat_cols = [c for c in r2.columns if c not in drop_cols]
X_train, y_train_lbl = train_df[feat_cols], train_df['target']
X_test,  y_test_lbl  = test_df[feat_cols],  test_df['target']

y_train = (y_train_lbl == 'MA').astype(int).values
y_test  = (y_test_lbl == 'MA').astype(int).values

# 5) Impute (trees don’t need scaling)
imp = SimpleImputer(strategy='median')
X_train_np = imp.fit_transform(X_train)
X_test_np  = imp.transform(X_test)

# 6) Imbalance weight from TRAIN
pos = int(y_train.sum()); neg = int(len(y_train) - pos)
scale_pos_weight = (neg / pos) if pos > 0 else 1.0
print(f"scale_pos_weight (train): {scale_pos_weight:.3f}  (neg={neg}, pos={pos})")

# 7) Model (simple; no early stopping)
xgb_stage2_red = XGBClassifier(
    objective='binary:logistic',
    max_depth=6,
    n_estimators=400,
    learning_rate=0.10,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=1.0,
    scale_pos_weight=scale_pos_weight,
    n_jobs=-1,
    random_state=SEED
).fit(X_train_np, y_train)

# 8) Scores & preds (MA positive)
y_scores = xgb_stage2_red.predict_proba(X_test_np)[:, 1]
y_pred_bin = (y_scores >= 0.5).astype(int)
y_pred_lbl = np.where(y_pred_bin==1, 'MA', 'HM')

# 9) Metrics
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_bin).ravel()
sens = tp/(tp+fn) if (tp+fn) else 0.0
spec = tn/(tn+fp) if (tn+fp) else 0.0
acc  = (tp+tn)/(tp+tn+fp+fn)
rocA = roc_auc_score(y_test, y_scores)
prA  = average_precision_score(y_test, y_scores)

print("\n=== Stage 2 — Red (MA vs HM) — XGBoost (simple) ===")
print(f"Sensitivity (MA+): {sens:.4f}  Specificity (HM): {spec:.4f}  Acc: {acc:.4f}  ROC-AUC: {rocA:.4f}  PR-AUC: {prA:.4f}")
print("\nClassification report:\n", classification_report(y_test_lbl, y_pred_lbl, digits=4))

# 10) Plots
labels = ['HM','MA']
ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels),
                       display_labels=labels).plot(values_format='d')
plt.title('Stage 2 Red — Confusion Matrix (counts, XGB)'); plt.show()

ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test_lbl, y_pred_lbl, labels=labels, normalize='true'),
                       display_labels=labels).plot(values_format='.2f')
plt.title('Stage 2 Red — Confusion Matrix (row-normalized, XGB)'); plt.show()

fpr, tpr, _ = roc_curve(y_test, y_scores)
plt.figure(figsize=(5,4)); plt.plot(fpr, tpr, linewidth=2, label=f"AUC = {rocA:.3f}")
plt.plot([0,1],[0,1],'--', linewidth=1); plt.xlabel('FPR'); plt.ylabel('TPR')
plt.title('Stage 2 Red — ROC (XGBoost)'); plt.legend(loc='lower right'); plt.grid(True, alpha=0.3); plt.show()

# 11) Save results
results_xgb_r2 = {'stage':'Stage 2 — Red (MA vs HM)','model':'XGBoost',
                  'sensitivity':sens,'specificity':spec,'accuracy':acc,
                  'roc_auc':rocA,'pr_auc':prA,'tp':tp,'tn':tn,'fp':fp,'fn':fn}


#### Model Comparison

In [None]:
import pandas as pd

rows = [results_lr_r2, results_rf_r2, results_xgb_r2]

tbl_r2 = pd.DataFrame([
    {
        "Model": r["model"],
        "Sensitivity (MA)": r["sensitivity"],
        "Specificity (HM)": r["specificity"],
        "Balanced Accuracy": 0.5 * (r["sensitivity"] + r["specificity"]),
        "Accuracy": r["accuracy"],
        "ROC AUC": r["roc_auc"],
        "PR AUC (MA)": r["pr_auc"],
        "TP": r["tp"], "FP": r["fp"], "TN": r["tn"], "FN": r["fn"]
    }
    for r in rows
]).round(4).sort_values("ROC AUC", ascending=False)

# Paper’s AUC definition (avg of sensitivity & specificity)
tbl_r2["AUC (paper def)"] = tbl_r2["Balanced Accuracy"]

print(tbl_r2)

# Optional pretty display (Jupyter)
try:
    display(tbl_r2.style.highlight_max(
        subset=["Sensitivity (MA)","Specificity (HM)","Balanced Accuracy","Accuracy","ROC AUC","PR AUC (MA)"],
        color="#d5f5e3"
    ))
except Exception:
    pass

# Export for your report
tbl_r2.to_csv("stage2_red_model_comparison.csv", index=False)
print("\nLaTeX (paste into your paper):\n")
print(tbl_r2.to_latex(index=False, float_format="%.4f",
                      caption="Stage 2 (Red: MA vs HM) — Model comparison",
                      label="tab:s2_red_models"))
