## EfficientNet vs CBAM Attention: ODIR-5K Comparison

This notebook compares a plain EfficientNet classifier to an EfficientNet+CBAM attention variant on ODIR-5K.

- Baseline: EfficientNetB0/B3 (no attention)
- Variant: EfficientNet + CBAM block on feature maps
- Same data splits, preprocessing, and hyperparameters
- Outputs: metrics table (accuracy, weighted/macro F1, ROC-AUC, PR-AUC), confusion matrices, and training curves


In [None]:
import os, re, cv2, numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB3
from tensorflow.keras.applications.efficientnet import preprocess_input as effnet_preprocess
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedShuffleSplit

# Config
DATA_DIR = "/kaggle/input/ocular-disease-recognition-odir5k"
OUTPUT_DIR = "/kaggle/working"
IMAGE_SIZE = 224
BACKBONE = "b0"
BATCH_SIZE = 16
EPOCHS = 40
SEED = 42
USE_TFA = False
try:
    import tensorflow_addons as tfa
    USE_TFA = True
except Exception:
    pass

# Seed + mixed precision
tf.keras.utils.set_random_seed(SEED)
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
except Exception:
    pass

# Augmentation
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
], name="augment")

# CBAM
@tf.keras.utils.register_keras_serializable(package="custom")
class CBAM(layers.Layer):
    def __init__(self, reduction_ratio: int = 16, kernel_size: int = 7, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
        self.kernel_size = kernel_size
    def build(self, input_shape):
        channels = int(input_shape[-1])
        hidden = max(channels // self.reduction_ratio, 1)
        self.mlp = tf.keras.Sequential([
            layers.Dense(hidden, activation="relu"),
            layers.Dense(channels)
        ])
        self.spatial_conv = layers.Conv2D(1, kernel_size=self.kernel_size, padding="same", activation="sigmoid")
        super().build(input_shape)
    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=[1,2], keepdims=True)
        max_pool = tf.reduce_max(x, axis=[1,2], keepdims=True)
        mlp_avg = self.mlp(layers.Flatten()(avg_pool))
        mlp_max = self.mlp(layers.Flatten()(max_pool))
        channel_attn = tf.nn.sigmoid(mlp_avg + mlp_max)
        channel_attn = tf.reshape(channel_attn, (-1,1,1,tf.shape(x)[-1]))
        x = x * channel_attn
        avg_pool_sp = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_pool_sp = tf.reduce_max(x, axis=-1, keepdims=True)
        sp = tf.concat([avg_pool_sp, max_pool_sp], axis=-1)
        spatial_attn = self.spatial_conv(sp)
        x = x * spatial_attn
        return x


In [None]:
# Ensure output directory exists
import os
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# Parse ODIR-5K (single-label 5-class: G, C, A, H, M)
ODIR_DIR = os.path.join(DATA_DIR, "ODIR-5K", "ODIR-5K")
EXCEL_PATH = os.path.join(ODIR_DIR, "data.xlsx")
TRAIN_IMG_DIR = os.path.join(ODIR_DIR, "Training Images")
TEST_IMG_DIR = os.path.join(ODIR_DIR, "Testing Images")

meta = pd.read_excel(EXCEL_PATH)

def find_col(df, substrings):
    subs = [s.lower() for s in substrings]
    for c in df.columns:
        lc = str(c).lower()
        if all(s in lc for s in subs):
            return c
    return None

left_img_col = find_col(meta, ["left","fundus"]) or find_col(meta, ["left","image"]) 
right_img_col = find_col(meta, ["right","fundus"]) or find_col(meta, ["right","image"]) 
left_diag_col = find_col(meta, ["left","diagn"]) or find_col(meta, ["left","keyword"]) 
right_diag_col = find_col(meta, ["right","diagn"]) or find_col(meta, ["right","keyword"]) 

KEYWORD_TO_SHORT = {
    "glaucoma":"G","cataract":"C","amd":"A","age-related macular degeneration":"A","age related macular degeneration":"A",
    "hypertension":"H","hypertensive":"H","hypertensive retinopathy":"H","htn":"H",
    "myopia":"M","normal":"N","diabetic retinopathy":"D","dr":"D","other":"O","others":"O"
}

TARGET = ["G","C","A","H","M"]

records = []
for _, row in meta.iterrows():
    for img_col, diag_col in [(left_img_col,left_diag_col),(right_img_col,right_diag_col)]:
        fname = row.get(img_col)
        if not isinstance(fname,str) or not fname:
            continue
        text = row.get(diag_col) if diag_col in meta.columns else None
        text_l = str(text).lower() if isinstance(text,str) else ""
        labels = set()
        for k,s in KEYWORD_TO_SHORT.items():
            if k in text_l:
                labels.add(s)
        labels = [l for l in labels if l in TARGET]
        if not labels:
            continue
        # single-label preference: if H present, keep H; else first in TARGET order
        final = "H" if "H" in labels else [l for l in TARGET if l in labels][0]
        records.append({"filename": fname, "label": final})

df = pd.DataFrame.from_records(records)

# Resolve paths
df["path"] = [os.path.join(TRAIN_IMG_DIR, f) if os.path.exists(os.path.join(TRAIN_IMG_DIR, f)) else os.path.join(TEST_IMG_DIR, f) for f in df["filename"].values]
df = df[df["path"].apply(os.path.exists)].reset_index(drop=True)
print("Counts by label:\n", df["label"].value_counts())

# Robust stratified split
sss1 = StratifiedShuffleSplit(n_splits=200, test_size=0.3, random_state=SEED)
labels = df["label"].values
chosen = None
for tr_idx, tmp_idx in sss1.split(df, labels):
    tr = df.iloc[tr_idx]
    tmp = df.iloc[tmp_idx]
    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=SEED)
    for va_idx, te_idx in sss2.split(tmp, tmp["label" ].values):
        va = tmp.iloc[va_idx]
        te = tmp.iloc[te_idx]
        if set(va["label"].unique()) == set(TARGET) and set(te["label"].unique()) == set(TARGET):
            chosen = (tr.reset_index(drop=True), va.reset_index(drop=True), te.reset_index(drop=True))
            break
    if chosen:
        break
