# **Preparing for Initial Dataset | 初始数据集准备**

**Mount Drive and Unzip | 挂载并解压**

In [None]:
import os
import zipfile
import shutil
from pathlib import Path

# --- Mount Drive and Unzip ---
# --- 挂载 Google Drive 并解压 ---

# Step 1.1: Mount Google Drive
# 步骤 1.1: 挂载 Google Drive
print("--> Step 1.1: Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
print("   Google Drive mounted successfully!")

# Step 1.2: Define core paths for zip files
# 步骤 1.2: 定义压缩文件的核心路径
print("\n--> Step 1.2: Defining core file paths...")
GDRIVE_PROJECT_DIR = Path("/content/drive/MyDrive/glaucoma_project_v2")
BEH_ZIP_PATH = GDRIVE_PROJECT_DIR / "BEH.zip"
DATASET_ZIP_PATH = GDRIVE_PROJECT_DIR / "Dataset.zip"

# Local directory where all data will be unzipped.
# 所有数据将被解压到的本地目录。
RAW_DATA_DIR = Path("/content/raw_datasets")

# Check if zip files exist before proceeding.
# 在继续之前检查zip文件是否存在。
assert BEH_ZIP_PATH.exists(), f"Error: BEH zip file not found at {BEH_ZIP_PATH}"
assert DATASET_ZIP_PATH.exists(), f"Error: Main dataset zip file not found at {DATASET_ZIP_PATH}"
print("   All zip file paths are correct and files exist.")

# Step 1.3: Set up a clean local working directory
# 步骤 1.3: 建立一个干净的本地工作目录
print("\n--> Step 1.3: Setting up local working directory...")
if RAW_DATA_DIR.exists():
    print(f"   - Deleting existing directory: {RAW_DATA_DIR}")
    shutil.rmtree(RAW_DATA_DIR)
print(f"   - Creating new directory: {RAW_DATA_DIR}")
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
print("   Local working directory is ready.")

# Step 1.4: Unzip datasets
# 步骤 1.4: 解压数据集
print("\n--> Step 1.4: Unzipping datasets...")
print(f"   - Unzipping {BEH_ZIP_PATH.name}...")
with zipfile.ZipFile(BEH_ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall(RAW_DATA_DIR)
print(f"     Done.")

print(f"   - Unzipping {DATASET_ZIP_PATH.name}...")
with zipfile.ZipFile(DATASET_ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall(RAW_DATA_DIR)
print(f"     Done.")
print("   All datasets have been unzipped.")

# Step 1.5: Verify the unzipped directory structure
# 步骤 1.5: 验证解压后的目录结构
print("\n--> Step 1.5: Verifying the unzipped directory structure...")
print(f"Listing contents of {RAW_DATA_DIR}:")
for dir_name in sorted(os.listdir(RAW_DATA_DIR)):
    full_path = RAW_DATA_DIR / dir_name
    if full_path.is_dir():
        subdirs = [d for d in full_path.iterdir() if d.is_dir()]
        num_files = len(list(full_path.rglob('*.*')))
        print(f"   - Found directory: '{dir_name}' | Subfolders: {len(subdirs)} | Total files: {num_files}")
    else:
        if dir_name != "__MACOSX": # Ignore macOS system files. / 忽略macOS系统文件。
            print(f"   - Found file: '{dir_name}'")

print("\n--- Initial Setup Finished ---")

--> Step 1.1: Mounting Google Drive...
Mounted at /content/drive
   Google Drive mounted successfully!

--> Step 1.2: Defining core file paths...
   All zip file paths are correct and files exist.

--> Step 1.3: Setting up local working directory...
   - Creating new directory: /content/raw_datasets
   Local working directory is ready.

--> Step 1.4: Unzipping datasets...
   - Unzipping BEH.zip...
     Done.
   - Unzipping Dataset.zip...
     Done.
   All datasets have been unzipped.

--> Step 1.5: Verifying the unzipped directory structure...
Listing contents of /content/raw_datasets:
   - Found directory: 'ACRIMA' | Subfolders: 0 | Total files: 705
   - Found directory: 'BEH (Bangladesh Eye Hospital) Dataset' | Subfolders: 2 | Total files: 636
   - Found directory: 'Drishti-GS1' | Subfolders: 1 | Total files: 716
   - Found directory: 'RIM-ONE r3' | Subfolders: 3 | Total files: 3052
   - Found directory: '__MACOSX' | Subfolders: 4 | Total files: 5437

--- Initial Setup Finished ---


In [None]:
# ==============================================================================
# Part 2: Pre-organizing Raw Datasets into a Standard Structure
# Part 2: 将原始数据集预处理成标准结构
# ==============================================================================

import pandas as pd

print("--- Starting Stage 2: Pre-organizing Raw Datasets (Corrected Version) ---")

# --- Path Definitions ---
BASE_SOURCE_DIR = Path("/content/raw_datasets")
BASE_TARGET_DIR = Path("/content/organized_datasets")

print(f"Source Directory: {BASE_SOURCE_DIR}")
print(f"Target Directory: {BASE_TARGET_DIR}")

if BASE_TARGET_DIR.exists():
    print(f"Found existing directory at {BASE_TARGET_DIR}. Deleting it now...")
    shutil.rmtree(BASE_TARGET_DIR)
    print("   Old directory successfully deleted.")
BASE_TARGET_DIR.mkdir(parents=True)
print("Base target directory created.")

stats = {}

# --- Process ACRIMA Dataset ---
print("\n--> Processing ACRIMA dataset...")
dataset_name = "ACRIMA"
stats[dataset_name] = {'g': 0, 'n': 0, 'total': 0, 'errors': []}
acrima_target_dir = BASE_TARGET_DIR / dataset_name
(acrima_target_dir / "Glaucoma").mkdir(parents=True, exist_ok=True)
(acrima_target_dir / "Normal").mkdir(parents=True, exist_ok=True)
acrima_source_dir = BASE_SOURCE_DIR / "ACRIMA"
if acrima_source_dir.is_dir():
    for src_path in acrima_source_dir.glob("*.*"):
        if src_path.is_file():
            try:
                if '_g_' in src_path.name.lower(): target_path = acrima_target_dir / "Glaucoma" / src_path.name; stats[dataset_name]['g'] += 1
                else: target_path = acrima_target_dir / "Normal" / src_path.name; stats[dataset_name]['n'] += 1
                shutil.copy(src_path, target_path)
            except Exception as e: stats[dataset_name]['errors'].append(f"Error processing {src_path.name}: {e}")
stats[dataset_name]['total'] = stats[dataset_name]['g'] + stats[dataset_name]['n']
print(f"   Finished. Glaucoma: {stats[dataset_name]['g']}, Normal: {stats[dataset_name]['n']}, Total: {stats[dataset_name]['total']}")


# --- Process Drishti-GS1 Dataset ---
print("\n--> Processing Drishti-GS1 dataset...")
dataset_name = "Drishti-GS1"
stats[dataset_name] = {'g': 0, 'n': 0, 'total': 0, 'errors': []}
drishti_target_dir = BASE_TARGET_DIR / dataset_name
(drishti_target_dir / "Glaucoma").mkdir(parents=True, exist_ok=True)
(drishti_target_dir / "Normal").mkdir(parents=True, exist_ok=True)
drishti_main_dir = BASE_SOURCE_DIR / "Drishti-GS1"
if drishti_main_dir.is_dir():
    try:
        xls_path = drishti_main_dir / "Drishti-GS1_diagnosis.xlsx"
        if not xls_path.exists(): raise FileNotFoundError(f"Diagnosis file not found at {xls_path}")
        df = pd.read_excel(xls_path, skiprows=3, usecols=['Drishti-GS File', 'Total']); df.columns = ['filename', 'diagnosis']; df.dropna(subset=['filename', 'diagnosis'], inplace=True)
        image_search_dirs = [drishti_main_dir / "Drishti-GS1_files/Training/Images", drishti_main_dir / "Drishti-GS1_files/Test/Images"]
        for _, row in df.iterrows():
            clean_filename = row['filename'].strip().replace("'", "") + '.png'; diagnosis = row['diagnosis']; found = False
            for search_dir in image_search_dirs:
                src_path = search_dir / clean_filename
                if src_path.exists():
                    if diagnosis == 'Glaucomatous': target_path = drishti_target_dir / "Glaucoma" / clean_filename; stats[dataset_name]['g'] += 1
                    elif diagnosis == 'Normal': target_path = drishti_target_dir / "Normal" / clean_filename; stats[dataset_name]['n'] += 1
                    shutil.copy(src_path, target_path); found = True; break
            if not found: stats[dataset_name]['errors'].append(f"Warning: Image not found for '{clean_filename}'")
    except Exception as e: stats[dataset_name]['errors'].append(f"Error: {e}")
stats[dataset_name]['total'] = stats[dataset_name]['g'] + stats[dataset_name]['n']
print(f"   Finished. Glaucoma: {stats[dataset_name]['g']}, Normal: {stats[dataset_name]['n']}, Total: {stats[dataset_name]['total']}")


# --- Process RIM-ONE r3 Dataset  ---
# --- 处理 RIM-ONE r3 数据集 ---
print("\n--> Processing RIM-ONE r3 dataset (Corrected Version 2)...")
dataset_name = "RIM-ONE-r3"
# 重置统计数据 / Reset stats
stats[dataset_name] = {'g': 0, 'n': 0, 's_excluded': 0, 'total': 0, 'errors': []}
rimone_target_dir = BASE_TARGET_DIR / dataset_name
(rimone_target_dir / "Glaucoma").mkdir(parents=True, exist_ok=True)
(rimone_target_dir / "Normal").mkdir(parents=True, exist_ok=True)
rimone_base_dir = BASE_SOURCE_DIR / "RIM-ONE r3"

if rimone_base_dir.is_dir():
    try:
        # ** FIX: Point directly to the 'Stereo Images' subfolders **
        # ** 修正: 直接指向 'Stereo Images' 子文件夹 **
        glaucoma_suspects_dir = rimone_base_dir / "Glaucoma and suspects" / "Stereo Images"
        normal_dir = rimone_base_dir / "Healthy" / "Stereo Images"

        # --- Process Glaucoma folder ---
        # --- 处理青光眼文件夹 ---
        if glaucoma_suspects_dir.is_dir():
            # Now, glob will find the images correctly in this specific folder.
            # 现在，glob 将能在这个特定的文件夹里正确地找到图片。
            for src_path in glaucoma_suspects_dir.glob("*.jpg"):
                if src_path.is_file():
                    if src_path.name.upper().startswith('G-'):
                        target_path = rimone_target_dir / "Glaucoma" / src_path.name
                        shutil.copy(src_path, target_path)
                        stats[dataset_name]['g'] += 1
                    elif src_path.name.upper().startswith('S-'):
                        stats[dataset_name]['s_excluded'] += 1

        # --- Process Healthy folder ---
        # --- 处理健康文件夹 ---
        if normal_dir.is_dir():
            # Same logic for the healthy images folder.
            # 对健康图片文件夹应用同样的逻辑。
            for src_path in normal_dir.glob("*.jpg"):
                if src_path.is_file():
                    target_path = rimone_target_dir / "Normal" / src_path.name
                    shutil.copy(src_path, target_path)
                    stats[dataset_name]['n'] += 1

    except Exception as e:
        stats[dataset_name]['errors'].append(f"Error: {e}")

stats[dataset_name]['total'] = stats[dataset_name]['g'] + stats[dataset_name]['n']
print(f"   Finished. Glaucoma: {stats[dataset_name]['g']}, Normal: {stats[dataset_name]['n']}, Total: {stats[dataset_name]['total']}")
if stats[dataset_name]['s_excluded'] > 0:
    print(f"   Excluded {stats[dataset_name]['s_excluded']} Suspected cases for clarity.")


# --- Final Report ---
# ... (与之前相同)
print("\n--- Data Organization Finished ---")
print("\nFinal Statistics:")
total_g, total_n, total_all = 0, 0, 0
for ds_name, ds_stats in stats.items():
    print(f"  - {ds_name}:"); print(f"    - Glaucoma: {ds_stats['g']}"); print(f"    - Normal:   {ds_stats['n']}"); print(f"    - Total:    {ds_stats['total']}")
    total_g += ds_stats['g']; total_n += ds_stats['n']; total_all += ds_stats['total']
print("\n-------------------------------------"); print("Grand Total:")
print(f"  - Total Glaucoma images: {total_g}"); print(f"  - Total Normal images:   {total_n}"); print(f"  - GRAND TOTAL IMAGES:    {total_all}"); print("-------------------------------------")
if any(ds_stats['errors'] for ds_stats in stats.values()):
    print("\nEncountered the following errors/warnings:")
    for ds_name, ds_stats in stats.items():
        if ds_stats['errors']:
            print(f"  -- In {ds_name}: --")
            for err in ds_stats['errors']: print(f"     - {err}")
else:
    print("\nNo errors encountered. All data processed successfully!")
print(f"\nAll data is now organized in: {BASE_TARGET_DIR}")

--- Starting Stage 2: Pre-organizing Raw Datasets (Corrected Version) ---
Source Directory: /content/raw_datasets
Target Directory: /content/organized_datasets
Base target directory created.

--> Processing ACRIMA dataset...
   Finished. Glaucoma: 396, Normal: 309, Total: 705

--> Processing Drishti-GS1 dataset...
   Finished. Glaucoma: 70, Normal: 31, Total: 101

--> Processing RIM-ONE r3 dataset (Corrected Version 2)...
   Finished. Glaucoma: 39, Normal: 85, Total: 124
   Excluded 35 Suspected cases for clarity.

--- Data Organization Finished ---

Final Statistics:
  - ACRIMA:
    - Glaucoma: 396
    - Normal:   309
    - Total:    705
  - Drishti-GS1:
    - Glaucoma: 70
    - Normal:   31
    - Total:    101
  - RIM-ONE-r3:
    - Glaucoma: 39
    - Normal:   85
    - Total:    124

-------------------------------------
Grand Total:
  - Total Glaucoma images: 505
  - Total Normal images:   425
  - GRAND TOTAL IMAGES:    930
-------------------------------------

No errors encountere

数据准备

In [None]:
# ==============================================================================
# Part 3: Creating the Final Dataset for Training
# Part 3: 创建用于训练的最终数据集
# ==============================================================================

import pandas as pd
from sklearn.model_selection import train_test_split
import cv2
from tqdm import tqdm
import random

# This script creates the final dataset used for all model training.
# It samples specific numbers of images from BEH and ACRIMA as described in the paper,
# then splits them into train, validation, and test sets.
# Note: Based on developer's insight, we are now using RESIZE instead of CROP,
# as it's a more robust and common practice.

# 这个脚本创建所有模型训练最终使用的数据集。
# 它会像论文里描述的那样，从BEH和ACRIMA中抽样特定数量的图片，
# 然后把它们划分为训练、验证和测试集。
# 注意：根据开发者的判断，我们现在使用“缩放”而不是“裁剪”，因为这是一个更健壮和常规的做法。

print("--- Starting Stage 3: Building the Final Dataset ---")

# --- Path Definitions ---
# --- 路径定义 ---
RAW_DATA_DIR = Path("/content/raw_datasets")
ORGANIZED_DATA_DIR = Path("/content/organized_datasets")
FINAL_DATASET_DIR = Path("/content/final_dataset")

print(f"Source of Raw BEH: {RAW_DATA_DIR}")
print(f"Source of Organized ACRIMA: {ORGANIZED_DATA_DIR}")
print(f"Target Final Dataset: {FINAL_DATASET_DIR}")

# Clean and rebuild the target directory.
# 清理并重建目标目录。
if FINAL_DATASET_DIR.exists():
    shutil.rmtree(FINAL_DATASET_DIR)
FINAL_DATASET_DIR.mkdir(parents=True)
for split in ['train', 'validation', 'test']:
    split_path = FINAL_DATASET_DIR / split
    split_path.mkdir()
    (split_path / 'Glaucoma').mkdir()
    (split_path / 'Normal').mkdir()
print("   Final dataset directory structure (train/validation/test) created.")

# --- Collect and Sample Source Images ---
# --- 收集并抽样源图片 ---
print("\n--> Step 3.1: Collecting and sampling raw images based on the paper's recipe...")
source_files = []
labels = []
random.seed(42) # For reproducibility. / 保证可复现性。

# Sample from BEH dataset
# 从BEH数据集抽样
beh_glaucoma_source = RAW_DATA_DIR / "BEH (Bangladesh Eye Hospital) Dataset" / "glaucoma"
beh_normal_source = RAW_DATA_DIR / "BEH (Bangladesh Eye Hospital) Dataset" / "normal"
# Use try-except to handle cases where there aren't enough images to sample
# 使用try-except来处理没有足够图片进行抽样的情况
try:
    source_files.extend(random.sample(list(beh_glaucoma_source.glob('*.*')), 69))
    labels.extend(['Glaucoma'] * 69)
    source_files.extend(random.sample(list(beh_normal_source.glob('*.*')), 319))
    labels.extend(['Normal'] * 319)
except ValueError as e:
    print(f"   [WARNING] Could not sample from BEH, not enough images? Error: {e}")


# Sample from ACRIMA dataset
# 从ACRIMA数据集抽样
acrima_glaucoma_source = ORGANIZED_DATA_DIR / "ACRIMA" / "Glaucoma"
acrima_normal_source = ORGANIZED_DATA_DIR / "ACRIMA" / "Normal"
try:
    source_files.extend(random.sample(list(acrima_glaucoma_source.glob('*.*')), 141))
    labels.extend(['Glaucoma'] * 141)
    source_files.extend(random.sample(list(acrima_normal_source.glob('*.*')), 50))
    labels.extend(['Normal'] * 50)
except ValueError as e:
    print(f"   [WARNING] Could not sample from ACRIMA, not enough images? Error: {e}")

print(f"   Total source files collected: {len(source_files)}")
print(f"   - Glaucoma: {labels.count('Glaucoma')}")
print(f"   - Normal: {labels.count('Normal')}")

# --- Split the Dataset ---
# --- 划分数据集 ---
# The paper specifies the exact counts for train/val/test sets.
# 论文指定了训练/验证/测试集的精确数量。
print("\n--> Step 3.2: Splitting the dataset according to the paper's ratio...")
df = pd.DataFrame({'filepath': source_files, 'label': labels})
# Ensure there's data to split before proceeding
# 在继续之前确保有数据可供划分
if not df.empty:
    train_val_df, test_df = train_test_split(df, test_size=115, random_state=42, stratify=df['label'])
    train_df, val_df = train_test_split(train_val_df, test_size=40, random_state=42, stratify=train_val_df['label'])
    print(f"   Dataset split completed:")
    print(f"   - Training set size: {len(train_df)}")
    print(f"   - Validation set size: {len(val_df)}")
    print(f"   - Test set size: {len(test_df)}")

    # --- Copy Files to Final Destination ---
    # --- 复制文件到最终目录 ---
    print("\n--> Step 3.3: Copying images to the final dataset directories...")
    def copy_files(dataframe, split_name):
        for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc=f"   Copying {split_name} set"):
            source_path = row['filepath']
            label = row['label']
            target_path = FINAL_DATASET_DIR / split_name / label / source_path.name
            shutil.copy(source_path, target_path)

    copy_files(train_df, 'train')
    copy_files(val_df, 'validation')
    copy_files(test_df, 'test')
else:
    print("   [ERROR] No files were collected for sampling. Cannot proceed with splitting.")


print("\n\n--- Final Dataset Creation Completed ---")
print(f"The final dataset is now organized in: {FINAL_DATASET_DIR}")
print("Ready for model training.")

--- Starting Stage 3: Building the Final Dataset ---
Source of Raw BEH: /content/raw_datasets
Source of Organized ACRIMA: /content/organized_datasets
Target Final Dataset: /content/final_dataset
   Final dataset directory structure (train/validation/test) created.

--> Step 3.1: Collecting and sampling raw images based on the paper's recipe...
   Total source files collected: 579
   - Glaucoma: 210
   - Normal: 369

--> Step 3.2: Splitting the dataset according to the paper's ratio...
   Dataset split completed:
   - Training set size: 424
   - Validation set size: 40
   - Test set size: 115

--> Step 3.3: Copying images to the final dataset directories...


   Copying train set: 100%|██████████| 424/424 [00:00<00:00, 564.68it/s]
   Copying validation set: 100%|██████████| 40/40 [00:00<00:00, 458.06it/s]
   Copying test set: 100%|██████████| 115/115 [00:00<00:00, 477.60it/s]



--- Final Dataset Creation Completed ---
The final dataset is now organized in: /content/final_dataset
Ready for model training.





# **Training baseline | 基线训练**


In [None]:
# ==============================================================================
# Part 4: Training Benchmark with Class Weights & Lower LR
# Part 4: 使用类别权重和更低的学习率训练基准模型
# ==============================================================================

import os
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

print("--- Starting Benchmark Training Script (Final Version: Weighted + Low LR) ---")
print("--- Objective: Train a stable and unbiased benchmark model ---")

# ==============================================================================
# 1. Configuration
# 1. 配置
# ==============================================================================
print("\n--> Step 4.1: Configuring hyperparameters...")

LEARNING_RATE = 0.0001
EPOCHS = 70
BATCH_SIZE = 6
IMG_SIZE = 300

BASE_DATA_DIR = "/content/final_dataset"
GDRIVE_RESULTS_DIR = "/content/drive/MyDrive/glaucoma_project_results/model_A_benchmark_final" # New folder for the final version / 为最终版创建新文件夹
MODEL_NAME = "A_Benchmark_Final_EfficientNetB3"

os.makedirs(GDRIVE_RESULTS_DIR, exist_ok=True)
BEST_MODEL_SAVE_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"{MODEL_NAME}_best_model.keras")
HISTORY_CSV_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"history_{MODEL_NAME}.csv")

