## Attention-based ODIR Classifier: EfficientNet + Transformer Head

This notebook trains an attention-augmented image classifier for ODIR using an EfficientNet backbone and a lightweight Transformer encoder over spatial tokens.

- Backbone: EfficientNetB0 (switchable to B3)
- Head: Transformer encoder with multi-head self-attention
- Metrics: accuracy, weighted F1, ROC-AUC, PR-AUC
- Saves: best model, history, confusion matrices, classification report


In [None]:
# Augmentations and EfficientNet preprocessing (self-contained imports)
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications.efficientnet import preprocess_input as effnet_preprocess

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
], name="augment")


In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import cv2

try:
    import tensorflow_addons as tfa
    USE_TFA = True
except Exception:
    tfa = None
    USE_TFA = False

from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB3
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

# Kaggle input directory (ODIR-5K structure)
DATA_DIR = "/kaggle/input/ocular-disease-recognition-odir5k"
OUTPUT_DIR = "/kaggle/working"

IMAGE_SIZE = 224
BACKBONE = "b0"  # "b0" or "b3"
BATCH_SIZE = 16
EPOCHS = 40
SEED = 42
RUN_TRANSFORMER = False  # guard to skip transformer training to save memory

tf.keras.utils.set_random_seed(SEED)

# Enable mixed precision to reduce memory
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
except Exception:
    pass

# Disable XLA JIT to reduce GPU timer warnings on some Kaggle GPUs
try:
    tf.config.optimizer.set_jit(False)
except Exception:
    pass

NUM_WORKERS = 2
AUTOTUNE = tf.data.AUTOTUNE


In [None]:
# Robust stratified split ensuring each class appears in val/test
from sklearn.model_selection import StratifiedShuffleSplit

def stratified_split_with_min(df, label_col="label", val_size=0.15, test_size=0.15, seed=SEED, attempts=100):
    sss1 = StratifiedShuffleSplit(n_splits=attempts, test_size=val_size+test_size, random_state=seed)
    labels = df[label_col].values
    for train_idx, temp_idx in sss1.split(df, labels):
        train_df_cand = df.iloc[train_idx]
        temp_df_cand = df.iloc[temp_idx]
        sss2 = StratifiedShuffleSplit(n_splits=1, test_size=test_size/(val_size+test_size), random_state=seed)
        for val_idx, test_idx in sss2.split(temp_df_cand, temp_df_cand[label_col].values):
            val_df_cand = temp_df_cand.iloc[val_idx]
            test_df_cand = temp_df_cand.iloc[test_idx]
            if set(val_df_cand[label_col].unique()) == set(TARGET) and set(test_df_cand[label_col].unique()) == set(TARGET):
                return train_df_cand.reset_index(drop=True), val_df_cand.reset_index(drop=True), test_df_cand.reset_index(drop=True)
    print("Warning: could not guarantee all classes in val/test; using best-effort stratified split.")
    return train_df_cand.reset_index(drop=True), val_df_cand.reset_index(drop=True), test_df_cand.reset_index(drop=True)

# Override previous split with robust one
train_df, val_df, test_df = stratified_split_with_min(df, label_col="label", val_size=0.15, test_size=0.15, seed=SEED)
print("Counts:", len(train_df), len(val_df), len(test_df))
print("Val class counts:\n", val_df["label"].value_counts())
print("Test class counts:\n", test_df["label"].value_counts())


In [None]:
# Build dataset from ODIR-5K file structure
# We will use single-label subset for 5-class case: G, C, A, H, M

import re

ODIR_DIR = os.path.join(DATA_DIR, "ODIR-5K", "ODIR-5K")
EXCEL_PATH = os.path.join(ODIR_DIR, "data.xlsx")
TRAIN_IMG_DIR = os.path.join(ODIR_DIR, "Training Images")
TEST_IMG_DIR = os.path.join(ODIR_DIR, "Testing Images")

assert os.path.exists(EXCEL_PATH), "data.xlsx not found at expected path"

# Read metadata
meta = pd.read_excel(EXCEL_PATH)

# Helper: find column containing all substrings (case-insensitive)
def find_col(df, substrings):
    subs = [s.lower() for s in substrings]
    for c in df.columns:
        lc = str(c).lower()
        if all(s in lc for s in subs):
            return c
    return None

