## Attention-based ODIR Classifier (Hypertension-Priority)

This notebook prioritizes Hypertension labeling: if a sample contains Hypertension among multiple diagnoses, it is labeled as Hypertension (H). Other classes remain single-label: G, C, A, M.

- Backbone: EfficientNetB0 (optionally B3)
- Attention: CBAM on CNN feature maps
- Split: robust stratified to ensure all 5 classes appear in val/test
- Metrics: accuracy, weighted/macro F1, ROC-AUC (OvR), PR-AUC
- Outputs saved to `/kaggle/working/`.


In [None]:
import os
import re
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB3
from tensorflow.keras.applications.efficientnet import preprocess_input as effnet_preprocess

# Config
DATA_DIR = "/kaggle/input/ocular-disease-recognition-odir5k"
OUTPUT_DIR = "/kaggle/working"
IMAGE_SIZE = 224
BACKBONE = "b0"
BATCH_SIZE = 16
EPOCHS = 40
SEED = 42

USE_TFA = False
try:
    import tensorflow_addons as tfa
    USE_TFA = True
except Exception:
    pass

# Seed + mixed precision
tf.keras.utils.set_random_seed(SEED)
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
except Exception:
    pass

# Augmentation
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
], name="augment")

# CBAM
@tf.keras.utils.register_keras_serializable(package="custom")
class CBAM(layers.Layer):
    def __init__(self, reduction_ratio: int = 16, kernel_size: int = 7, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
        self.kernel_size = kernel_size
    def build(self, input_shape):
        channels = int(input_shape[-1])
        hidden = max(channels // self.reduction_ratio, 1)
        self.mlp = tf.keras.Sequential([
            layers.Dense(hidden, activation="relu"),
            layers.Dense(channels)
        ])
        self.spatial_conv = layers.Conv2D(1, kernel_size=self.kernel_size, padding="same", activation="sigmoid")
        super().build(input_shape)
    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=[1,2], keepdims=True)
        max_pool = tf.reduce_max(x, axis=[1,2], keepdims=True)
        mlp_avg = self.mlp(layers.Flatten()(avg_pool))
        mlp_max = self.mlp(layers.Flatten()(max_pool))
        channel_attn = tf.nn.sigmoid(mlp_avg + mlp_max)
        channel_attn = tf.reshape(channel_attn, (-1,1,1,tf.shape(x)[-1]))
        x = x * channel_attn
        avg_pool_sp = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_pool_sp = tf.reduce_max(x, axis=-1, keepdims=True)
        sp = tf.concat([avg_pool_sp, max_pool_sp], axis=-1)
        spatial_attn = self.spatial_conv(sp)
        x = x * spatial_attn
        return x

# Backbone
def build_backbone(x_in, backbone="b0", image_size=224):
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=x_in)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x_in)
    for l in base.layers:
        l.trainable = True
    return base

# Model
def build_model_cbam(image_size=224, backbone="b0", num_classes=5, dropout=0.4):
    tf.keras.backend.clear_session()
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    bb = build_backbone(x_in, backbone=backbone, image_size=image_size)
    x = bb.output
    x = layers.BatchNormalization()(x)
    x = CBAM()(x)
    x = layers.Conv2D(192, 1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)


In [None]:
# Parse ODIR-5K with Hypertension-priority labeling
ODIR_DIR = os.path.join(DATA_DIR, "ODIR-5K", "ODIR-5K")
EXCEL_PATH = os.path.join(ODIR_DIR, "data.xlsx")
TRAIN_IMG_DIR = os.path.join(ODIR_DIR, "Training Images")
TEST_IMG_DIR = os.path.join(ODIR_DIR, "Testing Images")

meta = pd.read_excel(EXCEL_PATH)

def find_col(df, substrings):
    subs = [s.lower() for s in substrings]
    for c in df.columns:
        lc = str(c).lower()
        if all(s in lc for s in subs):
            return c
    return None

left_img_col = find_col(meta, ["left","fundus"]) or find_col(meta, ["left","image"])
right_img_col = find_col(meta, ["right","fundus"]) or find_col(meta, ["right","image"])
left_diag_col = find_col(meta, ["left","diagn"]) or find_col(meta, ["left","keyword"]) 
right_diag_col = find_col(meta, ["right","diagn"]) or find_col(meta, ["right","keyword"]) 

KEYWORD_TO_SHORT = {
    "glaucoma":"G","cataract":"C","amd":"A","age-related macular degeneration":"A","age related macular degeneration":"A",
    "hypertension":"H","hypertensive":"H","hypertensive retinopathy":"H","htn":"H",
    "myopia":"M","normal":"N","diabetic retinopathy":"D","dr":"D","other":"O","others":"O"
}

TARGET = ["G","C","A","H","M"]

records = []
for _, row in meta.iterrows():
    for img_col, diag_col in [(left_img_col,left_diag_col),(right_img_col,right_diag_col)]:
        fname = row.get(img_col)
        if not isinstance(fname,str) or not fname:
            continue
        text = row.get(diag_col) if diag_col in meta.columns else None
        text_l = str(text).lower() if isinstance(text,str) else ""
        labels = set()
        for k,s in KEYWORD_TO_SHORT.items():
            if k in text_l:
                labels.add(s)
        labels = [l for l in labels if l in TARGET]
        if not labels:
            continue
        # Hypertension-priority: if H present among multi-label, set to H; otherwise for multiple labels, pick first in TARGET order
        if "H" in labels:
            final = "H"
        else:
            # keep single-label if one; else choose a deterministic label by TARGET order
            labels_sorted = [l for l in TARGET if l in labels]
            final = labels_sorted[0]
        records.append({"filename": fname, "label": final})