if not chosen:
    print("Warning: fallback split used.")
    chosen = (tr.reset_index(drop=True), va.reset_index(drop=True), te.reset_index(drop=True))
train_df, val_df, test_df = chosen
print("Split sizes:", len(train_df), len(val_df), len(test_df))


In [None]:
# Arrays

def load_and_preprocess(path):
    img = cv2.imread(path)
    if img is None:
        img = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
    return img.astype("float32")

class_to_idx = {c:i for i,c in enumerate(["G","C","A","H","M"])}
num_classes = 5

def df_to_arrays(df):
    xs = np.stack([load_and_preprocess(p) for p in df["path"].values], axis=0)
    ys = np.array([class_to_idx[c] for c in df["label"].values])
    ys = tf.keras.utils.to_categorical(ys, num_classes=num_classes)
    return xs, ys

x_train, y_train = df_to_arrays(train_df)
x_val, y_val = df_to_arrays(val_df)
x_test, y_test = df_to_arrays(test_df)
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)


In [None]:
# Class weights to rebalance Hypertension
from collections import Counter
APPLY_CLASS_WEIGHTS = True

train_labels = [class_to_idx[c] for c in train_df["label"].values]
cnt = Counter(train_labels)
class_weight = {i: (len(train_labels) / (num_classes * cnt.get(i, 1))) for i in range(num_classes)}
print("Class weights:", class_weight)


In [None]:
# Model builders

def build_effnet_baseline(image_size=224, backbone="b0", num_classes=5, dropout=0.4):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=x_in)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x_in)
    for l in base.layers:
        l.trainable = True
    x = base.output
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(192, 1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)