left_img_col = find_col(meta, ["left", "fundus"]) or find_col(meta, ["left", "image"])
right_img_col = find_col(meta, ["right", "fundus"]) or find_col(meta, ["right", "image"])
left_diag_col = find_col(meta, ["left", "diagn"]) or find_col(meta, ["left", "keyword"])
right_diag_col = find_col(meta, ["right", "diagn"]) or find_col(meta, ["right", "keyword"])

assert left_img_col and right_img_col, "Could not locate Left/Right Fundus columns in data.xlsx"

# Map diagnosis text to short labels
KEYWORD_TO_SHORT = {
    "glaucoma": "G",
    "cataract": "C",
    "amd": "A",
    "age-related macular degeneration": "A",
    "age related macular degeneration": "A",
    "hypertension": "H",
    "myopia": "M",
    "normal": "N",
    "diabetic retinopathy": "D",
    "dr": "D",
    "other": "O",
    "others": "O",
}

def text_to_labels(text):
    if not isinstance(text, str):
        return []
    text_l = text.lower()
    labels = set()
    # keyword search
    for k, s in KEYWORD_TO_SHORT.items():
        if k in text_l:
            labels.add(s)
    # bracket short labels like ['G'] fallback
    labels.update(re.findall(r"[DGCAHMNO]", text))
    return sorted(labels)

# Build records
records = []
for _, row in meta.iterrows():
    for side, img_col, diag_col in [
        ("L", left_img_col, left_diag_col),
        ("R", right_img_col, right_diag_col),
    ]:
        fname = row.get(img_col)
        if isinstance(fname, str) and len(fname) > 0:
            diag_text = row.get(diag_col) if diag_col in meta.columns else None
            short_labels = text_to_labels(diag_text)
            if len(short_labels) == 0:
                continue
            records.append({
                "filename": fname,
                "labels": short_labels,
            })

df = pd.DataFrame.from_records(records)

# Keep only single-label images from target 5 classes
TARGET = ["G","C","A","H","M"]
if len(df) == 0:
    raise RuntimeError("No labeled images parsed from data.xlsx; please verify column names in your ODIR file.")

df["labels"] = df["labels"].apply(lambda lst: [l for l in lst if l in TARGET])
df = df[df["labels"].apply(len) == 1].copy()
df["label"] = df["labels"].apply(lambda lst: lst[0])

class_to_idx = {c:i for i,c in enumerate(TARGET)}
num_classes = len(TARGET)

# Resolve file path: prefer Training, else Testing
paths = []
for fname in df["filename"].values:
    p = os.path.join(TRAIN_IMG_DIR, fname)
    if not os.path.exists(p):
        p = os.path.join(TEST_IMG_DIR, fname)
    paths.append(p)
df["path"] = paths

# Filter only existing files
df = df[df["path"].apply(os.path.exists)].reset_index(drop=True)

print("Samples per class:")
print(df["label"].value_counts())

# Train/val/test split stratified
from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(df, test_size=0.3, random_state=SEED, stratify=df["label"]) 
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=SEED, stratify=temp_df["label"]) 

print("Counts:", len(train_df), len(val_df), len(test_df))


In [None]:
# Image loader
def load_and_preprocess(path, image_size=IMAGE_SIZE):
    img = cv2.imread(path)
    if img is None:
        # create a black image placeholder if missing
        img = np.zeros((image_size, image_size, 3), dtype=np.uint8)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (image_size, image_size))
    # Keep in [0,255] float; EfficientNet preprocess handles scaling
    img = img.astype("float32")
    return img

# Build numpy arrays (simple and memory-friendly for Kaggle scales)

def df_to_arrays(df):
    xs = np.stack([load_and_preprocess(p) for p in df["path"].values], axis=0)
    ys = np.array([class_to_idx[c] for c in df["label"].values])
    ys = tf.keras.utils.to_categorical(ys, num_classes=num_classes)
    return xs, ys

x_train, y_train = df_to_arrays(train_df)
x_val, y_val = df_to_arrays(val_df)
x_test, y_test = df_to_arrays(test_df)

print("Array shapes:")
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)


