In [None]:
import torch
import torch.nn as nn
import timm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
import random

def set_seed(seed=42):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes=5):
        """
        Initialize EfficientNet model with custom classifier.
        
        Args:
            num_classes (int): Number of classification categories
        """
        super().__init__()
        self.model = timm.create_model('efficientnet_b0', pretrained=False)
        in_features = self.model.classifier.in_features
        
        # Custom multi-layer classifier
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        """Forward pass through the model."""
        return self.model(x)

def load_models(model_paths, num_classes, device):
    """
    Load multiple pre-trained models.
    
    Args:
        model_paths (list): Paths to model weights
        num_classes (int): Number of output classes
        device (torch.device): Device to load models on
    
    Returns:
        list: Loaded and prepared models
    """
    models = []
    for model_path in model_paths:
        model = EfficientNetModel(num_classes=num_classes).to(device)
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model.eval()
        models.append(model)
    return models

def preprocess_image(image_path, image_size=384):
    """
    Preprocess input image for model inference.
    
    Args:
        image_path (str): Path to input image
        image_size (int): Resize dimension
    
    Returns:
        torch.Tensor: Preprocessed image tensor
    """
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Image not found at {image_path}")
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transforms = Compose([
        Resize(image_size, image_size),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    transformed = transforms(image=img)
    return transformed['image'].unsqueeze(0)

def predict_with_models(models, image_tensor, device):
    """
    Perform inference with multiple models.
    
    Args:
        models (list): List of trained models
        image_tensor (torch.Tensor): Input image tensor
        device (torch.device): Computation device
    
    Returns:
        list: Predictions and confidences from each model
    """
    predictions = []
    image_tensor = image_tensor.to(device)
    
    for model in models:
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = probabilities.argmax(dim=1).item()
            confidence = probabilities.max(dim=1).values.item()
            predictions.append((predicted_class, confidence, probabilities.cpu().numpy()[0]))
    
    return predictions

def visualize_multi_model_predictions(image_path, predictions, class_labels):
    """
    Create visualization of predictions from multiple models.
    
    Args:
        image_path (str): Path to input image
        predictions (list): Predictions from multiple models
        class_labels (dict): Mapping of class indices to labels
    """
    plt.figure(figsize=(16, 10))
    
    # Original Image
    plt.subplot(2, 3, 1)
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.title('Original Fundus Image')
    plt.axis('off')
    
    # Individual Model Predictions
    for i, (predicted_class, confidence, probabilities) in enumerate(predictions, start=1):
        plt.subplot(2, 3, i+1)
        sns.barplot(x=list(class_labels.values()), y=probabilities)
        plt.title(f'Model {i}\nPrediction: {class_labels[predicted_class]}\nConfidence: {confidence:.2%}')
        plt.xlabel('Diabetic Retinopathy Severity')
        plt.ylabel('Probability')
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()

def main():
    # Configuration
    set_seed(42)
    model_paths = [
        '/kaggle/input/blindness_efficentnetb0/pytorch/default/3/best_model_fold_1 (1).pth',
        '//kaggle/input/blindness_efficentnetb0/pytorch/default/3/best_model_fold_2 (1).pth',
        '/kaggle/input/blindness_efficentnetb0/pytorch/default/3/best_model_fold_3 (1).pth',
        '/kaggle/input/blindness_efficentnetb0/pytorch/default/3/best_model_fold_4 (1).pth',
        '/kaggle/input/blindness_efficentnetb0/pytorch/default/3/best_model_fold_5 (1).pth'
    ]
    
    test_images = [
        '/kaggle/input/aptos2019-blindness-detection/train_images/000c1434d8d7.png',
        '/kaggle/input/aptos2019-blindness-detection/train_images/0024cdab0c1e.png',
        '/kaggle/input/aptos2019-blindness-detection/train_images/0104b032c141.png'
    ]
    
    class_labels = {
        0: "No DR",
        1: "Mild",
        2: "Moderate", 
        3: "Severe",
        4: "Proliferative DR"
    }
    
    # Device and Model Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models = load_models(model_paths, num_classes=len(class_labels), device=device)
    
    # Inference and Visualization
    for image_path in test_images:
        try:
            image_tensor = preprocess_image(image_path)
            predictions = predict_with_models(models, image_tensor, device)
            
            print(f"Image: {image_path}")
            for i, (predicted_class, confidence, _) in enumerate(predictions, start=1):
                print(f"Model {i} - Predicted Class: {predicted_class} ({class_labels[predicted_class]})")
                print(f"Model {i} - Confidence: {confidence:.2%}")
            
            visualize_multi_model_predictions(image_path, predictions, class_labels)
        
        except Exception as e:
            print(f"Error processing image {image_path}: {str(e)}")

if __name__ == "__main__":
    main()