@tf.keras.utils.register_keras_serializable(package="custom")
class CBAM(layers.Layer):
    def __init__(self, reduction_ratio: int = 16, kernel_size: int = 7, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
        self.kernel_size = kernel_size
    def build(self, input_shape):
        channels = int(input_shape[-1])
        hidden = max(channels // self.reduction_ratio, 1)
        self.mlp = tf.keras.Sequential([
            layers.Dense(hidden, activation="relu"),
            layers.Dense(channels)
        ])
        self.spatial_conv = layers.Conv2D(1, kernel_size=self.kernel_size, padding="same", activation="sigmoid")
        super().build(input_shape)
    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=[1,2], keepdims=True)
        max_pool = tf.reduce_max(x, axis=[1,2], keepdims=True)
        mlp_avg = self.mlp(layers.Flatten()(avg_pool))
        mlp_max = self.mlp(layers.Flatten()(max_pool))
        channel_attn = tf.nn.sigmoid(mlp_avg + mlp_max)
        channel_attn = tf.reshape(channel_attn, (-1,1,1,tf.shape(x)[-1]))
        x = x * channel_attn
        avg_pool_sp = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_pool_sp = tf.reduce_max(x, axis=-1, keepdims=True)
        sp = tf.concat([avg_pool_sp, max_pool_sp], axis=-1)
        spatial_attn = self.spatial_conv(sp)
        return x * spatial_attn

def build_effnet_cbam(image_size=224, backbone="b0", num_classes=5, dropout=0.4):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=x_in)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x_in)
    for l in base.layers:
        l.trainable = True
    x = base.output
    x = layers.BatchNormalization()(x)
    x = CBAM()(x)
    x = layers.Conv2D(192, 1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)


In [None]:
# Train/eval utilities

