In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import joblib 
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input, Concatenate,
    GlobalAveragePooling2D, GlobalAveragePooling1D, BatchNormalization, Activation
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications import EfficientNetB4
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import transformers
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)
from transformers import TFSwinModel


In [None]:
# --- 1. DATA LOADING AND PREPROCESSING ---
dataset_path = "/kaggle/input/minor-dataset/Data_Minor"
classes = ["HR", "DR", "RVO"]
img_size = (224, 224)

def apply_clahe(image):
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    l = clahe.apply(l)
    lab = cv2.merge((l, a, b))
    return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

def adjust_gamma(image, gamma=1.0):
    invGamma = 1.0 / gamma
    table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
    return cv2.LUT(image, table)

def adaptive_gamma_correction(image):
    mean_intensity = np.mean(image) / 255.0
    gamma = 1.2 if mean_intensity < 0.5 else 0.9
    return adjust_gamma(image, gamma)

def preprocess_image(image):
    image = apply_clahe(image)
    image = adaptive_gamma_correction(image)
    return image.astype(np.float32) / 255.0

def load_images():
    images, labels = [], []
    for label, cls in enumerate(classes):
        class_dir = os.path.join(dataset_path, cls)
        for img_name in sorted(os.listdir(class_dir)):
            img_path = os.path.join(class_dir, img_name)
            try:
                image = cv2.imread(img_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = cv2.resize(image, img_size)
                image = preprocess_image(image)
                images.append(image)
                labels.append(label)
            except Exception as e:
                pass
    return np.array(images), np.array(labels)

images, labels = load_images()
print(f"Dataset Loaded: {images.shape}, Labels: {labels.shape}")


In [None]:
# --- 2. 70/10/20 TRAIN/VAL/TEST SPLIT ---
num_classes = len(classes)
input_shape = (224, 224, 3)
batch_size = 8
epochs = 30

train_val_images, test_images, train_val_labels, test_labels = train_test_split(
    images, labels, test_size=0.2, stratify=labels, random_state=42
)
train_images, val_images, train_labels, val_labels = train_test_split(
    train_val_images, train_val_labels, test_size=0.125, stratify=train_val_labels, random_state=42
)
print(f"Training: {len(train_images)}, Validation: {len(val_images)}, Test: {len(test_images)}")

train_labels_cat = to_categorical(train_labels, num_classes=num_classes)
val_labels_cat = to_categorical(val_labels, num_classes=num_classes)
test_labels_cat = to_categorical(test_labels, num_classes=num_classes)


In [None]:
# --- 3. MODEL DEFINITIONS ---
tf.config.run_functions_eagerly(True)

def build_cnn_feature_extractor(input_shape):
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape), MaxPooling2D(pool_size=(2, 2)),
        Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)),
        Conv2D(128, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)),
        Flatten(), Dense(128, activation='relu'), Dropout(0.5)
    ], name="cnn_feature_extractor")
    return model

class SwinWrapper(tf.keras.layers.Layer):
    def __init__(self, **kwargs): super().__init__(**kwargs); self.swin_model = None
    def build(self, input_shape): self.swin_model = TFSwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224"); super().build(input_shape)
    def call(self, inputs): x = tf.transpose(inputs, perm=[0, 3, 1, 2]); outputs = self.swin_model(x); return outputs.last_hidden_state

def build_swin_feature_extractor(input_shape):
    input_layer = Input(shape=input_shape); swin_out = SwinWrapper()(input_layer); pooled_output = GlobalAveragePooling1D()(swin_out); return Model(inputs=input_layer, outputs=pooled_output, name="swin_feature_extractor")

def build_hybrid_model(input_shape, num_classes):
    input_layer = Input(shape=input_shape, name="input_image")
    cnn_extractor = build_cnn_feature_extractor(input_shape)
    cnn_features = cnn_extractor(input_layer)
    swin_extractor = build_swin_feature_extractor(input_shape)
    swin_features = swin_extractor(input_layer)
    combined_features = Concatenate()([cnn_features, swin_features])
    x = Dense(128, activation='relu')(combined_features)
    x = Dropout(0.5)(x)
    output = Dense(num_classes, activation='softmax', name='output_layer')(x)
    model = Model(inputs=input_layer, outputs=output, name="HyRetNet")
    return model

def build_resnet(input_shape, num_classes):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = False 
    inputs = Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs, outputs, name="ResNet50")
    return model