df = pd.DataFrame.from_records(records)

# Resolve paths
paths = []
for fname in df["filename"].values:
    p = os.path.join(TRAIN_IMG_DIR, fname)
    if not os.path.exists(p):
        p = os.path.join(TEST_IMG_DIR, fname)
    paths.append(p)
df["path"] = paths

df = df[df["path"].apply(os.path.exists)].reset_index(drop=True)
print("Counts by label:\n", df["label"].value_counts())

# Robust stratified split
from sklearn.model_selection import StratifiedShuffleSplit

def stratified_split_with_min(df_in, label_col="label", val_size=0.15, test_size=0.15, seed=SEED, attempts=100):
    sss1 = StratifiedShuffleSplit(n_splits=attempts, test_size=val_size+test_size, random_state=seed)
    labels = df_in[label_col].values
    last = None
    for train_idx, temp_idx in sss1.split(df_in, labels):
        train_df_cand = df_in.iloc[train_idx]
        temp_df_cand = df_in.iloc[temp_idx]
        sss2 = StratifiedShuffleSplit(n_splits=1, test_size=test_size/(val_size+test_size), random_state=seed)
        for val_idx, test_idx in sss2.split(temp_df_cand, temp_df_cand[label_col].values):
            val_df_cand = temp_df_cand.iloc[val_idx]
            test_df_cand = temp_df_cand.iloc[test_idx]
            last = (train_df_cand, val_df_cand, test_df_cand)
            if set(val_df_cand[label_col].unique()) == set(TARGET) and set(test_df_cand[label_col].unique()) == set(TARGET):
                return train_df_cand.reset_index(drop=True), val_df_cand.reset_index(drop=True), test_df_cand.reset_index(drop=True)
    print("Warning: could not guarantee all classes in val/test; using best-effort stratified split.")
    train_df_cand, val_df_cand, test_df_cand = last
    return train_df_cand.reset_index(drop=True), val_df_cand.reset_index(drop=True), test_df_cand.reset_index(drop=True)

train_df, val_df, test_df = stratified_split_with_min(df)
print("Split sizes:", len(train_df), len(val_df), len(test_df))


In [None]:
# Build arrays

def load_and_preprocess(path):
    img = cv2.imread(path)
    if img is None:
        img = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
    return img.astype("float32")

class_to_idx = {c:i for i,c in enumerate(["G","C","A","H","M"])}
num_classes = 5

def df_to_arrays(df):
    xs = np.stack([load_and_preprocess(p) for p in df["path"].values], axis=0)
    ys = np.array([class_to_idx[c] for c in df["label"].values])
    ys = tf.keras.utils.to_categorical(ys, num_classes=num_classes)
    return xs, ys

x_train, y_train = df_to_arrays(train_df)
x_val, y_val = df_to_arrays(val_df)
x_test, y_test = df_to_arrays(test_df)
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)


In [None]:
# Train
model = build_model_cbam(image_size=IMAGE_SIZE, backbone=BACKBONE, num_classes=num_classes, dropout=0.4)
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(name="acc"),
    tf.keras.metrics.AUC(name="auc"),
    tf.keras.metrics.AUC(name="prc", curve="PR"),
]
if USE_TFA:
    METRICS.append(tfa.metrics.F1Score(num_classes=num_classes, average="weighted", name="f1"))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4), loss="categorical_crossentropy", metrics=METRICS)
ckpt_path = os.path.join(OUTPUT_DIR, "best_attention_efficientnet_hprio.keras")
callbacks = [
    ModelCheckpoint(ckpt_path, save_best_only=True, monitor="val_acc", mode="max"),
    ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6, verbose=1),
    EarlyStopping(patience=10, restore_best_weights=True, monitor="val_acc", mode="max", verbose=1),
]
_ = model.predict(x_train[:4], verbose=0)
history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=callbacks, verbose=1)


In [None]:
# Evaluate
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score

best_model = tf.keras.models.load_model(ckpt_path, compile=False)
best_model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=METRICS)

_ = best_model.evaluate(x_test, y_test, verbose=0)
y_prob = best_model.predict(x_test, batch_size=BATCH_SIZE, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
y_true = np.argmax(y_test, axis=1)

labels = list(range(num_classes))
class_names_full = ["Glaucoma","Cataract","AMD","Hypertension","Myopia"]
cm = confusion_matrix(y_true, y_pred, labels=labels)
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
plt.figure(figsize=(7,6))
sns.heatmap(cm_norm*100, annot=True, fmt='.2f', cmap='Blues', xticklabels=class_names_full, yticklabels=class_names_full)
plt.title('Confusion Matrix (Test) - H Priority')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix_test_hprio.png'))
plt.show()

print(classification_report(y_true, y_pred, labels=labels, target_names=class_names_full, zero_division=0))

present = sorted(list(set(y_true)))
y_true_bin = tf.keras.utils.to_categorical(y_true, num_classes=num_classes)
try:
    roc_auc = roc_auc_score(y_true_bin[:,present], y_prob[:,present], average='macro', multi_class='ovr')
    pr_auc = average_precision_score(y_true_bin[:,present], y_prob[:,present], average='macro')
    print('ROC-AUC (macro, OvR):', round(roc_auc, 4))
    print('PR-AUC (macro):', round(pr_auc, 4))
except Exception as e:
    print('AUC error:', e)
