In [2]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image


# Custom Dataset for SSL (SimCLR)
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Get all image files in the root directory
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.jpg', '.png'))]
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


# Define the ContrastiveLoss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, 2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss


# Define the SimCLR model (assuming the base model is a ResNet)
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.base_model = base_model
        # Get the number of input features for the last fully connected layer before identity replacement
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()  # Remove final classification layer
        
        # Improved projection head with batch normalization
        self.projection_head = nn.Sequential(
            nn.Linear(num_ftrs, 512),  # Use the number of features in the last layer
            nn.ReLU(),
            nn.BatchNorm1d(512),  # Batch normalization
            nn.Linear(512, projection_dim)
        )
    
    def forward(self, x):
        features = self.base_model(x)
        projections = self.projection_head(features)
        return projections


# Define the data augmentations for SimCLR
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5, contrast=0.5),  # Increased strength of augmentations
    transforms.RandomRotation(45),  # Increased rotation angle
    transforms.RandomAffine(degrees=45, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load your unlabeled dataset (replace with your actual path)
train_dataset = CustomImageDataset(root_dir='/Users/ramanathanswaminathan/Downloads/glaucoma_exhaustive/acrima+drishti_ssl/Training', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

# Initialize the SimCLR model and contrastive loss
base_model = models.resnet18(weights='IMAGENET1K_V1')  # Update to use correct pretrained weights
simclr_model = SimCLR(base_model)
contrastive_loss = ContrastiveLoss()

# Set up optimizer
optimizer = optim.Adam(simclr_model.parameters(), lr=1e-4)  # Adjusted learning rate

# Training loop for SimCLR pretraining
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simclr_model = simclr_model.to(device)
simclr_model.train()

for epoch in range(50):  # Pretrain for 50 epochs
    running_loss = 0.0
    for images in train_loader:  # No labels are used here
        images = images.to(device)
        
        # SimCLR requires pairs of augmented images (positive pairs)
        augmented_images1 = images
        augmented_images2 = images  # For simplicity, we'll treat images as paired in this basic setup
        
        # Forward pass for both augmented views
        projections1 = simclr_model(augmented_images1)
        projections2 = simclr_model(augmented_images2)
        
        # Labels: 1 if they belong to the same class, 0 otherwise (for simplicity, use dummy labels for now)
        labels = torch.ones(images.size(0)).to(device)  # All pairs are assumed to be positive pairs
        
        # Compute contrastive loss
        loss = contrastive_loss(projections1, projections2, labels)
        
        # Backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/50] - Loss: {running_loss/len(train_loader)}")


Epoch [1/50] - Loss: 0.9999773472547531
Epoch [2/50] - Loss: 0.9999773472547531
Epoch [3/50] - Loss: 0.9999773472547531
Epoch [4/50] - Loss: 0.9999773472547531
Epoch [5/50] - Loss: 0.9999773472547531
Epoch [6/50] - Loss: 0.9999773472547531
Epoch [7/50] - Loss: 0.9999773472547531


KeyboardInterrupt: 

In [None]:
# Fine-tuning the model on the labeled dataset (Glaucoma detection)
# Modify the model to include a final classification layer for 2 classes
num_ftrs = simclr_model.base_model.fc.in_features
simclr_model.base_model.fc = nn.Linear(num_ftrs, 2)  # 2 classes: glaucoma and normal

# Set the optimizer (fine-tuning)
optimizer = optim.Adam(simclr_model.parameters(), lr=1e-4)

# Load the labeled dataset (replace with the path to your actual labeled data)
train_dataset_finetune = datasets.ImageFolder(root='/Users/ramanathanswaminathan/Downloads/glaucoma_exhaustive/acrima+drishti/Training', transform=transform)
train_loader_finetune = DataLoader(train_dataset_finetune, batch_size=64, shuffle=True, num_workers=4)

# Fine-tuning loop
simclr_model.train()
for epoch in range(10):  # Fine-tune for 10 epochs
    running_loss = 0.0
    for images, labels in train_loader_finetune:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = simclr_model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Fine-tuning Epoch [{epoch+1}/10] - Loss: {running_loss/len(train_loader_finetune)}")

# Evaluation of the fine-tuned model
# Assuming you have a separate test set for evaluation
test_dataset = datasets.ImageFolder(root='/Users/ramanathanswaminathan/Downloads/glaucoma_exhaustive/acrima+drishti/Testing', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# Evaluate the model
simclr_model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = simclr_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the model on the test images: {accuracy}%')

# Optional: Save the model for future use
torch.save(simclr_model.state_dict(), "simclr_glaucoma_model.pth")


In [None]:
#ssl plus ensemble 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from transformers import ViTModel
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Data augmentation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.2))
])

# Dataset paths
train_dir = "/Users/ramanathanswaminathan/Downloads/glaucoma_exhaustive/acrima+drishti/Training"
test_dir = "/Users/ramanathanswaminathan/Downloads/glaucoma_exhaustive/acrima+drishti/Testing"

# Load datasets
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)

# DataLoader
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# SimCLR Model (Pretrained)
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.base_model = base_model
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()
        self.projection_head = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
    
    def forward(self, x):
        features = self.base_model(x)
        projections = self.projection_head(features)
        return projections

# Load pre-trained SimCLR model (e.g., ResNet-18)
base_model = models.resnet18(weights='IMAGENET1K_V1')
simclr_model = SimCLR(base_model).to(device)

