## Data Loader

In [None]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import random

class SimCLRDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = datasets.ImageFolder(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        # Apply transformations twice to get two different views of the same image
        if self.transform:
            img1 = self.transform(img)
            img2 = self.transform(img)
        return (img1, img2), label

def create_train_val_test_dataloaders(data_dir, batch_size, num_workers=0):
    # Define transformations with SimCLR augmentations
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
    }

    # Create custom datasets with two-view augmentation
    image_datasets = {
        'train': SimCLRDataset(os.path.join(data_dir, 'train'), transform=data_transforms['train']),
        'test': SimCLRDataset(os.path.join(data_dir, 'test'), transform=data_transforms['test'])
    }

    # Create dataloaders
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=num_workers),
        'test': DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False, num_workers=num_workers)
    }

    return dataloaders

# Example usage
data_dir = 'Dataset/'
batch_size = 32
dataloaders = create_train_val_test_dataloaders(data_dir, batch_size)

## Switch to GPU

In [None]:
import torch
# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Define SimCLR

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class SimCLRClassifier(nn.Module):
    def __init__(self, base_model, feature_dim=512):
        super(SimCLRClassifier, self).__init__()
        self.base_model = base_model
        self.fc = nn.Linear(feature_dim, 1)  # For binary classification

    def forward(self, x):
        features = self.base_model(x)
        features = features.view(features.size(0), -1)
        output = self.fc(features)
        return output

# Load a pre-trained ResNet model
base_model = models.resnet18(weights='ResNet18_Weights.DEFAULT')
base_model.fc = nn.Identity()  # Remove the final classification layer

model = SimCLRClassifier(base_model).to(device)

## Define Loss and Optimizer

In [None]:
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

## Training and Validation

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for (img1, img2), labels in tqdm(dataloader, desc="Training"):
        model.to(device)
        img1, img2 = img1.to(device), img2.to(device)
        labels = labels.float().to(device).view(-1, 1)  # Reshape for binary classification

        optimizer.zero_grad()
        
        # Forward pass for both views
        outputs1 = model(img1)
        outputs2 = model(img2)
        
        # Calculate loss
        loss = criterion(outputs1, labels) + criterion(outputs2, labels)  # Sum losses for both views
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        # Predictions
        preds = (torch.sigmoid(outputs1) > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(dataloader)
    accuracy = correct / total * 100
    return epoch_loss, accuracy

def validate_one_epoch(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for (img1, img2), labels in tqdm(dataloader, desc="Testing"):
            model.to(device)
            img1, img2 = img1.to(device), img2.to(device)
            labels = labels.float().to(device).view(-1, 1)  # Reshape for binary classification

            # Forward pass
            outputs1 = model(img1)
            outputs2 = model(img2)

            # Calculate loss
            loss = criterion(outputs1, labels) + criterion(outputs2, labels)
            running_loss += loss.item()

            # Predictions
            preds = (torch.sigmoid(outputs1) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / len(dataloader)
    accuracy = correct / total * 100
    return epoch_loss, accuracy

## Run the Training and Validation Loop

In [None]:
from tqdm import tqdm

num_epochs = 10

# Define the directory to save the model checkpoints
save_dir = 'Checkpoints_SimCLR_10ep/'
os.makedirs(save_dir, exist_ok=True) 

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    # Training
    train_loss, train_accuracy = train_one_epoch(model, dataloaders['train'], criterion, optimizer)
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    # Validation
    val_loss, val_accuracy = validate_one_epoch(model, dataloaders['test'], criterion)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    # Save the model state after each epoch
    model_save_path = os.path.join(save_dir, f'model_epoch_{epoch + 1}.pth')
    torch.save(model.state_dict(), model_save_path)