In [None]:
# Importing necessary libraries for ResNet synthetic data classification
import os
import shutil
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dropout, GlobalAveragePooling2D, Dense
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# Try to import tensorflow_addons, if not available, define alternative metrics
try:
    import tensorflow_addons as tfa
    from tensorflow_addons.metrics import F1Score
    HAS_TFA = True
    custom_objects = {'F1Score': F1Score}
    print("Using tensorflow_addons")
except ImportError:
    print("Warning: tensorflow_addons not available. Using alternative F1 score implementation.")
    HAS_TFA = False
    
    # Define a simple F1 score metric as alternative
    class F1Score(tf.keras.metrics.Metric):
        def __init__(self, num_classes, average='weighted', name='f1', **kwargs):
            super(F1Score, self).__init__(name=name, **kwargs)
            self.num_classes = num_classes
            self.average = average
            self.precision = tf.keras.metrics.Precision()
            self.recall = tf.keras.metrics.Recall()
            
        def update_state(self, y_true, y_pred, sample_weight=None):
            y_pred = tf.argmax(y_pred, axis=1)
            y_true = tf.argmax(y_true, axis=1)
            self.precision.update_state(y_true, y_pred, sample_weight)
            self.recall.update_state(y_true, y_pred, sample_weight)
            
        def result(self):
            p = self.precision.result()
            r = self.recall.result()
            return 2 * (p * r) / (p + r + tf.keras.backend.epsilon())
            
        def reset_state(self):
            self.precision.reset_state()
            self.recall.reset_state()
    
    custom_objects = {'F1Score': F1Score}


## ResNet50 Synthetic Data Classification


In [None]:
# Define the desired image size for model input
image_size = (224, 224)

# Define the class labels
class_labels = ['glaucoma', 'cataract', 'hyper', 'myopia', 'amd']

# Set the directory where the test images are stored
test_data_dir = '/content/gdrive/MyDrive/Ocular_Disease/sd_outputs'

# Create directories for storing images by their predicted class labels
for label in class_labels:
    label_dir = os.path.join(test_data_dir, 'resnet_results', label)
    tf.io.gfile.makedirs(label_dir)  # Ensure the directory exists for each label


In [None]:
# Load the saved ResNet model (replace with your actual model path)
try:
    with tf.keras.utils.custom_object_scope(custom_objects):
        model = load_model('./resnet_ocular_disease_model.h5')
    print("ResNet model loaded successfully!")
except:
    print("Model file not found. Creating a new ResNet50 model...")
    
    # Create a new ResNet50 model if the saved model is not available
    resnet = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
    
    for layer in resnet.layers:
        layer.trainable = True
    
    model = Sequential()
    model.add(resnet)
    model.add(Dropout(0.5))
    model.add(GlobalAveragePooling2D())
    model.add(Dense(512, activation="relu"))
    model.add(Dense(256, activation="relu"))
    model.add(Dense(128, activation="relu"))
    model.add(Dense(5, activation="softmax"))
    
    # Compile the model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss="categorical_crossentropy",
        metrics=[
            tf.keras.metrics.CategoricalAccuracy(name="acc"),
            tfa.metrics.F1Score(num_classes=5, average="weighted", name="f1"),
            tf.keras.metrics.AUC(name="auc"),
        ]
    )


## Classification of Synthetic Data


In [None]:
# Iterate over each class folder
for class_label in class_labels:
    # Get all images from the class directory
    class_dir = os.path.join(test_data_dir, class_label)
    test_images = os.listdir(class_dir)  # Get all image filenames from the current class folder

    print(f"Processing {class_label} images...")
    
    # Process each test image, make predictions, and save them in the corresponding directory
    for image_file in test_images:
        # Load and preprocess the image
        image_path = os.path.join(class_dir, image_file)
        image = load_img(image_path, target_size=image_size)  # Resize image
        image_array = img_to_array(image)  # Convert image to an array
        image_tensor = tf.expand_dims(image_array, 0)  # Add batch dimension

        # Get model's prediction for the image
        outputs = model.predict(image_tensor)
        predicted = tf.argmax(outputs, axis=1)  # Get index of predicted class
        predicted_label = class_labels[predicted.numpy()[0]]  # Get the label name based on prediction

        # Print the prediction result for reference (only for first few images to avoid spam)
        if len(test_images) <= 10 or test_images.index(image_file) < 5:
            print(f'Image: {image_file}, Predicted class: {predicted_label}')

        # Move the image to the corresponding folder for its predicted class
        destination_dir = os.path.join(test_data_dir, 'resnet_results', predicted_label)
        destination_path = os.path.join(destination_dir, image_file)  # Define the path to save the image
        shutil.copyfile(image_path, destination_path)  # Copy image to the appropriate folder

