Classify Diabetic Retinopathy (DR) into five-classes:

0 - No DR,
1 - Mild,
2 - Moderate,
3 - Severe,
4 - Proliferative DR

To get updated .csv file based on the filename of the folder. other rows gets deleted from the .csv file.

In [None]:
import os
import pandas as pd

# Configuration
csv_file_path = '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/trainLabels19.csv'  # Path to your CSV file
image_folder = '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19-samples'  # Path to your image dataset
output_csv_path = 'updated_trainLabels19.csv'  # Path for the updated CSV

# Step 1: Load the CSV file
df = pd.read_csv(csv_file_path)

# Step 2: Get list of actual image files in the dataset
# Assuming images are in .png or .jpg format - adjust extensions if needed
image_extensions = ('.png', '.jpg', '.jpeg')
image_files = set()
for file in os.listdir(image_folder):
    if file.lower().endswith(image_extensions):
        # Remove extension for matching with CSV
        base_name = os.path.splitext(file)[0]
        image_files.add(base_name)

# Step 3: Filter the DataFrame to keep only rows with existing images
filtered_df = df[df['id_code'].isin(image_files)]

# Step 4: Save the filtered DataFrame to a new CSV
filtered_df.to_csv(output_csv_path, index=False)

print(f"Original rows: {len(df)}, Filtered rows: {len(filtered_df)}")
print(f"Updated CSV saved to: {output_csv_path}")

Original rows: 3662, Filtered rows: 728
Updated CSV saved to: updated_trainLabels19.csv


The following scripts contains:
(i) GAN-based OCT-A generation

(ii) Training ResNet50Fusion model

(iii) Saving the model

(iv) Running inference on test data

(v) Saving predictions to CSV

In [None]:
import os
import pandas as pd
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import tensorflow as tf

# ----------------------------
# Load BVAC GAN generator
# ----------------------------
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones", trainable=True)
        self.offset = self.add_weight(name="offset", shape=(input_shape[-1],), initializer="zeros", trainable=True)

    def call(self, inputs):
        mean, var = tf.nn.moments(inputs, [1, 2], keepdims=True)
        normalized = (inputs - mean) / tf.sqrt(var + self.epsilon)
        return self.scale * normalized + self.offset