print(f"   - Training new model: {MODEL_NAME}")
print(f"   - Using Learning Rate: {LEARNING_RATE}")
print(f"   - Results will be saved to: {GDRIVE_RESULTS_DIR}")

# ==============================================================================
# 2. Data Loading
# 2. 数据加载
# ==============================================================================
print("\n--> Step 4.2: Preparing data loaders...")

AUTOTUNE = tf.data.AUTOTUNE

def create_dataset(directory, is_training=True, seed=None):
    return tf.keras.utils.image_dataset_from_directory(
        directory,
        labels='inferred',
        label_mode='binary',
        class_names=['Glaucoma', 'Normal'],
        image_size=(IMG_SIZE, IMG_SIZE),
        interpolation='bilinear',
        batch_size=BATCH_SIZE,
        shuffle=is_training,
        seed=seed
    )

def process_image(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.efficientnet.preprocess_input(image)
    return image, label

data_augmentation_layer = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
])

def augment_and_process(image, label):
    image = data_augmentation_layer(image, training=True)
    return process_image(image, label)

train_dir = os.path.join(BASE_DATA_DIR, "train")
val_dir = os.path.join(BASE_DATA_DIR, "validation")

train_ds_raw = create_dataset(train_dir, is_training=True, seed=42)
val_ds_raw = create_dataset(val_dir, is_training=False)

