In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageEnhance, ImageOps
import pandas as pd
import numpy as np
import os
import cv2
from sklearn.model_selection import train_test_split
from keras._tf_keras.keras.utils import to_categorical
import matplotlib.pyplot as plt

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),  # Changed back to 64 channels
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1),  # Changed back to 128 channels
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        
        self.residual_blocks = nn.Sequential(
            ResidualBlock(128),  # Changed back to 128 channels
            ResidualBlock(128),
            ResidualBlock(128)   # Added back the third residual block
        )
        
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1),  # Changed to match checkpoint
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),   # Changed to match checkpoint
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, padding=1),    # Changed to match checkpoint
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

In [3]:
class EnhancedDRDataset(Dataset):
    def __init__(self, df, base_image_dir, label_mapping, generator_model, device, img_size=(224, 224)):
        self.df = df
        self.base_image_dir = base_image_dir
        self.label_mapping = label_mapping
        self.img_size = img_size
        self.generator = generator_model
        self.device = device
        
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
    def enhance_image(self, image):
        # Convert to PIL Image if it's not already
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
            
        # Apply GAN enhancement
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            enhanced = self.generator(img_tensor)
        
        # Convert back to image
        enhanced = enhanced.squeeze(0).cpu()
        enhanced = enhanced * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + \
                  torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        enhanced = enhanced.clamp(0, 1)
        enhanced = transforms.ToPILImage()(enhanced)
        
        # Adjust enhancement parameters
        enhanced = ImageEnhance.Brightness(enhanced).enhance(1.1)  # Reduced enhancement
        enhanced = ImageEnhance.Contrast(enhanced).enhance(1.6)     # Reduced enhancement
        enhanced = ImageEnhance.Sharpness(enhanced).enhance(1.7)    # Reduced enhancement

        return enhanced
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = row['diagnosis']
        
        # Find the image in the appropriate directory
        for label_name, label_num in self.label_mapping.items():
            img_dir = os.path.join(self.base_image_dir, label_name)
            img_path = os.path.join(img_dir, row['id_code'] + '.png')
            if os.path.exists(img_path):
                # Load and enhance the image
                image = Image.open(img_path).convert('RGB')
                enhanced_image = self.enhance_image(image)
                
                # Convert to numpy array and normalize
                enhanced_array = np.array(enhanced_image)
                enhanced_array = enhanced_array / 255.0
                
                return enhanced_array, label
                
        raise FileNotFoundError(f"Image not found for id_code: {row['id_code']}")

In [4]:
def preprocess_data_with_gan(df, base_image_dir, label_mapping, generator_model, device, img_size=(224, 224)):
    images = []
    labels = []
    
    # Create an instance of EnhancedDRDataset for image enhancement
    dataset = EnhancedDRDataset(df, base_image_dir, label_mapping, generator_model, device)
    
    for index, row in df.iterrows():
        for label_name, label_num in label_mapping.items():
            img_dir = os.path.join(base_image_dir, label_name)
            img_path = os.path.join(img_dir, row['id_code'] + '.png')
            if os.path.exists(img_path):
                # Get enhanced image using GAN
                image, _ = dataset[index]
                
                # Convert to cv2 format and resize
                image = (image * 255).astype(np.uint8)
                image = cv2.resize(image, img_size)
                
                images.append(image)
                labels.append(label_num)

    images = np.array(images)
    labels = np.array(labels)
    labels = to_categorical(labels, num_classes=5)

    return images, labels

In [5]:
from keras._tf_keras.keras.models import Sequential, Model
from keras._tf_keras.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
from keras._tf_keras.keras.layers import (
    Conv2D, 
    MaxPooling2D, 
    Dense, 
    Flatten, 
    Dropout,
    BatchNormalization
)
from keras._tf_keras.keras.models import Sequential

def create_cnn_model(input_shape=(224, 224, 3)):
    model = Sequential([
        # First Convolutional Block
        Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        Conv2D(32, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        
        # Second Convolutional Block
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        
        # Third Convolutional Block
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        
        # Dense layers
        Flatten(),
        Dense(256, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(5, activation='softmax')
    ])
    
    return model

In [10]:
import tensorflow as tf
def train_combined_models(df, base_image_dir, label_mapping, gan_model_path='enhanced_gan_models.pth'):
    # Set up device for GAN
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load and initialize the GAN generator
    generator = Generator().to(device)
    checkpoint = torch.load(gan_model_path, map_location=device)
    generator.load_state_dict(checkpoint['model_state_dict'])
    generator.eval()
    
    # Preprocess data with GAN enhancement
    print("Preprocessing data with GAN enhancement...")
    images, labels = preprocess_data_with_gan(df, base_image_dir, label_mapping, generator, device)
    
    # Split the data
    X_train, X_val, y_train, y_val = train_test_split(images, labels, test_size=0.2, random_state=42)
    
    # Create and compile CNN model
    cnn_model = create_cnn_model()
    cnn_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),  # Lower learning rate
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Add data augmentation
    data_augmentation = tf.keras.Sequential([
        tf.keras.layers.RandomRotation(0.1),
        tf.keras.layers.RandomZoom(0.1),
        tf.keras.layers.RandomFlip("horizontal")
    ])
    
    # Train with more epochs and adjusted batch size
    history = cnn_model.fit(
        data_augmentation(X_train), y_train,
        validation_data=(X_val, y_val),
        epochs=10,  # Increased epochs
        # batch_size=32,  # Adjusted batch size
    )
    
    return cnn_model, history

In [11]:
def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    
    # Plot accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [12]:
# Load your data
data_path = 'FGADR/train.csv'
df = pd.read_csv(data_path)
base_image_dir = 'FGADR'

# Define label mapping
label_mapping = {
    '0': 0,
    '1': 1,
    '2': 2,
    '3': 3,
    '4': 4
}

In [13]:

# Train the models
cnn_model, history = train_combined_models(
    df=df,
    base_image_dir=base_image_dir,
    label_mapping=label_mapping,
    gan_model_path='enhanced_gan_models.pth'
)

Preprocessing data with GAN enhancement...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m184s[0m 2s/step - accuracy: 0.3536 - loss: 1.9339 - val_accuracy: 0.6385 - val_loss: 1.4845
Epoch 2/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m167s[0m 2s/step - accuracy: 0.5362 - loss: 1.3354 - val_accuracy: 0.6739 - val_loss: 1.3858
Epoch 3/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m171s[0m 2s/step - accuracy: 0.5573 - loss: 1.2449 - val_accuracy: 0.5375 - val_loss: 1.6564
Epoch 4/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m167s[0m 2s/step - accuracy: 0.5976 - loss: 1.1237 - val_accuracy: 0.6971 - val_loss: 0.9224
Epoch 5/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m168s[0m 2s/step - accuracy: 0.6602 - loss: 0.9905 - val_accuracy: 0.4379 - val_loss: 1.5341
Epoch 6/10
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m165s[0m 2s/step - accuracy: 0.6728 - loss: 0.9235 - val_accuracy: 0.6398 - val_loss: 0.9291
Epoch 7/10
[1m92/92[0m [32m━━━━