In [None]:
# Optionally override with pre-saved numpy arrays if present in this Kaggle dataset
npy_files = [
    os.path.join(DATA_DIR, "x_train.npy"),
    os.path.join(DATA_DIR, "y_train.npy"),
    os.path.join(DATA_DIR, "x_val.npy"),
    os.path.join(DATA_DIR, "y_val.npy"),
    os.path.join(DATA_DIR, "x_test.npy"),
    os.path.join(DATA_DIR, "y_test.npy"),
]

if all(os.path.exists(p) for p in npy_files):
    x_train = np.load(npy_files[0]).astype("float32") / 255.0
    y_train = np.load(npy_files[1])
    x_val = np.load(npy_files[2]).astype("float32") / 255.0
    y_val = np.load(npy_files[3])
    x_test = np.load(npy_files[4]).astype("float32") / 255.0
    y_test = np.load(npy_files[5])
    num_classes = int(y_train.shape[1])
    print("Loaded pre-saved numpy arrays from Kaggle dataset.")
    print("Shapes:")
    print(" x_train:", x_train.shape)
    print(" y_train:", y_train.shape)
    print(" x_val:", x_val.shape)
    print(" y_val:", y_val.shape)
    print(" x_test:", x_test.shape)
    print(" y_test:", y_test.shape)
else:
    print("Using arrays built from ODIR file structure.")