train_ds = train_ds_raw.map(augment_and_process, num_parallel_calls=AUTOTUNE).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds_raw.map(process_image, num_parallel_calls=AUTOTUNE).prefetch(buffer_size=AUTOTUNE)

print("   Data loaders created successfully.")

# ==============================================================================
# 3. Model Building
# 3. 模型构建
# ==============================================================================
print("\n--> Step 4.3: Building the EfficientNetB3 model...")

def build_model(input_shape=(IMG_SIZE, IMG_SIZE, 3)):
    base_model = EfficientNetB3(include_top=False, weights='imagenet', input_shape=input_shape)
    base_model.trainable = False
    x = base_model.output
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    outputs = Dense(1, activation='sigmoid', name='predictions')(x)
    model = Model(inputs=base_model.input, outputs=outputs)
    for layer in base_model.layers[-20:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True
    return model

model = build_model()
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss=BinaryCrossentropy(),
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

print("   Model built and compiled successfully.")

# ==============================================================================
# 4. Calculate and Apply Class Weights
# 4. 计算并应用类别权重
# ==============================================================================
print("\n--> Step 4.4: Calculating class weights to handle data imbalance...")

# Based on the train split: 424 images
# Glaucoma (class 0): 154 images
# Normal (class 1): 270 images
num_glaucoma_train = 154
num_normal_train = 270
total_train = num_glaucoma_train + num_normal_train

weight_for_0 = (1 / num_glaucoma_train) * (total_train / 2.0)
weight_for_1 = (1 / num_normal_train) * (total_train / 2.0)

class_weights = {0: weight_for_0, 1: weight_for_1}

print(f"   - Calculated Class Weights:")
print(f"     - Weight for Glaucoma (class 0): {class_weights[0]:.4f}")
print(f"     - Weight for Normal (class 1):   {class_weights[1]:.4f}")

# ==============================================================================
# 5. Model Training
# 5. 模型训练
# ==============================================================================
print(f"\n--> Step 4.5: Starting training for {MODEL_NAME} with class weights and lower learning rate...")

callbacks = [
    ModelCheckpoint(filepath=BEST_MODEL_SAVE_PATH, monitor='val_auc', mode='max', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_auc', mode='max', patience=15, restore_best_weights=True, verbose=1),
    CSVLogger(HISTORY_CSV_PATH)
]

history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    class_weight=class_weights
)

print(f"\n\n===== FINAL BENCHMARK TRAINING COMPLETED FOR: {MODEL_NAME} =====")
print(f"The best model has been saved to: {BEST_MODEL_SAVE_PATH}")

--- Starting Benchmark Training Script (Final Version: Weighted + Low LR) ---
--- Objective: Train a stable and unbiased benchmark model ---

--> Step 4.1: Configuring hyperparameters...
   - Training new model: A_Benchmark_Final_EfficientNetB3
   - Using Learning Rate: 0.0001
   - Results will be saved to: /content/drive/MyDrive/glaucoma_project_results/model_A_benchmark_final

--> Step 4.2: Preparing data loaders...
Found 424 files belonging to 2 classes.
Found 40 files belonging to 2 classes.
   Data loaders created successfully.

--> Step 4.3: Building the EfficientNetB3 model...
   Model built and compiled successfully.

--> Step 4.4: Calculating class weights to handle data imbalance...
   - Calculated Class Weights:
     - Weight for Glaucoma (class 0): 1.3766
     - Weight for Normal (class 1):   0.7852

--> Step 4.5: Starting training for A_Benchmark_Final_EfficientNetB3 with class weights and lower learning rate...
Epoch 1/70
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[

# **Train Baseline+SRA | SRA模型训练**

In [None]:
# ==============================================================================
# Part 5 : Training Model B (SRA)
# Part 5 : 训练模型B (SRA)
# ==============================================================================

import os
import tensorflow as tf
import albumentations as A
import numpy as np
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

print("--- Starting SRA Training Script (Final Version) ---")
print("--- Objective: Train the SRA model with the same robust settings as the final benchmark ---")

# ==============================================================================
# 1. Configuration
# 1. 配置
# ==============================================================================
print("\n--> Step 5.1: Configuring hyperparameters for the SRA model...")
# Using the exact same settings as the benchmark model.
# 使用与基准模型完全相同的设置。
LEARNING_RATE = 0.0001
EPOCHS = 70
BATCH_SIZE = 6
IMG_SIZE = 300

BASE_DATA_DIR = "/content/final_dataset"
GDRIVE_RESULTS_DIR = "/content/drive/MyDrive/glaucoma_project_results/model_B_SRA_final" # New folder for the final version / 为最终版创建新文件夹
MODEL_NAME = "B_SRA_Final_EfficientNetB3"

os.makedirs(GDRIVE_RESULTS_DIR, exist_ok=True)
BEST_MODEL_SAVE_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"{MODEL_NAME}_best_model.keras")
HISTORY_CSV_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"history_{MODEL_NAME}.csv")
print(f"   - Training new model: {MODEL_NAME}")
print(f"   - Results will be saved to: {GDRIVE_RESULTS_DIR}")

