<a href="https://colab.research.google.com/github/radhakrishnan-omotec/fundus-repo/blob/main/Fundus_ImageClassification_Project_7_classes_IMAGE_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CNN based ResNet152 Image Classification for highest accuracy in google Colab notebook format

Below is a complete Google Colab notebook for implementing a ResNet152-based image classification model for the highest accuracy, tailored to classify 5 classes of fundus images (e.g., Diabetic Retinopathy stages) using a dataset of approximately 3,700 images. ResNet152 is chosen for its deep architecture and residual connections, which excel in extracting intricate features from complex medical images. The notebook includes data loading, preprocessing, model training, evaluation, and visualization, optimized for Colab’s GPU environment.

Since you referenced a prior context (3,700 fundus images with 5 classes), I’ll assume the dataset is structured with subfolders for each class (e.g., No_DR, Mild, Moderate, Severe, Proliferative_DR) in a Google Drive directory. Adjust paths and class names as needed.

## ResNet152 for Fundus Image Classification

# ResNet152 for Fundus Image Classification

This notebook implements a ResNet152-based Convolutional Neural Network (CNN) for classifying fundus images into 5 Diabetic Retinopathy classes using a dataset of ~3,700 images. ResNet152, with its 152 layers and residual connections, is selected for its superior accuracy in medical imaging tasks. The workflow includes data loading from Google Drive, preprocessing with augmentation, transfer learning, training on a GPU, and evaluation. The goal is to maximize classification accuracy for deployment in diagnostic applications.

### Workflow
1. Setup and import libraries.
2. Load and preprocess the dataset.
3. Define and configure ResNet152.
4. Train the model.
5. Evaluate and visualize results.

## Step 1: Setup and Import Libraries

In [None]:
# Cell 1: Setup and Imports
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet152
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import os
from google.colab import drive

# Enable GPU
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
print("TensorFlow version:", tf.__version__)
print("GPU available:", tf.test.is/gpu_available())

## Step 2: Load and Preprocess the Dataset

In [None]:
# Cell 2: Mount Google Drive and Load Data
drive.mount('/content/drive')

# Define dataset path
data_dir = '/content/drive/MyDrive/Fundus_Dataset'  # Update to your dataset path
if not os.path.exists(data_dir):
    raise Exception(f"Dataset folder {data_dir} not found.")

# Image parameters
img_height, img_width = 224, 224  # ResNet152 default input size
batch_size = 32
num_classes = 5

# Data augmentation and preprocessing
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    validation_split=0.2,
    preprocessing_function=tf.keras.applications.resnet.preprocess_input  # ResNet-specific preprocessing
)

# Training and validation generators
train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

val_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# Display class names
class_names = list(train_generator.class_indices.keys())
print("Class names:", class_names)
print("Training samples:", train_generator.samples)
print("Validation samples:", val_generator.samples)

## Step 3: Define and Configure ResNet152

In [None]:
# Cell 3: Define ResNet152 Model
def create_resnet152_model(num_classes):
    # Load pre-trained ResNet152 with ImageNet weights
    base_model = ResNet152(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

    # Freeze base model layers
    base_model.trainable = False

    # Add custom classification head
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(1024, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])

    return model

# Create and compile the model
model = create_resnet152_model(num_classes)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Model summary
model.summary()

## Step 4: Train the Model

In [None]:
# Cell 4: Train the Model
epochs = 20  # Adjust based on convergence

# Callbacks for training
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint('/content/drive/MyDrive/resnet152_fundus_best.h5',
                                       monitor='val_accuracy', save_best_only=True)
]

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    epochs=epochs,
    callbacks=callbacks
)

# Fine-tune (unfreeze some layers)
base_model = model.layers[0]
base_model.trainable = True
for layer in base_model.layers[:-20]:  # Fine-tune last 20 layers
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Continue training
fine_tune_epochs = 10
history_fine = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    epochs=fine_tune_epochs,
    callbacks=callbacks
)

# Save final model
model.save('/content/drive/MyDrive/resnet152_fundus_final.h5')

## Step 5: Evaluate and Visualize Results

In [None]:
# Cell 5: Evaluate and Visualize
# Combine histories
acc = history.history['accuracy'] + history_fine.history['accuracy']
val_acc = history.history['val_accuracy'] + history_fine.history['val_accuracy']
loss = history.history['loss'] + history_fine.history['loss']
val_loss = history.history['val_loss'] + history_fine.history['val_loss']

# Plot accuracy and loss
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Evaluate on validation set
val_loss, val_accuracy = model.evaluate(val_generator)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")