In [None]:
def build_model_cbam(image_size=224, backbone="b0", num_classes=5, dropout=0.3):
    tf.keras.backend.clear_session()
    inputs = layers.Input(shape=(image_size, image_size, 3))
    x_in = data_augmentation(inputs)
    x_in = effnet_preprocess(x_in)
    backbone_model = build_backbone(x_in, backbone=backbone, image_size=image_size)
    x = backbone_model.output
    x = layers.BatchNormalization()(x)
    x = CBAM()(x)
    x = layers.Conv2D(192, kernel_size=1, activation="relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(192, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)  # force float32 output under mixed precision
    model = models.Model(inputs, outputs)
    return model


In [None]:
# CBAM attention module for CNN feature maps
@tf.keras.utils.register_keras_serializable(package="custom")
class CBAM(layers.Layer):
    def __init__(self, reduction_ratio: int = 16, kernel_size: int = 7, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
        self.kernel_size = kernel_size

    def build(self, input_shape):
        channels = int(input_shape[-1])
        hidden = max(channels // self.reduction_ratio, 1)
        self.mlp = tf.keras.Sequential([
            layers.Dense(hidden, activation="relu"),
            layers.Dense(channels)
        ])
        self.spatial_conv = layers.Conv2D(1, kernel_size=self.kernel_size, padding="same", activation="sigmoid")
        super().build(input_shape)

    def call(self, x):
        # Channel attention
        avg_pool = tf.reduce_mean(x, axis=[1,2], keepdims=True)
        max_pool = tf.reduce_max(x, axis=[1,2], keepdims=True)
        mlp_avg = self.mlp(layers.Flatten()(avg_pool))
        mlp_max = self.mlp(layers.Flatten()(max_pool))
        channel_attn = tf.nn.sigmoid(mlp_avg + mlp_max)
        channel_attn = tf.reshape(channel_attn, (-1,1,1,tf.shape(x)[-1]))
        x = x * channel_attn
        # Spatial attention
        avg_pool_sp = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_pool_sp = tf.reduce_max(x, axis=-1, keepdims=True)
        sp = tf.concat([avg_pool_sp, max_pool_sp], axis=-1)
        spatial_attn = self.spatial_conv(sp)
        x = x * spatial_attn
        return x

    def get_config(self):
        cfg = super().get_config()
        cfg.update({"reduction_ratio": self.reduction_ratio, "kernel_size": self.kernel_size})
        return cfg


In [None]:
def build_backbone(input_tensor, backbone="b0", image_size=224):
    if backbone == "b3":
        base = EfficientNetB3(include_top=False, weights="imagenet", input_tensor=input_tensor)
    else:
        base = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=input_tensor)
    for layer in base.layers:
        layer.trainable = True
    return base

@tf.keras.utils.register_keras_serializable(package="custom")
class PositionalEmbedding(layers.Layer):
    def __init__(self, num_patches: int, projection_dim: int, **kwargs):
        super().__init__(**kwargs)
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.pos_emb = self.add_weight(
            name="pos_emb", shape=(1, num_patches, projection_dim), initializer="random_normal"
        )

    def call(self, x):  # x: (batch, num_patches, projection_dim)
        return x + self.pos_emb

    def get_config(self):
        config = super().get_config()
        config.update({
            "num_patches": self.num_patches,
            "projection_dim": self.projection_dim,
        })
        return config

@tf.keras.utils.register_keras_serializable(package="custom")
class TransformerEncoder(layers.Layer):
    def __init__(self, projection_dim: int, num_heads: int, mlp_dim: int, dropout: float=0.1, **kwargs):
        super().__init__(**kwargs)
        self.projection_dim = projection_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout)
        self.drop1 = layers.Dropout(dropout)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = tf.keras.Sequential([
            layers.Dense(mlp_dim, activation=tf.nn.gelu),
            layers.Dropout(dropout),
            layers.Dense(projection_dim),
            layers.Dropout(dropout),
        ])

    def call(self, x, training=False):
        h = self.norm1(x)
        h = self.attn(h, h, training=training)
        x = x + self.drop1(h, training=training)
        h2 = self.norm2(x)
        h2 = self.mlp(h2, training=training)
        return x + h2

    def get_config(self):
        config = super().get_config()
        config.update({
            "projection_dim": self.projection_dim,
            "num_heads": self.num_heads,
            "mlp_dim": self.mlp_dim,
            "dropout": self.dropout,
        })
        return config

def build_model(image_size=224, backbone="b0", num_classes=5, transformer_layers=2, projection_dim=256, num_heads=4, mlp_dim=512, dropout=0.2):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    # CNN backbone
    backbone_model = build_backbone(inputs, backbone=backbone, image_size=image_size)
    # Feature map: (B, H', W', C')
    features = backbone_model.output
    features = layers.BatchNormalization()(features)

    # Project channel dimension to projection_dim
    proj = layers.Conv2D(projection_dim, kernel_size=1, padding="same")(features)
    # Flatten spatial to tokens
    tokens = layers.Reshape((-1, projection_dim))(proj)  # (B, N, D)

    # Positional embedding
    num_patches = (proj.shape[1] or (image_size // 7)) * (proj.shape[2] or (image_size // 7))
    tokens = PositionalEmbedding(num_patches=num_patches, projection_dim=projection_dim)(tokens)

    # Transformer encoder stack
    x = tokens
    for _ in range(transformer_layers):
        x = TransformerEncoder(projection_dim=projection_dim, num_heads=num_heads, mlp_dim=mlp_dim, dropout=dropout)(x)

    # Token pooling
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)

    # Classification head
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs, outputs)
    return model


In [None]:
# Suppress Keras build warnings by providing trivial build methods

def _attach_trivial_build(cls):
    if not hasattr(cls, "build") or cls.build is layers.Layer.build:
        def build(self, input_shape):
            self.built = True
        cls.build = build

_attach_trivial_build(PositionalEmbedding)
_attach_trivial_build(TransformerEncoder)


In [None]:
# Build model (switch to CBAM head)
model = build_model_cbam(
    image_size=IMAGE_SIZE,
    backbone=BACKBONE,
    num_classes=num_classes,
    dropout=0.4,
)
model.summary()

# Metrics
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(name="acc"),
    tf.keras.metrics.AUC(name="auc"),
    tf.keras.metrics.AUC(name="prc", curve="PR"),
]
if USE_TFA:
    METRICS.append(tfa.metrics.F1Score(num_classes=num_classes, average="weighted", name="f1"))

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="categorical_crossentropy",
    metrics=METRICS,
)

# Callbacks
ckpt_path = os.path.join(OUTPUT_DIR, "best_attention_efficientnet.keras")
callbacks = [
    ModelCheckpoint(ckpt_path, save_best_only=True, monitor="val_acc", mode="max"),
    ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6, verbose=1),
    EarlyStopping(patience=10, restore_best_weights=True, monitor="val_acc", mode="max", verbose=1),
]

_ = model.predict(x_train[:4], verbose=0)

history = model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1,
)

import pickle
with open(os.path.join(OUTPUT_DIR, "attention_history.pkl"), "wb") as f:
    pickle.dump(history.history, f)