print("Classification completed!")


## Confusion Matrix and Performance Analysis


In [None]:
from tensorflow.keras.preprocessing import image
from sklearn.metrics import confusion_matrix, classification_report
import itertools

# Specify the path to the test images directory
test_dir = '/content/gdrive/MyDrive/Ocular_Disease/sd_outputs/try'  # Directory where test images are stored

# Load the dataset CSV file which contains the true labels for the test images
dataset_file = '/content/gdrive/MyDrive/Ocular_Disease/true_labels.csv'
df = pd.read_csv(dataset_file)  # Read the CSV into a DataFrame

# Extract the image filenames and true labels from the dataset
image_files = df['image'].values  # List of image filenames
true_labels = df['true_label'].values  # List of corresponding true labels

# Create lists to store predictions and predicted labels
predictions = []  # List to store model's raw predictions
predicted_labels = []  # List to store the predicted class labels

print("Performing inference on test images...")

# Perform inference for each image in the directory
for i, image_file in enumerate(image_files):
    if i % 100 == 0:
        print(f"Processing image {i+1}/{len(image_files)}")
    
    # Load and preprocess the test image
    image_path = os.path.join(test_dir, image_file)  # Get the full image path
    img = image.load_img(image_path, target_size=(224, 224))  # Load and resize the image
    x = image.img_to_array(img)  # Convert image to array
    x = np.expand_dims(x, axis=0)  # Add batch dimension

    # Perform the inference (predict the class for the image)
    prediction = model.predict(x)  # Get the model's output
    predicted_label_index = np.argmax(prediction)  # Get index of the predicted class (class with highest probability)
    predicted_label = class_labels[predicted_label_index]  # Convert index to class label

    # Append the prediction and predicted label to the respective lists
    predictions.append(prediction)  # Store raw prediction
    predicted_labels.append(predicted_label)  # Store predicted class label

# Convert the lists to numpy arrays for easier processing later
predictions = np.array(predictions)  # Convert list of raw predictions to numpy array
predicted_labels = np.array(predicted_labels)  # Convert list of predicted labels to numpy array

print("Inference completed!")


In [None]:
# Define a function to plot the confusion matrix
def plot_confusion_matrix(cm, classes, normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalize the matrix to percentages
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)  # Plot the confusion matrix as an image
    plt.title(title)  # Set the title for the plot
    plt.colorbar()  # Add a color bar
    tick_marks = np.arange(len(classes))  # Set tick marks based on the number of classes
    plt.xticks(tick_marks, classes, rotation=45)  # Set x-axis labels with class names
    plt.yticks(tick_marks, classes)  # Set y-axis labels with class names

    fmt = '.2f' if normalize else 'd'  # Format values as percentage or integer
    thresh = cm.max() / 2.  # Determine threshold for text color based on matrix values
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):  # Loop over matrix cells
        plt.text(j, i, format(cm[i, j], fmt),  # Annotate matrix with values
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")  # Use white or black text for contrast

    plt.tight_layout()  # Ensure everything fits without overlap
    plt.ylabel('True label')  # Label for the y-axis
    plt.xlabel('Predicted label')  # Label for the x-axis

# Compute the confusion matrix comparing true labels vs predicted labels
conf_mat = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix for the true vs predicted labels
plt.figure(figsize=(10, 5))  # Set the figure size
plt.grid(False)  # Disable grid lines
plot_confusion_matrix(conf_mat, classes=['glaucoma', 'cataract', 'hyper', 'myopia', 'amd'], normalize=False)
plt.title('ResNet50 Synthetic Data Classification - Confusion Matrix')
plt.show()


In [None]:
# Generate a classification report which includes precision, recall, and F1-score
report = classification_report(
    true_labels, predicted_labels,  # Actual labels and predicted labels
    target_names=['glaucoma', 'cataract', 'hyper', 'myopia', 'amd']  # Class names
)

print("ResNet50 Synthetic Data Classification Report:")
print(report)