def build_efficientnet(input_shape, num_classes):
    base_model = EfficientNetB4(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = True
    for layer in base_model.layers[:-40]:
        layer.trainable = False
        
    inputs = Input(shape=input_shape)
    x = base_model(inputs, training=True)
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x) # Added stability
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.4)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs, name="EfficientNetB4")
    return model


# --- 4. EXPERIMENT EXECUTION ---

models_to_run = {
    "HyRetNet": build_hybrid_model,
    "ResNet50": build_resnet,               
    "EfficientNetB4": build_efficientnet 
}

results_dir = "/kaggle/working/model_results/"
os.makedirs(results_dir, exist_ok=True)

for model_name, model_builder in models_to_run.items():
    print(f"\n" + "="*50)
    print(f"--- Checking Model: {model_name} ---")
    print("="*50)
    
    result_file = os.path.join(results_dir, f"results_{model_name}.joblib")
    model_file = os.path.join(results_dir, f"best_model_{model_name}.h5")

    if os.path.exists(result_file):
        print(f"Results file found for {model_name}. Skipping training.")
        continue 
    
    print(f"No results found. Training {model_name}...")
    
    model = model_builder(input_shape, num_classes)
    
    lr = 0.00005 if "EfficientNet" in model_name else 0.0001
    
    model.compile(
        optimizer=Adam(learning_rate=lr),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )
    
    checkpoint = ModelCheckpoint(model_file, monitor='val_loss', save_best_only=True, mode='min', verbose=0)
    
    # Train
    history = model.fit(
        train_images, train_labels_cat,
        epochs=epochs,
        validation_data=(val_images, val_labels_cat),
        callbacks=[checkpoint],
        batch_size=batch_size,
        verbose=1
    )
    
    # Load best model for TEST evaluation
    custom_objs = {"SwinWrapper": SwinWrapper} if "Swin" in model_name or "HyRetNet" in model_name else {}
    best_model = tf.keras.models.load_model(model_file, custom_objects=custom_objs)
        
    print(f"\nEvaluating {model_name} on the TEST SET...")
    test_loss, test_accuracy = best_model.evaluate(test_images, test_labels_cat, verbose=1)
    
    print(f"\nFINAL TEST ACCURACY for {model_name}: {test_accuracy*100:.2f}%")
    
    predictions = best_model.predict(test_images)
    predicted_labels = np.argmax(predictions, axis=1)
    
    model_result_data = {
        "history": history.history,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
        "predicted_labels": predicted_labels,
        "report": classification_report(test_labels, predicted_labels, target_names=classes, output_dict=True),
        "params": model.count_params()
    }
    
    joblib.dump(model_result_data, result_file)
    print(f"--- {model_name} Complete. ---")




In [None]:
# --- 5. PLOTTING (CLEAN STYLE) ---
print("\n\n" + "="*60)
print("--- EXPERIMENT COMPLETE: LOADING ALL RESULTS ---")
print("="*60)

results = {}
for model_name in models_to_run.keys():
    result_file = os.path.join(results_dir, f"results_{model_name}.joblib")
    if os.path.exists(result_file):
        results[model_name] = joblib.load(result_file)

# Plot everything
for model_name, data in results.items():
    history = data['history']
    
    plt.figure(figsize=(14, 5))
    
    # Subplot 1: Loss
    plt.subplot(1, 2, 1)
    plt.plot(history['loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss Curve', fontsize=14)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()
    plt.ylim(bottom=0) 

    # Subplot 2: Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['accuracy'], label='Train Accuracy')
    plt.plot(history['val_accuracy'], label='Val Accuracy')
    plt.title('Accuracy Curve', fontsize=14)
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()
    plt.ylim([0.65, 1.0]) 
    
    plt.suptitle(f'Training History for {model_name}', fontsize=16, y=1.03)
    plt.tight_layout()
    plt.show()

# Final Bar Chart
summary_data = {
    "Model": [name for name in results.keys()],
    "Test Accuracy": [data['test_accuracy'] for data in results.values()]
}
summary_df = pd.DataFrame(summary_data).sort_values(by="Test Accuracy", ascending=False)

plt.figure(figsize=(12, 7))
bars = plt.bar(summary_df["Model"], summary_df["Test Accuracy"], color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
plt.ylabel('Test Accuracy Score')
plt.title('FINAL Model Test Accuracy Comparison', fontsize=16)
plt.xticks(rotation=15)
plt.bar_label(bars, fmt='%.4f')
plt.ylim(top=plt.ylim()[1] * 1.05)
plt.show()

print("\n\n--- FINAL SUMMARY TABLE ---")
print(summary_df.to_markdown(index=False))