In [None]:
import os
import re
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import seaborn as sns

from typing import Dict
from sklearn.model_selection import train_test_split

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Kaggle-style dataset root (keep this constant in code)
DATASET_ROOT = "/kaggle/input/ocular-disease-recognition-odir5k"
IMAGES_DIR = os.path.join(DATASET_ROOT, "preprocessed_images")
CSV_PATH = os.path.join(DATASET_ROOT, "full_df.csv")

# Image and training parameters
IMAGE_SIZE: int = 224
NUM_CLASSES: int = 5  # G, C, A, H, M
BATCH_SIZE: int = 32
EPOCHS: int = 35
LEARNING_RATE: float = 2e-4

# Class mappings (exclude 'D' and 'N')
CLASS_SHORT_TO_FULL: Dict[str, str] = {
    "G": "Glaucoma",
    "C": "Cataract",
    "A": "AMD",
    "H": "Hypertension",
    "M": "Myopia",
}
CLASS_TO_INDEX: Dict[str, int] = {k: i for i, k in enumerate(CLASS_SHORT_TO_FULL.keys())}
INDEX_TO_CLASS: Dict[int, str] = {v: k for k, v in CLASS_TO_INDEX.items()}

print("Using dataset root:", DATASET_ROOT)


In [None]:
# Load CSV and filter classes
raw_df = pd.read_csv(CSV_PATH)
raw_df["class"] = raw_df["labels"].apply(lambda x: " ".join(re.findall("[A-Za-z]+", str(x))).strip())
raw_df = raw_df[raw_df["class"].isin(CLASS_SHORT_TO_FULL.keys())]
raw_df["class_idx"] = raw_df["class"].map(CLASS_TO_INDEX)

assert raw_df["filename"].notna().all()
assert raw_df["class_idx"].notna().all()

print("Class distribution (kept):")
print(raw_df["class"].value_counts().sort_index())

train_df, temp_df = train_test_split(raw_df, test_size=0.3, random_state=SEED, stratify=raw_df["class_idx"])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=SEED, stratify=temp_df["class_idx"])

for name, df in [("train", train_df), ("val", val_df), ("test", test_df)]:
    print(name, df.shape, df["class"].value_counts().sort_index().to_dict())


In [None]:
# TF Dataset pipeline

def decode_image(path: tf.Tensor) -> tf.Tensor:
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE), antialias=True)
    return image

@tf.function
def build_path(filename: tf.Tensor) -> tf.Tensor:
    return tf.strings.join([IMAGES_DIR, "/", filename])

@tf.function
def one_hot(label_idx: tf.Tensor) -> tf.Tensor:
    return tf.one_hot(label_idx, NUM_CLASSES, dtype=tf.float32)

@tf.function
def augment(image: tf.Tensor) -> tf.Tensor:
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.05)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return tf.clip_by_value(image, 0.0, 1.0)


def make_dataset(df: pd.DataFrame, training: bool) -> tf.data.Dataset:
    filenames = df["filename"].values.astype(str)
    labels = df["class_idx"].values.astype(np.int32)

    ds = tf.data.Dataset.from_tensor_slices((filenames, labels))
    ds = ds.shuffle(len(df), seed=SEED) if training else ds

    def _load_map(fname, lbl):
        path = build_path(fname)
        img = decode_image(path)
        if training:
            img = augment(img)
        return img, one_hot(lbl)

    ds = ds.map(_load_map, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = make_dataset(train_df, training=True)
val_ds = make_dataset(val_df, training=False)
test_ds = make_dataset(test_df, training=False)

train_ds, val_ds, test_ds


In [None]:
# CBAM attention + ResNet50
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import ResNet50


def channel_attention(inputs: tf.Tensor, reduction_ratio: int = 8) -> tf.Tensor:
    channels = inputs.shape[-1]
    avg_pool = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
    max_pool = tf.reduce_max(inputs, axis=[1, 2], keepdims=True)

    shared = layers.Dense(channels // reduction_ratio, activation="relu", use_bias=True)
    shared2 = layers.Dense(channels, activation="sigmoid", use_bias=True)

    avg_out = shared2(shared(avg_pool))
    max_out = shared2(shared(max_pool))
    scale = avg_out + max_out
    return inputs * scale


def spatial_attention(inputs: tf.Tensor, kernel_size: int = 7) -> tf.Tensor:
    avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
    max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
    concat = tf.concat([avg_pool, max_pool], axis=-1)
    scale = layers.Conv2D(1, kernel_size=kernel_size, padding="same", activation="sigmoid")(concat)
    return inputs * scale


def cbam_block(inputs: tf.Tensor, reduction_ratio: int = 8, kernel_size: int = 7) -> tf.Tensor:
    x = channel_attention(inputs, reduction_ratio=reduction_ratio)
    x = spatial_attention(x, kernel_size=kernel_size)
    return x


base = ResNet50(include_top=False, weights="imagenet", input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
base.trainable = True

inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
x = base(inputs, training=True)
x = cbam_block(x, reduction_ratio=8, kernel_size=7)

x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)

model = Model(inputs, outputs)
model.summary()


In [None]:
# Compile and train
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(name="acc"),
    tf.keras.metrics.AUC(name="auc", multi_label=False, num_thresholds=200),
    tfa.metrics.F1Score(num_classes=NUM_CLASSES, average="weighted", name="f1"),
]

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss="categorical_crossentropy",
    metrics=METRICS,
)

checkpoint_path = "/kaggle/working/resnet50_cbam_best.h5"
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_best_only=True, monitor="val_acc", mode="max"),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, verbose=1, min_lr=1e-6),
    tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True, monitor="val_acc", mode="max"),
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
)



In [None]:
# Reusable confusion matrix plotting + evaluation
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

def plot_confusion_matrix(cm: np.ndarray, class_names: list, title: str):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(title)
    plt.tight_layout()
    plt.show()

hist = history.history
print({k: (float(np.max(v)) if 'loss' not in k else float(np.min(v))) for k, v in hist.items()})

# Evaluate on test
metrics = model.evaluate(test_ds, return_dict=True)
print("Test metrics:", metrics)

# Predictions for CM and report
y_true = []
y_prob = []
for images, labels in test_ds:
    y_true.append(np.argmax(labels.numpy(), axis=1))
    y_prob.append(model.predict(images, verbose=0))

y_true = np.concatenate(y_true, axis=0)
y_prob = np.concatenate(y_prob, axis=0)
y_pred = np.argmax(y_prob, axis=1)

cm = confusion_matrix(y_true, y_pred)
class_names = [CLASS_SHORT_TO_FULL[k] for k in CLASS_SHORT_TO_FULL.keys()]
plot_confusion_matrix(cm, class_names, title='Confusion Matrix (Test)')

report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
print(report)

try:
    auc_ovr = roc_auc_score(y_true, y_prob, multi_class='ovr')
    print("ROC-AUC (OvR):", auc_ovr)
except Exception as e:
    print("ROC-AUC calculation skipped:", e)

# Plot training curves
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(hist['acc'], label='train acc')
plt.plot(hist['val_acc'], label='val acc')
plt.legend(); plt.title('Accuracy')

plt.subplot(1,2,2)
plt.plot(hist['loss'], label='train loss')
plt.plot(hist['val_loss'], label='val loss')
plt.legend(); plt.title('Loss')
plt.show()