# Confusion matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
val_generator.reset()
preds = np.argmax(model.predict(val_generator), axis=1)
true_labels = val_generator.classes
cm = confusion_matrix(true_labels, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.show()

## Optional: Test a Single Image

In [None]:
# Cell 6: Test a Single Image
from tensorflow.keras.preprocessing.image import load_img, img_to_array

def predict_image(image_path):
    img = load_img(image_path, target_size=(224, 224))
    img_array = img_to_array(img)
    img_array = tf.keras.applications.resnet.preprocess_input(img_array)
    img_array = np.expand_dims(img_array, axis=0)
    pred = model.predict(img_array)
    predicted_class = class_names[np.argmax(pred)]
    return img, predicted_class

# Example usage
test_image_path = '/content/drive/MyDrive/Fundus_Dataset/Severe/sample.jpg'  # Update path
img, pred_class = predict_image(test_image_path)
plt.imshow(img)
plt.title(f"Predicted: {pred_class}")
plt.axis('off')
plt.show()

## Step 6: Convert to TensorFlow Lite for Edge Deployment

In [None]:
# Cell 6: TensorFlow Lite Conversion
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Apply default optimizations (quantization)
converter.target_spec.supported_types = [tf.float16]  # Use float16 for reduced size
tflite_model = converter.convert()

# Save the TFLite model
tflite_path = '/content/drive/MyDrive/resnet152_fundus.tflite'
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"TFLite model saved to {tflite_path}")
print(f"Size of TFLite model: {os.path.getsize(tflite_path) / (1024 * 1024):.2f} MB")

# Test TFLite inference
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Example inference on a single image
test_image = load_img('/content/drive/MyDrive/Fundus_Dataset/Severe/sample.jpg', target_size=(224, 224))
test_image_array = img_to_array(test_image)
test_image_array = tf.keras.applications.resnet.preprocess_input(test_image_array)
test_image_array = np.expand_dims(test_image_array, axis=0).astype(np.float32)

interpreter.set_tensor(input_details[0]['index'], test_image_array)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])
tflite_pred_class = class_names[np.argmax(tflite_output[0])]
print(f"TFLite Predicted Class: {tflite_pred_class}")
plt.imshow(test_image)
plt.title(f"TFLite Predicted: {tflite_pred_class}")
plt.axis('off')
plt.show()

## Step 7: Compute Advanced Evaluation Metrics

In [None]:
# Cell 7: Advanced Evaluation Metrics
val_generator.reset()
y_true = val_generator.classes
y_pred_probs = model.predict(val_generator)
y_pred = np.argmax(y_pred_probs, axis=1)

# Precision, Recall, F1-Score
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
print(f"Precision (weighted): {precision:.4f}")
print(f"Recall (weighted): {recall:.4f}")
print(f"F1-Score (weighted): {f1:.4f}")

# ROC Curve and AUC for each class
y_true_bin = label_binarize(y_true, classes=range(num_classes))
plt.figure(figsize=(10, 8))
for i in range(num_classes):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for Multi-Class Classification')
plt.legend(loc="lower right")
plt.show()

## Optional: Test a Single Image (Keras Model)

In [None]:
# Cell 8: Test a Single Image (Keras)
from tensorflow.keras.preprocessing.image import load_img, img_to_array

def predict_image(image_path):
    img = load_img(image_path, target_size=(224, 224))
    img_array = img_to_array(img)
    img_array = tf.keras.applications.resnet.preprocess_input(img_array)
    img_array = np.expand_dims(img_array, axis=0)
    pred = model.predict(img_array)
    predicted_class = class_names[np.argmax(pred)]
    return img, predicted_class

test_image_path = '/content/drive/MyDrive/Fundus_Dataset/Severe/sample.jpg'
img, pred_class = predict_image(test_image_path)
plt.imshow(img)
plt.title(f"Predicted: {pred_class}")
plt.axis('off')
plt.show()

## Notes

Notes
TFLite Size: Float16 quantization halves the model size; further reduction (e.g., int8) is possible but may trade off accuracy.
Evaluation: Advanced metrics complement accuracy and confusion matrix, offering insights into false positives/negatives, vital for fundus classification.
Dataset: Assumes 3,700 images in Fundus_Dataset with 5 subfolders. Update paths and class names if different.
Running Instructions
Upload your dataset to Google Drive.
Enable GPU in Colab (Runtime > Change runtime type > GPU).
Run cells sequentially; adjust data_dir and test_image_path as needed.
Expect high accuracy (92-95%) with fine-tuning, though TFLite may slightly reduce it due to quantization.

---

## Key Features and Notes

Key Features and Notes
ResNet152: Pre-trained on ImageNet, with a custom head for 5-class classification. Residual connections ensure high accuracy by mitigating vanishing gradients.
Transfer Learning: Initial training with frozen base layers, followed by fine-tuning of the last 20 layers for fundus-specific features.
Data Augmentation: Applied to prevent overfitting on the relatively small dataset (3,700 images).
GPU Utilization: Optimized for Colab’s GPU to handle ResNet152’s computational demands (~60M parameters, ~230 MB size).
Evaluation: Includes accuracy, loss plots, and a confusion matrix for detailed performance analysis.
Dataset: Assumes 3,700 images in /content/drive/MyDrive/Fundus_Dataset with subfolders for each class. Adjust paths if different.
Assumptions
Dataset is balanced or nearly balanced across 5 classes (e.g., ~740 images per class). If imbalanced, add class weights to the model.compile loss function.
Images are RGB fundus photographs in standard formats (e.g., JPG, PNG).
Running in Colab
Upload your dataset to Google Drive.
Copy this code into a Colab notebook.
Update data_dir and test_image_path to match your file structure.
Run cells sequentially; ensure GPU runtime is enabled (Runtime > Change runtime type > GPU).
Expected Accuracy
ResNet152 typically achieves >90% accuracy on medical imaging tasks with fine-tuning, potentially reaching 92-95% on this dataset, depending on image quality and preprocessing. The two-phase training (transfer learning + fine-tuning) maximizes performance.

---
---