def train_and_eval(model, name, x_train, y_train, x_val, y_val, x_test, y_test, lr=3e-4):
    METRICS = [
        tf.keras.metrics.CategoricalAccuracy(name="acc"),
        tf.keras.metrics.AUC(name="auc"),
        tf.keras.metrics.AUC(name="prc", curve="PR"),
    ]
    if USE_TFA:
        METRICS.append(tfa.metrics.F1Score(num_classes=y_train.shape[1], average="weighted", name="f1"))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss="categorical_crossentropy", metrics=METRICS)
    ckpt_path = os.path.join(OUTPUT_DIR, f"best_{name}.keras")
    callbacks = [
        ModelCheckpoint(ckpt_path, save_best_only=True, monitor="val_acc", mode="max"),
        ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6, verbose=1),
        EarlyStopping(patience=10, restore_best_weights=True, monitor="val_acc", mode="max", verbose=1),
    ]
    _ = model.predict(x_train[:4], verbose=0)
    hist = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=callbacks, verbose=1)

    # Save training curves
    h = hist.history
    def plot_curve(tr, va, title):
        if tr in h and va in h:
            plt.figure()
            plt.plot(h[tr], label=tr)
            plt.plot(h[va], label=va)
            plt.title(f'{name}: {title}')
            plt.xlabel('Epochs')
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(OUTPUT_DIR, f'{name}_{tr}.png'))
            plt.show()
            plt.close()
    plot_curve('acc','val_acc','Accuracy')
    plot_curve('loss','val_loss','Loss')
    plot_curve('auc','val_auc','ROC-AUC')
    plot_curve('prc','val_prc','PR-AUC')

    # Evaluate
    y_prob = model.predict(x_test, batch_size=BATCH_SIZE, verbose=0)
    y_pred = np.argmax(y_prob, axis=1)
    y_true = np.argmax(y_test, axis=1)
    labels = list(range(y_test.shape[1]))
    names_full = ["Glaucoma","Cataract","AMD","Hypertension","Myopia"]

    # Confusion matrix (counts)
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    plt.figure(figsize=(7,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=names_full, yticklabels=names_full)
    plt.title(f'Confusion Matrix (Test) - {name}')
    plt.xlabel('Predicted label'); plt.ylabel('True label')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'cm_{name}_counts.png'))
    plt.show()
    plt.close()
    # Save confusion matrix as CSV
    pd.DataFrame(cm, index=names_full, columns=names_full).to_csv(os.path.join(OUTPUT_DIR, f'cm_{name}.csv'))

    # Classification metrics
    report = classification_report(y_true, y_pred, labels=labels, target_names=names_full, zero_division=0, output_dict=True)
    with open(os.path.join(OUTPUT_DIR, f'classification_report_{name}.txt'),'w') as f:
        for k,v in report.items():
            f.write(f'{k}: {v}\n')

    # ROC/PR curves per class
    from sklearn.metrics import roc_curve, auc, precision_recall_curve
    present = sorted(list(set(y_true)))
    y_true_bin = tf.keras.utils.to_categorical(y_true, num_classes=y_test.shape[1])

    plt.figure(figsize=(7,6))
    for c in present:
        fpr, tpr, _ = roc_curve(y_true_bin[:,c], y_prob[:,c])
        plt.plot(fpr, tpr, label=names_full[c])
    plt.plot([0,1],[0,1],'k--',alpha=0.5)
    plt.title(f'ROC Curves - {name}')
    plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'roc_curves_{name}.png'))
    plt.close()

    plt.figure(figsize=(7,6))
    for c in present:
        precision, recall, _ = precision_recall_curve(y_true_bin[:,c], y_prob[:,c])
        plt.plot(recall, precision, label=names_full[c])
    plt.title(f'PR Curves - {name}')
    plt.xlabel('Recall'); plt.ylabel('Precision'); plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'pr_curves_{name}.png'))
    plt.close()

    # Macro AUCs
    roc_auc = roc_auc_score(y_true_bin[:,present], y_prob[:,present], average='macro', multi_class='ovr')
    pr_auc = average_precision_score(y_true_bin[:,present], y_prob[:,present], average='macro')
    res = {
        "name": name,
        "acc": report["accuracy"],
        "macro_f1": report["macro avg"]["f1-score"],
        "weighted_f1": report["weighted avg"]["f1-score"],
        "roc_auc_macro": float(roc_auc),
        "pr_auc_macro": float(pr_auc),
    }
    return res, h


In [None]:
# Run comparison

# Baseline
baseline = build_effnet_baseline(image_size=IMAGE_SIZE, backbone=BACKBONE, num_classes=num_classes, dropout=0.4)
res_base, hist_base = train_and_eval(baseline, "effnet_baseline", x_train, y_train, x_val, y_val, x_test, y_test)

# CBAM
cbam = build_effnet_cbam(image_size=IMAGE_SIZE, backbone=BACKBONE, num_classes=num_classes, dropout=0.4)
res_cbam, hist_cbam = train_and_eval(cbam, "effnet_cbam", x_train, y_train, x_val, y_val, x_test, y_test)

# Compare
import pandas as pd
comp = pd.DataFrame([res_base, res_cbam])
print(comp)
comp.to_csv(os.path.join(OUTPUT_DIR, "effnet_vs_cbam_metrics.csv"), index=False)


In [None]:
# Baseline EfficientNet model (no attention)

def build_effnet_baseline(image_size=224, backbone="b0", num_classes=5, dropout=0.4):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=x_in)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x_in)
    for l in base.layers:
        l.trainable = True
    x = base.output
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(192, 1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)

# CBAM variant
def build_effnet_cbam(image_size=224, backbone="b0", num_classes=5, dropout=0.4):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=x_in)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x_in)
    for l in base.layers:
        l.trainable = True
    x = base.output
    x = layers.BatchNormalization()(x)
    x = CBAM()(x)
    x = layers.Conv2D(192, 1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)