class ThresholdSEBlock(tf.keras.layers.Layer):
    def __init__(self, channels, reduction=16, threshold=0.5, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.reduction = reduction
        self.threshold = threshold
        self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(channels // reduction, activation='relu')
        self.fc2 = tf.keras.layers.Dense(channels, activation='sigmoid')

    def call(self, inputs):
        x = self.global_avg_pool(inputs)
        x = self.fc1(x)
        x = self.fc2(x)
        x = tf.where(x > self.threshold, x, tf.zeros_like(x))
        x = tf.reshape(x, [-1, 1, 1, self.channels])
        return inputs * x

generator = tf.keras.models.load_model(
    '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/generator_g.h5',
    custom_objects={'InstanceNormalization': InstanceNormalization, 'ThresholdSEBlock': ThresholdSEBlock}
)

# ----------------------------
# Image Pre/Post-Processing
# ----------------------------
def preprocess_tf(img_path):
    img = Image.open(img_path).convert("RGB").resize((256, 256))
    arr = np.array(img).astype(np.float32) / 127.5 - 1.0
    return np.expand_dims(arr, 0)

def postprocess_tf(tensor):
    tensor = (tensor[0] * 0.5 + 0.5) * 255.0
    return Image.fromarray(np.clip(tensor, 0, 255).astype(np.uint8))

def generate_synthetic_oct(images_dir, synthetic_dir):
    os.makedirs(synthetic_dir, exist_ok=True)
    for img_name in os.listdir(images_dir):
        if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            continue
        img_path = os.path.join(images_dir, img_name)
        input_tensor = preprocess_tf(img_path)
        output_tensor = generator.predict(input_tensor)
        output_img = postprocess_tf(output_tensor)
        output_img.save(os.path.join(synthetic_dir, img_name))

# ----------------------------
# Dataset Classes
# ----------------------------
class FusedDataset(Dataset):
    def __init__(self, csv_file, fundus_dir, synthetic_dir):
        self.data = pd.read_csv(csv_file)
        self.fundus_dir = fundus_dir
        self.synthetic_dir = synthetic_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_id = row['id_code']
        label = int(row['diagnosis'])
        extensions = ['.jpg', '.jpeg', '.png']
        for ext in extensions:
            fp = os.path.join(self.fundus_dir, img_id + ext)
            sp = os.path.join(self.synthetic_dir, img_id + ext)
            if os.path.exists(fp) and os.path.exists(sp):
                fundus = Image.open(fp).convert("RGB")
                synthetic = Image.open(sp).convert("RGB")
                fundus = self.transform(fundus)
                synthetic = self.transform(synthetic)
                fused = torch.cat((fundus, synthetic), dim=0)
                return fused, label
        raise FileNotFoundError(f"{img_id} not found.")

class TestDataset(Dataset):
    def __init__(self, csv_file, fundus_dir, synthetic_dir):
        self.data = pd.read_csv(csv_file)
        self.fundus_dir = fundus_dir
        self.synthetic_dir = synthetic_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_id = row['id_code']
        extensions = ['.jpg', '.jpeg', '.png']
        for ext in extensions:
            fp = os.path.join(self.fundus_dir, img_id + ext)
            sp = os.path.join(self.synthetic_dir, img_id + ext)
            if os.path.exists(fp) and os.path.exists(sp):
                fundus = Image.open(fp).convert("RGB")
                synthetic = Image.open(sp).convert("RGB")
                fundus = self.transform(fundus)
                synthetic = self.transform(synthetic)
                fused = torch.cat((fundus, synthetic), dim=0)
                return fused, img_id
        raise FileNotFoundError(f"{img_id} not found.")

# ----------------------------
# Model Definition
# ----------------------------
class ResNet50Fusion(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1.weight.data[:, :3] = base.conv1.weight.data
        self.conv1.weight.data[:, 3:] = base.conv1.weight.data.clone()
        base.conv1 = self.conv1
        base.fc = nn.Linear(2048, num_classes)
        self.model = base

    def forward(self, x):
        return self.model(x)

# ----------------------------
# Training Function
# ----------------------------
def train_model():
    fundus_dir = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19-samples"
    synthetic_dir = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/synthetic_octa"
    csv_file = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/updated_trainLabels19.csv"

    generate_synthetic_oct(fundus_dir, synthetic_dir)

    dataset = FusedDataset(csv_file, fundus_dir, synthetic_dir)
    loader = DataLoader(dataset, batch_size=16, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet50Fusion(num_classes=5).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        model.train()
        total_loss, correct = 0, 0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Acc: {correct/len(dataset):.4f}")

    torch.save(model.state_dict(), "resnet50_fused_model.pth")
    print("Model saved.")

# ----------------------------
# Inference Function
# ----------------------------
def predict(model, dataloader, device):
    model.eval()
    results = []
    with torch.no_grad():
        for inputs, img_ids in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            for img_id, pred in zip(img_ids, preds):
                results.append((img_id, pred))
    return results

def run_inference():
    test_csv = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/trainLabels19_test.csv"
    fundus_dir = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19_test"
    synthetic_dir = "/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/synthetic_test_octa"

    generate_synthetic_oct(fundus_dir, synthetic_dir)

    test_dataset = TestDataset(test_csv, fundus_dir, synthetic_dir)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet50Fusion(num_classes=5).to(device)
    model.load_state_dict(torch.load("resnet50_fused_model.pth"))

    predictions = predict(model, test_loader, device)
    pd.DataFrame(predictions, columns=["id_code", "predicted_label"]).to_csv("DR_predictions.csv", index=False)
    print("Predictions saved to DR_predictions.csv")

# ----------------------------
# Run Training and Inference
# ----------------------------
if __name__ == "__main__":
    train_model()
    run_inference()