In [None]:
# Optionally build transformer model (memory heavy)
if RUN_TRANSFORMER:
    model = build_model(
        image_size=IMAGE_SIZE,
        backbone=BACKBONE,
        num_classes=num_classes,
        transformer_layers=3,
        projection_dim=384 if BACKBONE == "b3" else 256,
        num_heads=6 if BACKBONE == "b3" else 4,
        mlp_dim=768 if BACKBONE == "b3" else 512,
        dropout=0.3,
    )
    model.summary()

# Metrics
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(name="acc"),
    tf.keras.metrics.AUC(name="auc"),
    tf.keras.metrics.AUC(name="prc", curve="PR"),
]
if USE_TFA:
    METRICS.append(tfa.metrics.F1Score(num_classes=num_classes, average="weighted", name="f1"))

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="categorical_crossentropy",
    metrics=METRICS,
)

# Callbacks
ckpt_path = os.path.join(OUTPUT_DIR, "best_attention_efficientnet.keras")
callbacks = [
    ModelCheckpoint(ckpt_path, save_best_only=True, monitor="val_acc", mode="max"),
    ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6, verbose=1),
    EarlyStopping(patience=10, restore_best_weights=True, monitor="val_acc", mode="max", verbose=1),
]

# Warm-up forward pass to stabilize cuDNN timers
_ = model.predict(x_train[:4], verbose=0)

history = model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1,
)

# Save history
import pickle
with open(os.path.join(OUTPUT_DIR, "attention_history.pkl"), "wb") as f:
    pickle.dump(history.history, f)


In [None]:
# Evaluation
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import seaborn as sns

# Load best model
custom_objects = {"PositionalEmbedding": PositionalEmbedding, "TransformerEncoder": TransformerEncoder}
best_model = tf.keras.models.load_model(ckpt_path, custom_objects=custom_objects, compile=False)
best_model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=METRICS)

# Evaluate
eval_results = best_model.evaluate(x_test, y_test, verbose=0)
print({m.name: v for m, v in zip(best_model.metrics, eval_results[1:])})

# Predictions
y_prob = best_model.predict(x_test, batch_size=BATCH_SIZE, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
y_true = np.argmax(y_test, axis=1)

# Confusion matrix with full class names and normalized percentages
labels = list(range(num_classes))
cm = confusion_matrix(y_true, y_pred, labels=labels)
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
class_names_full = ["Glaucoma","Cataract","AMD","Hypertension","Myopia"] if num_classes == 5 else [f"Class_{i}" for i in labels]
plt.figure(figsize=(7,6))
sns.heatmap(cm_norm*100, annot=True, fmt='.2f', cmap='Blues', xticklabels=class_names_full, yticklabels=class_names_full)
plt.title('Confusion Matrix (Test)')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png'))
plt.show()

# Classification report with fixed label set
report = classification_report(y_true, y_pred, labels=labels, target_names=class_names_full, zero_division=0)
print(report)
with open(os.path.join(OUTPUT_DIR, 'classification_report.txt'), 'w') as f:
    f.write(report)

# ROC-AUC and PR-AUC restricted to present classes
present = sorted(list(set(y_true)))
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    y_true_bin = tf.keras.utils.to_categorical(y_true, num_classes=num_classes)
    roc_auc = roc_auc_score(y_true_bin[:,present], y_prob[:,present], average='macro', multi_class='ovr')
    pr_auc = average_precision_score(y_true_bin[:,present], y_prob[:,present], average='macro')
    print('ROC-AUC (macro, OvR):', round(roc_auc, 4))
    print('PR-AUC (macro):', round(pr_auc, 4))
except Exception as e:
    print('ROC/PR AUC could not be computed:', e)

# Plot training curves
hist = history.history
for key_pair in [("acc","val_acc"),("auc","val_auc"),("prc","val_prc"),("loss","val_loss")]:
    tr, va = key_pair
    if tr in hist and va in hist:
        plt.figure()
        plt.plot(hist[tr], label=tr)
        plt.plot(hist[va], label=va)
        plt.title(tr + ' vs ' + va)
        plt.xlabel('Epochs')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f'{tr}_curve.png'))
        plt.show()
