In [None]:
#training code
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import time
import copy
import random
import warnings
from glob import glob
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

set_seed()

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Custom Dataset with Augmented Images and Oversampling
class ODIRDataset(Dataset):
    def __init__(self, excel_file, img_dir, aug_dir, transform=None, oversample_factor=2):
        self.data = pd.read_excel(excel_file)
        self.img_dir = img_dir
        self.aug_dir = aug_dir
        self.transform = transform
        self.label_columns = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
        self.oversample_factor = oversample_factor
        
        # Calculate class weights
        class_counts = self.data[self.label_columns].sum()
        total_samples = len(self.data)
        class_weights = total_samples / (len(self.label_columns) * class_counts)
        self.class_weights = torch.FloatTensor(class_weights.values).to(device)
        
        print(f"Class distribution: {class_counts.to_dict()}")
        print(f"Class weights: {class_weights.to_dict()}")

        # Load images and labels
        self.image_pairs = []
        self.labels = []
        self.demographics = []
        
        # Add original image pairs
        for idx in range(len(self.data)):
            row = self.data.iloc[idx]
            left_img_path = os.path.join(self.img_dir, row['Left-Fundus'])
            right_img_path = os.path.join(self.img_dir, row['Right-Fundus'])
            if os.path.exists(left_img_path) and os.path.exists(right_img_path):
                self.image_pairs.append((left_img_path, right_img_path))
                self.labels.append(torch.FloatTensor(row[self.label_columns].values.astype(float)))
                gender = 1 if row['Patient Sex'] == 'Male' else 0
                age = row['Patient Age'] / 100.0
                self.demographics.append(torch.tensor([age, gender], dtype=torch.float32))

        # Add augmented images with oversampling for minority classes
        aug_images = glob(os.path.join(self.aug_dir, "left_*.png"))
        for aug_img in aug_images:
            filename = os.path.basename(aug_img)
            try:
                id_str = filename.split('_')[1].split('.')[0]
                id_num = int(id_str)
                row = self.data[self.data['ID'] == id_num]
                if len(row) == 0:
                    continue
                row = row.iloc[0]
                orig_left = os.path.join(self.img_dir, row['Left-Fundus'])
                orig_right = os.path.join(self.img_dir, row['Right-Fundus'])
                labels = torch.FloatTensor(row[self.label_columns].values.astype(float))
                gender = 1 if row['Patient Sex'] == 'Male' else 0
                age = row['Patient Age'] / 100.0
                demo = torch.tensor([age, gender], dtype=torch.float32)

                # Oversample minority classes (A, H, D)
                oversample = 1
                if labels[4] == 1 or labels[5] == 1 or labels[1] == 1:  # A, H, D
                    oversample = self.oversample_factor

                for _ in range(oversample):
                    self.image_pairs.append((aug_img, orig_right))
                    self.labels.append(labels)
                    self.demographics.append(demo)
            except (ValueError, IndexError):
                continue

        print(f"Total image pairs (original + augmented): {len(self.image_pairs)}")

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        try:
            left_img_path, right_img_path = self.image_pairs[idx]
            left_img = np.array(Image.open(left_img_path).convert('RGB'))
            right_img = np.array(Image.open(right_img_path).convert('RGB'))
            
            if self.transform:
                left_img = self.transform(image=left_img)['image']
                right_img = self.transform(image=right_img)['image']
            
            return left_img, right_img, self.demographics[idx], self.labels[idx]
        except Exception as e:
            print(f"Error at index {idx}: {e}")
            return None

# Data transforms with more aggressive augmentation for minority classes
data_transforms = {
    'train': A.Compose([
        A.Resize(256, 256),
        A.RandomCrop(224, 224),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=30, p=0.8),  # Increased rotation
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),  # More aggressive
        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2, p=0.7),
        A.GaussNoise(p=0.3),  # Add noise for robustness
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ]),
    'val': A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ]),
}

# Custom collate function to handle None values
def custom_collate(batch):
    batch = [x for x in batch if x is not None]
    if not batch:
        return None
    return torch.utils.data.dataloader.default_collate(batch)

# Model Definitions with Attention Mechanism
class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.conv(x)
        attention = self.sigmoid(attention)
        return x * attention

class ResNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        model = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(model.children())[:-2])  # Remove the last two layers
        self.attention = AttentionModule(2048)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(2048 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        model = models.efficientnet_b0(pretrained=True)
        self.backbone = nn.Sequential(*list(model.children())[:-2])
        self.attention = AttentionModule(1280)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(1280 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

class DenseNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        model = models.densenet121(pretrained=True)
        self.backbone = nn.Sequential(*list(model.children())[:-1])
        self.attention = AttentionModule(1024)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(1024 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

# Focal Loss with Adjusted Class Weights
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = (1 - pt) ** self.gamma * BCE_loss
        if self.alpha is not None:
            alpha_t = self.alpha[targets.long()]
            F_loss = alpha_t * F_loss
        if self.reduction == 'mean':
            return F_loss.mean()
        return F_loss

# Optimize Thresholds with Class-Specific Ranges
def optimize_thresholds(outputs, targets):
    outputs_np = torch.sigmoid(outputs).cpu().numpy()
    targets_np = targets.cpu().numpy()
    best_thresholds = []
    
    for class_idx in range(targets_np.shape[1]):
        best_f1 = 0
        best_threshold = 0.5
        # Adjust threshold range based on class performance
        if class_idx in [1, 4, 5]:  # D, A, H (low precision, high recall)
            threshold_range = np.arange(0.5, 0.9, 0.05)  # Higher thresholds to improve precision
        else:
            threshold_range = np.arange(0.3, 0.7, 0.05)
        for threshold in threshold_range:
            preds = (outputs_np[:, class_idx] >= threshold).astype(targets_np.dtype)
            f1 = f1_score(targets_np[:, class_idx], preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
        best_thresholds.append(best_threshold)
    
    return best_thresholds

# Metrics Calculation with Confusion Matrix
def calculate_metrics(outputs, targets, thresholds=None, class_names=None, phase='val'):
    outputs_np = torch.sigmoid(outputs).cpu().numpy()
    targets_np = targets.cpu().numpy()
    
    if thresholds is None:
        thresholds = [0.5] * targets_np.shape[1]
    
    preds = np.zeros_like(outputs_np, dtype=targets_np.dtype)
    for i in range(len(thresholds)):
        preds[:, i] = (outputs_np[:, i] >= thresholds[i]).astype(targets_np.dtype)
    
    precision = precision_score(targets_np, preds, average=None, zero_division=0)
    recall = recall_score(targets_np, preds, average=None, zero_division=0)
    f1 = f1_score(targets_np, preds, average=None, zero_division=0)
    accuracy = [accuracy_score(targets_np[:, i], preds[:, i]) for i in range(targets_np.shape[1])]
    
    # Confusion matrix for validation phase
    if phase == 'val':
        for i, name in enumerate(class_names):
            cm = confusion_matrix(targets_np[:, i], preds[:, i])
            plt.figure(figsize=(6, 4))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
            plt.title(f'Confusion Matrix for Class {name}')
            plt.ylabel('True')
            plt.xlabel('Predicted')
            plt.savefig(f'confusion_matrix_{name}.png')
            plt.close()
    
    return {
        'macro_precision': np.mean(precision),
        'macro_recall': np.mean(recall),
        'macro_f1': np.mean(f1),
        'macro_accuracy': np.mean(accuracy),
        'per_class_precision': precision,
        'per_class_recall': recall,
        'per_class_f1': f1,
        'per_class_accuracy': accuracy,
        'thresholds': thresholds
    }

# Training Function with Improved Early Stopping
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=15, model_name="model", patience=2):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_macro_f1 = 0.0
    scaler = GradScaler()
    class_names = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
    thresholds = None
    patience_counter = 0
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f'Epoch {epoch+1}/{num_epochs}\n{"-"*10}')
        
        for phase in ['train', 'val']:
            if phase == 'val' and epoch % 2 != 0:
                continue
            model.train() if phase == 'train' else model.eval()
            running_loss = 0.0
            all_outputs, all_targets = [], []
            
            for batch in dataloaders[phase]:
                if batch is None:
                    continue
                left_imgs, right_imgs, demos, targets = [x.to(device, non_blocking=True) for x in batch]
                
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast():
                        outputs = model(left_imgs, right_imgs, demos)
                        loss = criterion(outputs, targets)
                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                
                running_loss += loss.item() * left_imgs.size(0)
                all_outputs.append(outputs.detach())
                all_targets.append(targets)
            
            if not all_outputs:
                continue
            all_outputs = torch.cat(all_outputs)
            all_targets = torch.cat(all_targets)
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            
            if phase == 'val':
                thresholds = optimize_thresholds(all_outputs, all_targets)
            metrics = calculate_metrics(all_outputs, all_targets, thresholds, class_names, phase)
            
            print(f'{phase} Loss: {epoch_loss:.4f}')
            print(f'{phase} Macro Precision: {metrics["macro_precision"]:.4f}')
            print(f'{phase} Macro Recall: {metrics["macro_recall"]:.4f}')
            print(f'{phase} Macro F1: {metrics["macro_f1"]:.4f}')
            print(f'{phase} Macro Accuracy: {metrics["macro_accuracy"]:.4f}')
            print("Per-class metrics:")
            for i, name in enumerate(class_names):
                print(f"{name}: Precision={metrics['per_class_precision'][i]:.4f}, "
                      f"Recall={metrics['per_class_recall'][i]:.4f}, "
                      f"F1={metrics['per_class_f1'][i]:.4f}, "
                      f"Accuracy={metrics['per_class_accuracy'][i]:.4f}")
            if phase == 'val':
                print("Optimized Thresholds:", {name: thresh for name, thresh in zip(class_names, thresholds)})
            
            if phase == 'val':
                scheduler.step(epoch_loss)
                if metrics['macro_f1'] > best_macro_f1:
                    best_macro_f1 = metrics['macro_f1']
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(), f'{model_name}_best.pth')
                    print(f"✅ Saved best model (F1: {best_macro_f1:.4f})")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print("Early stopping triggered")
                        break
            
            vram = torch.cuda.memory_allocated(device) / 1024**3
            print(f"VRAM Usage: {vram:.2f} GB")
        
        if patience_counter >= patience:
            break
        
        epoch_time = time.time() - epoch_start
        print(f"Epoch time: {epoch_time:.2f} seconds\n")
    
    model.load_state_dict(best_model_wts)
    return model, {
        'macro_f1': best_macro_f1,
        'thresholds': thresholds
    }

# Ensemble Model
class EnsembleModel(nn.Module):
    def __init__(self, models, weights):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.weights = torch.tensor(weights, dtype=torch.float32).to(device)
    
    def forward(self, left_img, right_img, demo):
        outputs = [torch.sigmoid(model(left_img, right_img, demo)) for model in self.models]
        outputs = torch.stack(outputs)
        return torch.sum(outputs * self.weights[:, None, None], dim=0)

# Main Execution
def main():
    excel_path = r"C:\Users\OMEN\Saved Programs\Disease prediction\fundus_disease_prediction\dataset\data.xlsx"
    img_dir = r"C:\Users\OMEN\Saved Programs\Disease prediction\fundus_disease_prediction\dataset\images"
    aug_dir = r"C:\Users\OMEN\Saved Programs\Disease prediction\fundus_disease_prediction\dataset\images\augmented"
    
    dataset = ODIRDataset(excel_path, img_dir, aug_dir, data_transforms['train'], oversample_factor=3)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_data, val_data = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
    
    val_data.dataset.transform = data_transforms['val']
    dataloaders = {
        'train': DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, 
                            pin_memory=True, collate_fn=custom_collate),  # Increased batch size and num_workers
        'val': DataLoader(val_data, batch_size=64, num_workers=0, 
                          pin_memory=True, collate_fn=custom_collate)
    }
    
    criterion = FocalLoss(alpha=dataset.class_weights * 1.5, gamma=2.0)  # Increased weight for minority classes
    models_to_train = [
        (ResNetModel, "resnet50"),
        (EfficientNetModel, "efficientnet"),
        (DenseNetModel, "densenet")
    ]
    
    trained_models, metrics_list = [], []
    for model_class, name in models_to_train:
        print(f"\n===== Training {name} =====")
        model = model_class(num_classes=8).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
        model, metrics = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=15, model_name=name, patience=2)
        trained_models.append(model)
        metrics_list.append(metrics)
    
    # Use F1-score for ensemble weights
    weights = [m['macro_f1'] / sum(m['macro_f1'] for m in metrics_list) for m in metrics_list]
    ensemble = EnsembleModel(trained_models, weights).to(device)
    
    # Evaluate Ensemble
    ensemble.eval()
    all_outputs, all_targets = [], []
    with torch.no_grad():
        for batch in dataloaders['val']:
            if batch is None:
                continue
            inputs = [x.to(device, non_blocking=True) for x in batch]
            with autocast():
                outputs = ensemble(*inputs[:-1])
            all_outputs.append(outputs)
            all_targets.append(inputs[-1])
    
    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets)
    thresholds = optimize_thresholds(all_outputs, all_targets)
    metrics = calculate_metrics(all_outputs, all_targets, thresholds, ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O'], phase='val')
    print("\n===== Ensemble Performance =====")
    print(f"Macro Precision: {metrics['macro_precision']:.4f}")
    print(f"Macro Recall: {metrics['macro_recall']:.4f}")
    print(f"Macro F1: {metrics['macro_f1']:.4f}")
    print(f"Macro Accuracy: {metrics['macro_accuracy']:.4f}")
    print("Per-class metrics:")
    for i, name in enumerate(['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']):
        print(f"{name}: Precision={metrics['per_class_precision'][i]:.4f}, "
              f"Recall={metrics['per_class_recall'][i]:.4f}, "
              f"F1={metrics['per_class_f1'][i]:.4f}, "
              f"Accuracy={metrics['per_class_accuracy'][i]:.4f}")
    print("Optimized Thresholds:", {name: thresh for name, thresh in zip(['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O'], thresholds)})

if __name__ == "__main__":
    main()

In [5]:
#testing code

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import time
import copy
import random
import warnings
from glob import glob
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define model architectures (copied from your training code)
class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.conv(x)
        attention = self.sigmoid(attention)
        return x * attention

class ResNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        from torchvision import models
        model = models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(*list(model.children())[:-2])
        self.attention = AttentionModule(2048)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(2048 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        from torchvision import models
        model = models.efficientnet_b0(pretrained=False)
        self.backbone = nn.Sequential(*list(model.children())[:-2])
        self.attention = AttentionModule(1280)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(1280 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

class DenseNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        from torchvision import models
        model = models.densenet121(pretrained=False)
        self.backbone = nn.Sequential(*list(model.children())[:-1])
        self.attention = AttentionModule(1024)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.demo_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.6))
        self.classifier = nn.Sequential(
            nn.Linear(1024 * 2 + 32, 512), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, left_img, right_img, demo):
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)
        left_features = self.attention(left_features)
        right_features = self.attention(right_features)
        left_features = self.pool(left_features).flatten(1)
        right_features = self.pool(right_features).flatten(1)
        demo_features = self.demo_fc(demo)
        combined = torch.cat((left_features, right_features, demo_features), dim=1)
        return self.classifier(combined)