# ==============================================================================
# 2. Data Loading with SRA Augmentation
# 2. 使用SRA增强进行数据加载
# ==============================================================================
print("\n--> Step 5.2: Preparing data loaders with the SRA strategy...")

# Define the SRA strategy.
# 定义SRA策略。
sra_strategy = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.OneOf([
        A.MotionBlur(blur_limit=5, p=1.0),
        A.GaussianBlur(blur_limit=5, p=1.0),
    ], p=0.5),
    A.OneOf([
        A.GaussNoise(var_limit=(5.0, 30.0), p=1.0),
        A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
    ], p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.ImageCompression(quality_lower=75, quality_upper=95, p=0.5)
])
AUTOTUNE = tf.data.AUTOTUNE

# Data loader functions, adapted for Albumentations.
# 适配了Albumentations的数据加载函数。
def create_dataset_unbatched(directory, is_training=True, seed=None):
    return tf.keras.utils.image_dataset_from_directory(
        directory, labels='inferred', label_mode='binary',
        class_names=['Glaucoma', 'Normal'],
        image_size=(IMG_SIZE, IMG_SIZE), interpolation='bilinear',
        batch_size=None, shuffle=is_training, seed=seed)

def process_image(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.efficientnet.preprocess_input(image)
    return image, label

def aug_fn(image_tensor):
    image_np = image_tensor.numpy().astype(np.uint8)
    aug_img = sra_strategy(image=image_np)['image']
    return tf.convert_to_tensor(aug_img, dtype=tf.float32)

@tf.function
def process_with_sra(image, label):
    aug_img = tf.py_function(func=aug_fn, inp=[image], Tout=tf.float32)
    aug_img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    return process_image(aug_img, label)

train_dir = os.path.join(BASE_DATA_DIR, "train")
val_dir = os.path.join(BASE_DATA_DIR, "validation")
train_ds_unbatched = create_dataset_unbatched(train_dir, is_training=True, seed=42)
val_ds_unbatched = create_dataset_unbatched(val_dir, is_training=False)

# IMPORTANT: SRA is only applied to the training set. Validation set remains clean.
# 重要: SRA只应用于训练集。验证集保持干净。
train_ds = train_ds_unbatched.map(process_with_sra, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds_unbatched.map(process_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
print("   SRA-infused data loaders created successfully.")

# ==============================================================================
# 3. Model Building (Identical to Benchmark)
# 3. 模型构建 (与基准模型完全相同)
# ==============================================================================
print("\n--> Step 5.3: Building the EfficientNetB3 model (identical architecture)...")

def build_model(input_shape=(IMG_SIZE, IMG_SIZE, 3)):
    base_model = EfficientNetB3(include_top=False, weights='imagenet', input_shape=input_shape)
    base_model.trainable = False
    x = base_model.output
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    outputs = Dense(1, activation='sigmoid', name='predictions')(x)
    model = Model(inputs=base_model.input, outputs=outputs)
    for layer in base_model.layers[-20:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True
    return model

model = build_model()
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss=BinaryCrossentropy(),
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
print("   Model built and compiled successfully.")

# ==============================================================================
# 4. Class Weights (Identical to Benchmark)
# 4. 类别权重 (与基准模型完全相同)
# ==============================================================================
print("\n--> Step 5.4: Calculating class weights...")
num_glaucoma_train = 154
num_normal_train = 270
total_train = num_glaucoma_train + num_normal_train
weight_for_0 = (1 / num_glaucoma_train) * (total_train / 2.0)
weight_for_1 = (1 / num_normal_train) * (total_train / 2.0)
class_weights = {0: weight_for_0, 1: weight_for_1}
print(f"   - Class weights applied: G={class_weights[0]:.2f}, N={class_weights[1]:.2f}")

# ==============================================================================
# 5. Model Training
# 5. 模型训练
# ==============================================================================
print(f"\n--> Step 5.5: Starting training for {MODEL_NAME}...")

callbacks = [
    ModelCheckpoint(filepath=BEST_MODEL_SAVE_PATH, monitor='val_auc', mode='max', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_auc', mode='max', patience=15, restore_best_weights=True, verbose=1),
    CSVLogger(HISTORY_CSV_PATH)
]

history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    class_weight=class_weights
)

print(f"\n\n===== SRA TRAINING COMPLETED FOR: {MODEL_NAME} =====")
print(f"The best SRA model has been saved to: {BEST_MODEL_SAVE_PATH}")

--- Starting SRA Training Script (Final Version) ---
--- Objective: Train the SRA model with the same robust settings as the final benchmark ---

--> Step 5.1: Configuring hyperparameters for the SRA model...
   - Training new model: B_SRA_Final_EfficientNetB3
   - Results will be saved to: /content/drive/MyDrive/glaucoma_project_results/model_B_SRA_final

--> Step 5.2: Preparing data loaders with the SRA strategy...
Found 424 files belonging to 2 classes.
Found 40 files belonging to 2 classes.
   SRA-infused data loaders created successfully.

--> Step 5.3: Building the EfficientNetB3 model (identical architecture)...


  A.GaussNoise(var_limit=(5.0, 30.0), p=1.0),
  A.ImageCompression(quality_lower=75, quality_upper=95, p=0.5)


   Model built and compiled successfully.

--> Step 5.4: Calculating class weights...
   - Class weights applied: G=1.38, N=0.79

--> Step 5.5: Starting training for B_SRA_Final_EfficientNetB3...
Epoch 1/70
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 448ms/step - accuracy: 0.7532 - auc: 0.7029 - loss: 0.6115
Epoch 1: val_auc improved from -inf to 0.80907, saving model to /content/drive/MyDrive/glaucoma_project_results/model_B_SRA_final/B_SRA_Final_EfficientNetB3_best_model.keras
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 771ms/step - accuracy: 0.7534 - auc: 0.7037 - loss: 0.6109 - val_accuracy: 0.7250 - val_auc: 0.8091 - val_loss: 0.5758
Epoch 2/70
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 158ms/step - accuracy: 0.8258 - auc: 0.8630 - loss: 0.4489
Epoch 2: val_auc improved from 0.80907 to 0.85165, saving model to /content/drive/MyDrive/glaucoma_project_results/model_B_SRA_final/B_SRA_Final_EfficientNetB3_best_model.keras


# **Train Baseline+SRA+Attention | 训练模型C-SRA+Attention**

In [None]:
# ==============================================================================
# Part 6 : Training Model C (SRA + Attention)
# Part 6 : 训练模型C (SRA + Attention)
# ==============================================================================

import os
import sys
import tensorflow as tf
import albumentations as A
import numpy as np
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

# --- Dependency Check for CBAM ---
# --- CBAM 依赖项检查 ---
print("--- Checking for the essential 'cbam.py' module... ---")
CBAM_FILE_PATH = '/content/drive/MyDrive/cbam.py'
if not os.path.exists(CBAM_FILE_PATH):
    sys.exit(f"\n[FATAL ERROR] cbam.py not found at: {CBAM_FILE_PATH}")
else:
    shutil.copy(CBAM_FILE_PATH, os.path.join(os.getcwd(), 'cbam.py'))
    from cbam import cbam_block
    print("--- Successfully loaded the CBAM module. ---")

print("\n--- Starting SRA + Attention Training Script (Final Version) ---")
print("--- Objective: Train the ultimate champion with both SRA and Attention ---")

# ==============================================================================
# 1. Configuration
# 1. 配置
# ==============================================================================
print("\n--> Step 6.1: Configuring hyperparameters for the Attention model...")
# Using the exact same settings as the SRA model.
# 使用与SRA模型完全相同的设置。
LEARNING_RATE = 0.0001
EPOCHS = 70
BATCH_SIZE = 6
IMG_SIZE = 300

BASE_DATA_DIR = "/content/final_dataset"
GDRIVE_RESULTS_DIR = "/content/drive/MyDrive/glaucoma_project_results/model_C_SRA_Attention_final" # New folder for the final version / 为最终版创建新文件夹
MODEL_NAME = "C_SRA_Attention_Final_EfficientNetB3"

os.makedirs(GDRIVE_RESULTS_DIR, exist_ok=True)
BEST_MODEL_SAVE_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"{MODEL_NAME}_best_model.keras")
HISTORY_CSV_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"history_{MODEL_NAME}.csv")
print(f"   - Training new model: {MODEL_NAME}")
print(f"   - Results will be saved to: {GDRIVE_RESULTS_DIR}")

# ==============================================================================
# 2. Data Loading with SRA Augmentation (Identical to Block 5)
# 2. 使用SRA增强进行数据加载 (与Block 5完全相同)
# ==============================================================================
print("\n--> Step 6.2: Preparing SRA-infused data loaders...")
# Re-using the same data loading pipeline from the SRA script.
# 复用SRA脚本中相同的数据加载流程。
sra_strategy = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.OneOf([A.MotionBlur(blur_limit=5, p=1.0), A.GaussianBlur(blur_limit=5, p=1.0)], p=0.5),
    A.OneOf([A.GaussNoise(var_limit=(5.0, 30.0), p=1.0), A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0)], p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.ImageCompression(quality_lower=75, quality_upper=95, p=0.5)
])
AUTOTUNE = tf.data.AUTOTUNE