# Hybrid Model 1 (EfficientNet + ViT)
class HybridModel1(nn.Module):
    def __init__(self, feature_dim=768, num_classes=2):
        super().__init__()
        self.cnn = models.efficientnet_b0(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
        
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224", add_pooling_layer=False)
        self.projection = nn.Linear(1280 + feature_dim, feature_dim)
        
        self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True)
        
        self.attention_weights = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        self.fc = nn.Linear(feature_dim, num_classes)
        self.batch_norm = nn.BatchNorm1d(feature_dim)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        cnn_features = self.cnn(x).flatten(1)
        vit_features = self.vit(x).last_hidden_state[:, 0, :]
        features = torch.cat([cnn_features, vit_features], dim=-1)
        fused_features = self.projection(features)
        
        attn_output, _ = self.cross_attention(fused_features.unsqueeze(1), fused_features.unsqueeze(1), fused_features.unsqueeze(1))
        attn_output = attn_output.squeeze(1)
        
        attention_scores = torch.sigmoid(self.attention_weights(attn_output))
        weighted_features = attn_output * attention_scores
        
        weighted_features = self.batch_norm(weighted_features)
        weighted_features = self.dropout(weighted_features)
        return self.fc(weighted_features)

# Hybrid Model 2 (ResNet + ViT)
class HybridModel2(nn.Module):
    def __init__(self, feature_dim=768, num_classes=2):
        super().__init__()
        self.cnn = models.resnet150(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
        
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224", add_pooling_layer=False)
        self.projection = nn.Linear(512 + feature_dim, feature_dim)
        
        self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=8, batch_first=True)
        
        self.attention_weights = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        self.fc = nn.Linear(feature_dim, num_classes)
        self.batch_norm = nn.BatchNorm1d(feature_dim)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        cnn_features = self.cnn(x).flatten(1)
        vit_features = self.vit(x).last_hidden_state[:, 0, :]
        features = torch.cat([cnn_features, vit_features], dim=-1)
        fused_features = self.projection(features)
        
        attn_output, _ = self.cross_attention(fused_features.unsqueeze(1), fused_features.unsqueeze(1), fused_features.unsqueeze(1))
        attn_output = attn_output.squeeze(1)
        
        attention_scores = torch.sigmoid(self.attention_weights(attn_output))
        weighted_features = attn_output * attention_scores
        
        weighted_features = self.batch_norm(weighted_features)
        weighted_features = self.dropout(weighted_features)
        return self.fc(weighted_features)

# Ensemble Model: Combine Predictions of Two Models and SimCLR
class EnsembleModel(nn.Module):
    def __init__(self, model1, model2, ssl_model):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
        self.ssl_model = ssl_model  # SimCLR Model
    
    def forward(self, x):
        # Get predictions from both models
        pred1 = self.model1(x)
        pred2 = self.model2(x)
        
        # Get features from the SimCLR model (projection output)
        ssl_features = self.ssl_model(x)
        
        # Combine predictions: Averaging the outputs and concatenating SSL features
        combined_pred = (pred1 + pred2) / 2
        combined_pred = torch.cat([combined_pred, ssl_features], dim=-1)  # Concatenate SSL features for final decision
        return combined_pred

# Training function with Cyclical Learning Rate, Gradual Layer Freezing, and Label Smoothing
def train_model(model, train_loader, test_loader, optimizer, scheduler, criterion, device, epochs=150):
    train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []
    best_test_accuracy = 0.0
    best_model_state_dict = None
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train, total_train = 0, 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            running_loss += loss.item()
            correct_train += (outputs.argmax(1) == labels).sum().item()
            total_train += labels.size(0)
        
        train_accuracy = correct_train / total_train * 100
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        train_accuracies.append(train_accuracy)
        
        # Testing Phase
        model.eval()
        correct_test, total_test = 0, 0
        running_test_loss = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_test_loss += loss.item()
                correct_test += (outputs.argmax(1) == labels).sum().item()
                total_test += labels.size(0)
                all_preds.extend(outputs.argmax(1).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        test_accuracy = correct_test / total_test * 100
        avg_test_loss = running_test_loss / len(test_loader)
        test_losses.append(avg_test_loss)
        test_accuracies.append(test_accuracy)
        
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} - Train Acc: {train_accuracy:.2f}% - Test Loss: {avg_test_loss:.4f} - Test Acc: {test_accuracy:.2f}%")
        scheduler.step()

        # Save the model with the best test accuracy
        if test_accuracy > best_test_accuracy:
            best_test_accuracy = test_accuracy
            best_model_state_dict = model.state_dict()
            print(f"New best model with Test Acc: {best_test_accuracy:.2f}%")
    
    # Save the best model after training completes
    if best_model_state_dict is not None:
        torch.save(best_model_state_dict, "best_ensemble_model_with_ssl.pth")
        print("Best model saved with Test Accuracy: {:.2f}%".format(best_test_accuracy))
    
    # Plotting training and testing curves
    plt.figure()
    plt.plot(range(epochs), train_losses, label='Train Loss')
    plt.plot(range(epochs), test_losses, label='Test Loss')
    plt.legend()
    plt.title("Loss Curve")
    plt.show()
    
    plt.figure()
    plt.plot(range(epochs), train_accuracies, label='Train Accuracy')
    plt.plot(range(epochs), test_accuracies, label='Test Accuracy')
    plt.legend()
    plt.title("Accuracy Curve")
    plt.show()
    
    # Confusion Matrix and other metrics
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.show()
    
    print(classification_report(all_labels, all_preds))
    
    # ROC and AUC
    fpr, tpr, thresholds = roc_curve(all_labels, all_preds)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

# Initialize models
model1 = HybridModel1().to(device)
model2 = HybridModel2().to(device)

# Optimizer, loss, and scheduler
optimizer = optim.AdamW(list(model1.parameters()) + list(model2.parameters()) + list(simclr_model.parameters()), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Create and train ensemble model
ensemble_model = EnsembleModel(model1, model2, simclr_model).to(device)
train_model(ensemble_model, train_loader, test_loader, optimizer, scheduler, criterion, device)