class EnsembleModel(nn.Module):
    def __init__(self, models, weights):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.weights = torch.tensor(weights, dtype=torch.float32).to(device)
    
    def forward(self, left_img, right_img, demo):
        outputs = [torch.sigmoid(model(left_img, right_img, demo)) for model in self.models]
        outputs = torch.stack(outputs)
        return torch.sum(outputs * self.weights[:, None, None], dim=0)

def load_ensemble_model(model_paths, weights):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models
    resnet_model = ResNetModel(num_classes=8).to(device)
    efficientnet_model = EfficientNetModel(num_classes=8).to(device)
    densenet_model = DenseNetModel(num_classes=8).to(device)
    
    # Load weights
    resnet_model.load_state_dict(torch.load(model_paths[0], map_location=device))
    efficientnet_model.load_state_dict(torch.load(model_paths[1], map_location=device))
    densenet_model.load_state_dict(torch.load(model_paths[2], map_location=device))
    
    # Create ensemble
    models = [resnet_model, efficientnet_model, densenet_model]
    ensemble = EnsembleModel(models, weights).to(device)
    ensemble.eval()
    
    return ensemble

def predict_disease(left_img_path, right_img_path, model_dir, age=50, gender="Male", thresholds=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Default thresholds if not provided (use your optimized thresholds from training)
    if thresholds is None:
        thresholds = {
            'N': 0.5, 'D': 0.6, 'G': 0.5,
            'C': 0.5, 'A': 0.65, 'H': 0.7,
            'M': 0.5, 'O': 0.5
        }
    
    # Class names
    class_names = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
    class_full_names = {
        'N': 'Normal',
        'D': 'Diabetic Retinopathy',
        'G': 'Glaucoma',
        'C': 'Cataract',
        'A': 'Age-related Macular Degeneration',
        'H': 'Hypertensive Retinopathy',
        'M': 'Myopia',
        'O': 'Other'
    }
    
    # Load models
    model_paths = [
        os.path.join(model_dir, "resnet50_best.pth"),
        os.path.join(model_dir, "efficientnet_best.pth"),
        os.path.join(model_dir, "densenet_best.pth")
    ]
    
    # Ensemble weights (use your F1-score weights from training)
    # These are sample weights - replace with your actual weights
    weights = [0.33, 0.33, 0.34]  # Example weights
    
    # Load ensemble model
    try:
        ensemble = load_ensemble_model(model_paths, weights)
        print("Models loaded successfully.")
    except Exception as e:
        print(f"Error loading models: {e}")
        return None
    
    # Image preprocessing
    transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # Demographic data
    gender_value = 1 if gender.lower() == "male" else 0
    demo_tensor = torch.tensor([[age / 100.0, gender_value]], dtype=torch.float32).to(device)
    
    # Load and preprocess images
    try:
        left_img = np.array(Image.open(left_img_path).convert('RGB'))
        right_img = np.array(Image.open(right_img_path).convert('RGB'))
        
        left_img = transform(image=left_img)['image'].unsqueeze(0).to(device)
        right_img = transform(image=right_img)['image'].unsqueeze(0).to(device)
        
        print("Images loaded and preprocessed successfully.")
    except Exception as e:
        print(f"Error processing images: {e}")
        return None
    
    # Run prediction
    with torch.no_grad():
        with autocast():
            outputs = ensemble(left_img, right_img, demo_tensor)
    
    # Apply thresholds and get predictions
    predictions = []
    probabilities = {}
    
    for i, class_name in enumerate(class_names):
        prob = outputs[0, i].item()
        probabilities[class_name] = prob
        if prob >= thresholds[class_name]:
            predictions.append(class_name)
    
    # If no disease is above threshold but there are probabilities, take highest one
    if not predictions and len(probabilities) > 0:
        max_class = max(probabilities, key=probabilities.get)
        predictions.append(max_class)
    
    # Results
    result = {
        'predictions': predictions,
        'probabilities': probabilities,
        'is_normal': 'N' in predictions and len(predictions) == 1,
        'diagnosis': []
    }
    
    # Generate diagnosis text
    if result['is_normal']:
        result['diagnosis'].append("Normal eye condition detected.")
    else:
        if 'N' in predictions:
            predictions.remove('N')  # Remove normal if other diseases are present
        
        for disease in predictions:
            result['diagnosis'].append(f"{class_full_names[disease]} detected with {probabilities[disease]*100:.1f}% confidence.")
    
    return result

def main():
    # Configuration
    model_dir = r"C:\Users\OMEN\Saved Programs\Disease prediction"  # Path to directory with saved model files
    left_img_path = r"C:\Users\OMEN\Saved Programs\Disease prediction\fundus_disease_prediction\dataset\images\171_left.jpg" # Replace with actual path
    right_img_path = r"C:\Users\OMEN\Saved Programs\Disease prediction\fundus_disease_prediction\dataset\images\171_right.jpg" # Replace with actual path
    
    # Patient demographics (can be hardcoded or entered by user)
    age = 60  # Patient age
    gender = "Male"  # Patient gender: "Male" or "Female"
    
    # Custom thresholds from your training (replace with your optimized values)
    thresholds = {
        'N': 0.5,  # Normal
        'D': 0.65, # Diabetic Retinopathy
        'G': 0.45, # Glaucoma
        'C': 0.5,  # Cataract
        'A': 0.7,  # Age-related Macular Degeneration
        'H': 0.7,  # Hypertensive Retinopathy
        'M': 0.5,  # Myopia
        'O': 0.5   # Other
    }
    
    print("Running fundus image disease prediction...")
    print(f"Analyzing left image: {os.path.basename(left_img_path)}")
    print(f"Analyzing right image: {os.path.basename(right_img_path)}")
    print(f"Patient: Age {age}, Gender: {gender}")
    
    # Run prediction
    result = predict_disease(left_img_path, right_img_path, model_dir, age, gender, thresholds)
    
    if result:
        print("\n----- RESULTS -----")
        if result['is_normal']:
            print("✅ No diseases detected. Eyes appear normal.")
        else:
            print("🔍 Analysis complete. Findings:")
            for diagnosis in result['diagnosis']:
                print(f"  - {diagnosis}")
        
        print("\nDetailed probabilities:")
        for disease, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
            print(f"  {disease}: {prob*100:.1f}%")
    else:
        print("❌ Error during prediction. Please check image paths and model files.")

if __name__ == "__main__":
    main()

Using device: cuda
Running fundus image disease prediction...
Analyzing left image: 171_left.jpg
Analyzing right image: 171_right.jpg
Patient: Age 60, Gender: Male
Using device: cuda
Models loaded successfully.
Images loaded and preprocessed successfully.

----- RESULTS -----
🔍 Analysis complete. Findings:
  - Diabetic Retinopathy detected with 89.0% confidence.
  - Other detected with 69.1% confidence.

Detailed probabilities:
  D: 89.0%
  O: 69.1%
  C: 10.8%
  G: 9.2%
  H: 9.2%
  A: 8.6%
  M: 8.5%
  N: 5.3%