def create_dataset_unbatched(directory, is_training=True, seed=None):
    return tf.keras.utils.image_dataset_from_directory(
        directory, labels='inferred', label_mode='binary', class_names=['Glaucoma', 'Normal'],
        image_size=(IMG_SIZE, IMG_SIZE), interpolation='bilinear',
        batch_size=None, shuffle=is_training, seed=seed)

def process_image(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.efficientnet.preprocess_input(image)
    return image, label

def aug_fn(image_tensor):
    image_np = image_tensor.numpy().astype(np.uint8); aug_img = sra_strategy(image=image_np)['image']; return tf.convert_to_tensor(aug_img, dtype=tf.float32)

@tf.function
def process_with_sra(image, label):
    aug_img = tf.py_function(func=aug_fn, inp=[image], Tout=tf.float32); aug_img.set_shape([IMG_SIZE, IMG_SIZE, 3]); return process_image(aug_img, label)

train_dir = os.path.join(BASE_DATA_DIR, "train"); val_dir = os.path.join(BASE_DATA_DIR, "validation")
train_ds_unbatched = create_dataset_unbatched(train_dir, is_training=True, seed=42); val_ds_unbatched = create_dataset_unbatched(val_dir, is_training=False)
train_ds = train_ds_unbatched.map(process_with_sra, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds_unbatched.map(process_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
print("   SRA-infused data loaders are ready.")

# ==============================================================================
# 3. Model Building with Attention
# 3. 构建带Attention的模型
# ==============================================================================
print("\n--> Step 6.3: Building the Attention-Infused EfficientNetB3 model...")

def build_attention_model(input_shape=(IMG_SIZE, IMG_SIZE, 3)):
    inputs = Input(shape=input_shape)
    base_model = EfficientNetB3(include_top=False, weights='imagenet', input_tensor=inputs)
    base_model.trainable = False

    # Injecting the CBAM block right after the base model's feature extraction.
    # 在基础模型的特征提取之后，立即注入CBAM模块。
    x = cbam_block(base_model.output)

    x = GlobalAveragePooling2D(name='avg_pool')(x)
    outputs = Dense(1, activation='sigmoid', name='predictions')(x)
    model = Model(inputs=inputs, outputs=outputs)

    # Fine-tune the top layers of the base model.
    # 微调基础模型的顶层。
    for layer in base_model.layers[-20:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True
    return model

model = build_attention_model()
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss=BinaryCrossentropy(),
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
model.summary()
print("   Attention-infused model built and compiled successfully.")

# ==============================================================================
# 4. Class Weights (Identical to Benchmark & SRA)
# 4. 类别权重 (与基准和SRA模型完全相同)
# ==============================================================================
print("\n--> Step 6.4: Calculating class weights...")
num_glaucoma_train = 154; num_normal_train = 270; total_train = num_glaucoma_train + num_normal_train
weight_for_0 = (1 / num_glaucoma_train) * (total_train / 2.0)
weight_for_1 = (1 / num_normal_train) * (total_train / 2.0)
class_weights = {0: weight_for_0, 1: weight_for_1}
print(f"   - Class weights applied: G={class_weights[0]:.2f}, N={class_weights[1]:.2f}")

# ==============================================================================
# 5. Model Training
# 5. 模型训练
# ==============================================================================
print(f"\n--> Step 6.5: Starting training for the ultimate champion: {MODEL_NAME}...")

callbacks = [
    ModelCheckpoint(filepath=BEST_MODEL_SAVE_PATH, monitor='val_auc', mode='max', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_auc', mode='max', patience=15, restore_best_weights=True, verbose=1),
    CSVLogger(HISTORY_CSV_PATH)
]

history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    class_weight=class_weights
)

print(f"\n\n===== SRA + ATTENTION TRAINING COMPLETED FOR: {MODEL_NAME} =====")
print(f"The ultimate champion model has been saved to: {BEST_MODEL_SAVE_PATH}")

--- Checking for the essential 'cbam.py' module... ---
--- Successfully loaded the CBAM module. ---

--- Starting SRA + Attention Training Script (Final Version) ---
--- Objective: Train the ultimate champion with both SRA and Attention ---

--> Step 6.1: Configuring hyperparameters for the Attention model...
   - Training new model: C_SRA_Attention_Final_EfficientNetB3
   - Results will be saved to: /content/drive/MyDrive/glaucoma_project_results/model_C_SRA_Attention_final

--> Step 6.2: Preparing SRA-infused data loaders...
Found 424 files belonging to 2 classes.
Found 40 files belonging to 2 classes.


  A.OneOf([A.GaussNoise(var_limit=(5.0, 30.0), p=1.0), A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0)], p=0.5),
  A.ImageCompression(quality_lower=75, quality_upper=95, p=0.5)


   SRA-infused data loaders are ready.

--> Step 6.3: Building the Attention-Infused EfficientNetB3 model...


   Attention-infused model built and compiled successfully.

--> Step 6.4: Calculating class weights...
   - Class weights applied: G=1.38, N=0.79

--> Step 6.5: Starting training for the ultimate champion: C_SRA_Attention_Final_EfficientNetB3...
Epoch 1/70
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 464ms/step - accuracy: 0.6707 - auc: 0.6881 - loss: 0.6477
Epoch 1: val_auc improved from -inf to 0.82143, saving model to /content/drive/MyDrive/glaucoma_project_results/model_C_SRA_Attention_final/C_SRA_Attention_Final_EfficientNetB3_best_model.keras
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m101s[0m 812ms/step - accuracy: 0.6716 - auc: 0.6891 - loss: 0.6474 - val_accuracy: 0.7250 - val_auc: 0.8214 - val_loss: 0.5711
Epoch 2/70
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 156ms/step - accuracy: 0.8286 - auc: 0.8271 - loss: 0.5015
Epoch 2: val_auc improved from 0.82143 to 0.83379, saving model to /content/drive/MyDrive/glaucoma_project

#  **Evaluate | 测试评估**

**Create degraded dataset | 降质测试集生产**

In [None]:
# ==============================================================================
# Part 7: Creating Degraded Test Sets for Robustness Evaluation
# Part 7: 创建降质测试集用于鲁棒性评估
# ==============================================================================
import os
import shutil
from pathlib import Path
import albumentations as A
import cv2
from tqdm import tqdm

print("--- Starting creation of the degraded test sets (from TRUE in-domain test set) ---")

# --- Configuration ---
# --- 配置 ---
# Source is the 'clean' test set from the final dataset.
# 源是最终数据集里那个“干净”测试集。
SOURCE_TEST_DIR = Path("/content/final_dataset/test")
DEGRADED_TEST_SETS_DIR = Path("/content/degraded_test_sets")

print(f"Source of clean test data: {SOURCE_TEST_DIR}")
print(f"Target directory for degraded test sets: {DEGRADED_TEST_SETS_DIR}")

# --- Degradation strategies remain the same ---
# --- 降质策略保持不变 ---
degradation_strategies = {
    "Blur": A.Compose([A.GaussianBlur(blur_limit=(5, 5), always_apply=True)]),
    "Noise": A.Compose([A.GaussNoise(var_limit=(5.0, 30.0), always_apply=True)]),
    "Lighting": A.Compose([A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, always_apply=True)]),
    "Artifacts": A.Compose([A.ImageCompression(quality_lower=75, quality_upper=75, always_apply=True)])
}
print(f"Defined {len(degradation_strategies)} degradation types.")

if DEGRADED_TEST_SETS_DIR.exists():
    shutil.rmtree(DEGRADED_TEST_SETS_DIR)
DEGRADED_TEST_SETS_DIR.mkdir(parents=True)

# --- Create Degraded Sets ---
# --- 创建降质集 ---
print("\n--> Creating degraded test sets...")
clean_target_dir = DEGRADED_TEST_SETS_DIR / "test_Clean"
shutil.copytree(SOURCE_TEST_DIR, clean_target_dir)
print(f"   - Copied 'test_Clean' set to: {clean_target_dir}")

for name, strategy in degradation_strategies.items():
    degraded_dir = DEGRADED_TEST_SETS_DIR / f"test_{name}"
    print(f"   - Creating '{degraded_dir.name}' set...")
    for class_name in ["Glaucoma", "Normal"]:
        source_class_dir = SOURCE_TEST_DIR / class_name
        target_class_dir = degraded_dir / class_name
        target_class_dir.mkdir(parents=True, exist_ok=True)
        image_files = list(source_class_dir.glob("*.*"))
        for img_path in tqdm(image_files, desc=f"     Processing {class_name}", unit="img"):
            try:
                image = cv2.imread(str(img_path))
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                augmented_image = strategy(image=image)['image']
                target_img_path = target_class_dir / img_path.name
                cv2.imwrite(str(target_img_path), cv2.cvtColor(augmented_image, cv2.COLOR_RGB2BGR))
            except Exception as e:
                print(f"     [ERROR] Failed to process {img_path}: {e}")

print("\n\n--- Degraded Test Set Creation Completed ---")
print(f"All test sets (Clean + Degraded) are now ready in: {DEGRADED_TEST_SETS_DIR}")

  "Blur": A.Compose([A.GaussianBlur(blur_limit=(5, 5), always_apply=True)]),
  "Noise": A.Compose([A.GaussNoise(var_limit=(5.0, 30.0), always_apply=True)]),
  "Lighting": A.Compose([A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, always_apply=True)]),
  "Artifacts": A.Compose([A.ImageCompression(quality_lower=75, quality_upper=75, always_apply=True)])


--- Starting creation of the degraded test sets (from TRUE in-domain test set) ---
Source of clean test data: /content/final_dataset/test
Target directory for degraded test sets: /content/degraded_test_sets
Defined 4 degradation types.

--> Creating degraded test sets...
   - Copied 'test_Clean' set to: /content/degraded_test_sets/test_Clean
   - Creating 'test_Blur' set...


     Processing Glaucoma: 100%|██████████| 42/42 [00:01<00:00, 31.16img/s]
     Processing Normal: 100%|██████████| 73/73 [00:05<00:00, 13.40img/s]


   - Creating 'test_Noise' set...


     Processing Glaucoma: 100%|██████████| 42/42 [00:04<00:00,  9.54img/s]
     Processing Normal: 100%|██████████| 73/73 [00:16<00:00,  4.44img/s]


   - Creating 'test_Lighting' set...


     Processing Glaucoma: 100%|██████████| 42/42 [00:01<00:00, 35.99img/s]
     Processing Normal: 100%|██████████| 73/73 [00:05<00:00, 14.09img/s]


   - Creating 'test_Artifacts' set...


     Processing Glaucoma: 100%|██████████| 42/42 [00:01<00:00, 23.72img/s]
     Processing Normal: 100%|██████████| 73/73 [00:05<00:00, 13.65img/s]



--- Degraded Test Set Creation Completed ---
All test sets (Clean + Degraded) are now ready in: /content/degraded_test_sets





**Generally Evaluate | 普通评估**

In [None]:
# =============================================================================
# Part 9: Evaluation
# Part 9: 评估
# =============================================================================

import os
import sys
import tensorflow as tf
import numpy as np
import pandas as pd
import time
import shutil
from tensorflow.keras.models import load_model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from tabulate import tabulate
from tqdm import tqdm

# --- Dependency Check for CBAM ---
# --- CBAM 依赖项检查 ---
print("--- Checking for 'cbam.py' module for model loading... ---")
CBAM_FILE_PATH = '/content/drive/MyDrive/cbam.py'
custom_objects_dict = {}

if os.path.exists(CBAM_FILE_PATH):
    shutil.copy(CBAM_FILE_PATH, os.path.join(os.getcwd(), 'cbam.py'))
    # *** FIX: Import all necessary custom components from your cbam.py file ***
    # *** 修正: 从你的 cbam.py 文件中导入所有必需的自定义组件 ***
    from cbam import cbam_block, MaxAcrossChannel, MeanAcrossChannel

    # *** FIX: Define a comprehensive dictionary for all custom objects ***
    # *** 修正: 为所有自定义对象定义一个详尽的字典 ***
    custom_objects_dict = {
        'cbam_block': cbam_block,
        'MaxAcrossChannel': MaxAcrossChannel,
        'MeanAcrossChannel': MeanAcrossChannel
    }
    print("--- Successfully loaded the CBAM module and prepared custom objects. ---")
else:
    print(f"\n[WARNING] 'cbam.py' not found at {CBAM_FILE_PATH}. Attention model will fail to load.")

print("\n--- Starting FINAL, COMPREHENSIVE Evaluation Platform (V2) - CORRECTED ---")

# ==============================================================================
# 1. Configuration: Define Models and Test Sets
# 1. 配置: 定义模型和测试集
# ==============================================================================
print("\n--> Step 9.1: Configuring models and test sets...")

MODELS_TO_EVALUATE = {
    "A_Benchmark_Final": "/content/drive/MyDrive/glaucoma_project_results/model_A_benchmark_final/A_Benchmark_Final_EfficientNetB3_best_model.keras",
    "B_SRA_Final": "/content/drive/MyDrive/glaucoma_project_results/model_B_SRA_final/B_SRA_Final_EfficientNetB3_best_model.keras",
    "C_SRA_Attention_Final": "/content/drive/MyDrive/glaucoma_project_results/model_C_SRA_Attention_final/C_SRA_Attention_Final_EfficientNetB3_best_model.keras"
}

TEST_SETS = {
    # In-Domain Robustness Tests / 域内鲁棒性测试
    "In-Domain_Clean": "/content/degraded_test_sets/test_Clean",
    "In-Domain_Blur": "/content/degraded_test_sets/test_Blur",
    "In-Domain_Noise": "/content/degraded_test_sets/test_Noise",
    "In-Domain_Lighting": "/content/degraded_test_sets/test_Lighting",
    "In-Domain_Artifacts": "/content/degraded_test_sets/test_Artifacts",
    # Cross-Domain Generalization Tests / 跨域泛化能力测试
    "Cross-Domain_Drishti-GS1": "/content/organized_datasets/Drishti-GS1",
    "Cross-Domain_RIM-ONE-r3": "/content/organized_datasets/RIM-ONE-r3"
}

GDRIVE_RESULTS_DIR = "/content/drive/MyDrive/glaucoma_project_results/final_comprehensive_report"
os.makedirs(GDRIVE_RESULTS_DIR, exist_ok=True)
timestamp = time.strftime("%Y%m%d-%H%M%S")
OUTPUT_CSV_PATH = os.path.join(GDRIVE_RESULTS_DIR, f"comprehensive_report_{timestamp}.csv")

IMG_SIZE = 300
AUTOTUNE = tf.data.AUTOTUNE

# ==============================================================================
# 2. Core Evaluation Functions
# 2. 核心评估函数
# ==============================================================================
print("\n--> Step 9.2: Preparing the comprehensive evaluation engine...")

def create_test_dataset(directory):
    if not os.path.isdir(os.path.join(directory, "Glaucoma")) or not os.path.isdir(os.path.join(directory, "Normal")):
        print(f"   [WARNING] Skipping '{os.path.basename(directory)}': missing class subdirectories.")
        return None
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory,
        labels='inferred',
        label_mode='binary',
        class_names=['Glaucoma', 'Normal'],
        image_size=(IMG_SIZE, IMG_SIZE),
        interpolation='bilinear',
        batch_size=32,
        shuffle=False
    )
    return dataset.map(lambda x, y: (tf.keras.applications.efficientnet.preprocess_input(tf.cast(x, tf.float32)), y),
                       num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

def find_optimal_threshold(model, val_dir):
    print("   -   Finding optimal threshold on clean validation set...")
    val_ds = create_test_dataset(val_dir)
    if val_ds is None: return 0.5
    y_true_val = np.concatenate([y for x, y in val_ds], axis=0).flatten()
    y_pred_probs_val = model.predict(val_ds, verbose=0).flatten()

    best_f1_g, best_threshold = 0, 0.5
    for threshold in np.arange(0.01, 1.0, 0.01):
        y_pred_binary = (y_pred_probs_val > threshold).astype(int)
        f1 = f1_score(y_true_val, y_pred_binary, pos_label=0, zero_division=0) # pos_label=0 for Glaucoma
        if f1 > best_f1_g:
            best_f1_g, best_threshold = f1, threshold
    print(f"   -   Optimal threshold found: {best_threshold:.4f} (Max Glaucoma F1 on Val: {best_f1_g:.4f})")
    return best_threshold

def evaluate_model(model, dataset_path, optimal_threshold):
    dataset = create_test_dataset(dataset_path)
    if dataset is None:
        return {metric: np.nan for metric in [
            'Test Accuracy', 'AUC',
            'F1-Score(G)', 'Recall(G)', 'Precision(G)',
            'F1-Score(N)', 'Recall(N)', 'Precision(N)'
        ]}

    y_true = np.concatenate([y for x, y in dataset], axis=0).flatten()
    y_pred_probs = model.predict(dataset, verbose=0).flatten()
    y_pred_binary = (y_pred_probs > optimal_threshold).astype(int)

    return {
        'Test Accuracy': accuracy_score(y_true, y_pred_binary),
        'AUC': roc_auc_score(y_true, 1 - y_pred_probs) if len(np.unique(y_true)) > 1 else np.nan,

        # Glaucoma (class 0) metrics
        'F1-Score(G)': f1_score(y_true, y_pred_binary, pos_label=0, zero_division=0),
        'Recall(G)': recall_score(y_true, y_pred_binary, pos_label=0, zero_division=0),
        'Precision(G)': precision_score(y_true, y_pred_binary, pos_label=0, zero_division=0),

        # Normal (class 1) metrics
        'F1-Score(N)': f1_score(y_true, y_pred_binary, pos_label=1, zero_division=0),
        'Recall(N)': recall_score(y_true, y_pred_binary, pos_label=1, zero_division=0),
        'Precision(N)': precision_score(y_true, y_pred_binary, pos_label=1, zero_division=0),
    }

# ==============================================================================
# 3. Run the Evaluation Gauntlet
# 3. 运行评估流程
# ==============================================================================
print("\n--> Step 9.3: Starting the final, comprehensive evaluation...")

CLEAN_VAL_DIR = "/content/final_dataset/validation"
all_results = []

for model_name, model_path in MODELS_TO_EVALUATE.items():
    print(f"\n===== Evaluating Model: {model_name} =====")
    if not os.path.exists(model_path):
        print(f"   [ERROR] Model file not found at {model_path}. Skipping.")
        continue

    try:
        # *** FIX: Pass the comprehensive custom_objects_dict to the load_model function ***
        # *** 修正: 将详尽的 custom_objects_dict 传递给 load_model 函数 ***
        model = load_model(model_path, custom_objects=custom_objects_dict)
        print("   -   Model loaded successfully.")
    except Exception as e:
        print(f"   [ERROR] Failed to load model {model_name}: {e}")
        continue

    optimal_threshold = find_optimal_threshold(model, CLEAN_VAL_DIR)

    for test_name, test_path in TEST_SETS.items():
        print(f"   -   Testing on: '{test_name}'...")
        metrics = evaluate_model(model, test_path, optimal_threshold)
        result_row = {
            'Model': model_name,
            'Test Set': test_name,
            'Threshold': optimal_threshold,
            **metrics
        }
        all_results.append(result_row)

# ==============================================================================
# 4. Final Report
# 4. 最终报告
# ==============================================================================
print("\n\n===== FINAL COMPREHENSIVE REPORT =====")
if all_results:
    results_df = pd.DataFrame(all_results)
    column_order = [
        'Model', 'Test Set', 'Test Accuracy', 'AUC',
        'F1-Score(G)', 'Recall(G)', 'Precision(G)',
        'F1-Score(N)', 'Recall(N)', 'Precision(N)',
        'Threshold'
    ]
    results_df = results_df[column_order]

    print(tabulate(results_df, headers='keys', tablefmt='psql', showindex=False, floatfmt=".4f"))

    results_df.to_csv(OUTPUT_CSV_PATH, index=False)
    print(f"\n--- Detailed comprehensive report saved to CSV: {OUTPUT_CSV_PATH} ---")
else:
    print("No results were generated. Please check for errors in the evaluation loop.")

print("\n\n===== EVALUATION PLATFORM FINISHED =====")

--- Checking for 'cbam.py' module for model loading... ---
--- Successfully loaded the CBAM module and prepared custom objects. ---

--- Starting FINAL, COMPREHENSIVE Evaluation Platform (V2) - CORRECTED ---

--> Step 9.1: Configuring models and test sets...

--> Step 9.2: Preparing the comprehensive evaluation engine...

--> Step 9.3: Starting the final, comprehensive evaluation...

===== Evaluating Model: A_Benchmark_Final =====
   -   Model loaded successfully.
   -   Finding optimal threshold on clean validation set...
Found 40 files belonging to 2 classes.
   -   Optimal threshold found: 0.0300 (Max Glaucoma F1 on Val: 0.9286)
   -   Testing on: 'In-Domain_Clean'...
Found 115 files belonging to 2 classes.
   -   Testing on: 'In-Domain_Blur'...
Found 115 files belonging to 2 classes.
   -   Testing on: 'In-Domain_Noise'...
Found 115 files belonging to 2 classes.
   -   Testing on: 'In-Domain_Lighting'...
Found 115 files belonging to 2 classes.
   -   Testing on: 'In-Domain_Artifact