<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/MIL_Contrastive_learning_with_CNN_and_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Preprocess

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import zipfile
import glob
import os
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import f1_score
from tqdm import tqdm
import numpy as np

# Path & Extraction for dataset
ZIP_PATH = "/content/drive/MyDrive/oaisis.zip"  # Path to your dataset ZIP file
EXTRACT_DIR = "/content/oasis_data"  # Path to unzip dataset

# Extract the dataset if not already extracted
if not os.path.exists(EXTRACT_DIR) or not os.listdir(EXTRACT_DIR):
    print(f"Extracting {ZIP_PATH} -> {EXTRACT_DIR}...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
        zf.extractall(EXTRACT_DIR)

# Label Mapping for different classes
LABEL_MAP = {
    "Non Demented": 0,
    "Very mild Dementia": 1,
    "Mild Dementia": 2,
    "Moderate Dementia": 3
}

# Finding the data root directory (where class folders are)
DATA_ROOT = None
for cand in [EXTRACT_DIR] + [os.path.join(EXTRACT_DIR, d) for d in os.listdir(EXTRACT_DIR)]:
    if all(os.path.isdir(os.path.join(cand, lbl)) for lbl in LABEL_MAP):
        DATA_ROOT = cand
        break
if DATA_ROOT is None:
    raise RuntimeError(f"Could not find class folders under {EXTRACT_DIR}")
print("Using data root:", DATA_ROOT)

# Build samples and split data into train and validation sets
all_samples = []
for label_name, lbl in LABEL_MAP.items():
    pattern = os.path.join(DATA_ROOT, label_name, "*.jpg")
    files = glob.glob(pattern)
    if not files:
        raise RuntimeError(f"No images for '{label_name}' in {pattern}")
    all_samples += [(fp, lbl) for fp in files]

# Splitting data (80% train, 20% validation)
paths, labels = zip(*all_samples)
labels = np.array(labels)

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(sss.split(paths, labels))
train_list = [all_samples[i] for i in train_idx]
val_list = [all_samples[i] for i in val_idx]

# Compute class weights to handle imbalanced data
class_counts = np.bincount(labels[train_idx], minlength=len(LABEL_MAP))
class_weights = torch.tensor(1.0 / class_counts, dtype=torch.float32)
sample_weights = [class_weights[lbl].item() for _, lbl in train_list]
train_sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# Transformations for data augmentation
train_tf = transforms.Compose([
    transforms.Resize(280),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
])

val_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

# Dataset class for loading the patches
class OASISPatchDataset(Dataset):
    def __init__(self, samples, transform):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        fp, lbl = self.samples[i]
        img = Image.open(fp).convert('L')  # Convert image to grayscale
        img_t = self.transform(img)
        return img_t, lbl

# Dataset objects
train_ds = OASISPatchDataset(train_list, train_tf)
val_ds = OASISPatchDataset(val_list, val_tf)

# DataLoader objects for batch processing
train_loader = DataLoader(train_ds, batch_size=16, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)

Using data root: /content/oasis_data/Data


#MIL Model

In [2]:
# Define a simpler CNN model for feature extraction
class SimpleCNN(nn.Module):
    def __init__(self, embed_dim):
        super(SimpleCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 64 * 64, embed_dim)  # Now embedding dimension matches output
        # No need for the final classification layer here, only embeddings are needed
        # self.fc2 = nn.Linear(256, num_classes)  # Classification layer if needed

    def forward(self, x):
        # Convolutional layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # Flatten the tensor to feed into fully connected layers
        x = torch.flatten(x, 1)  # Flatten from the second dimension (batch size is kept as-is)

        # Fully connected layers to get the embeddings (output dimension matches the embedding size)
        x = F.relu(self.fc1(x))
        return x  # Embeddings for contrastive loss

#Contrastive Learning

In [3]:
# Contrastive loss function for learning prototypes
class PrototypeContrastiveLoss(nn.Module):
    def __init__(self, n_cls, dim, temp):
        super().__init__()
        self.protos = nn.Parameter(torch.randn(n_cls, dim))  # Create class prototypes (n_cls x dim)
        self.temp = temp

    def forward(self, emb, labels):
        B, D = emb.shape  # B: batch size, D: embedding dimension
        # Ensure the prototypes are of the same size as embeddings (B x D)
        prototypes = self.protos.unsqueeze(0).expand(B, -1, -1)  # Expand prototypes to match batch size
        # Compute cosine similarity between embeddings and prototypes
        sims = F.cosine_similarity(emb.unsqueeze(1), prototypes, dim=-1) / self.temp
        return F.cross_entropy(sims, labels)  # Return cross-entropy loss

#Train & Evaluate

In [4]:
# Combined MIL and Contrastive Learning Model
class CombinedMILContrastiveCNN(nn.Module):
    def __init__(self, embed_dim, num_classes, temperature=0.1):
        super().__init__()
        self.cnn = SimpleCNN(embed_dim)  # Output embeddings from CNN
        self.prot_loss = PrototypeContrastiveLoss(num_classes, embed_dim, temperature)

    def forward(self, patches, labels=None):
        B, C, H, W = patches.shape  # Get batch size, channels, height, and width

        # Apply CNN to all patches
        emb = self.cnn(patches)  # Now emb is of shape [B, embed_dim]

        # Contrastive loss
        p_loss = self.prot_loss(emb, labels) if labels is not None else 0

        return emb, p_loss

# Training and evaluation function
def train_and_evaluate(model, train_loader, val_loader, epochs=10, lr=1e-4, proto_weight=0.1):
    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    for ep in range(epochs):
        model.train()
        all_p, all_t = [], []
        for patches, labels in tqdm(train_loader, desc=f"Epoch {ep + 1}"):
            patches, labels = patches.to(DEVICE), labels.to(DEVICE)
            opt.zero_grad()
            emb, p_loss = model(patches, labels)
            c_loss = F.cross_entropy(emb.view(-1, emb.shape[-1]), labels.repeat(patches.shape[1]))
            loss = (1 - proto_weight) * c_loss + proto_weight * p_loss
            loss.backward()
            opt.step()
            preds = emb.argmax(1)  # Changed to emb to reflect final prediction layer
            all_p += preds.cpu().tolist()
            all_t += labels.cpu().tolist()

        print(f"Epoch {ep + 1} F1 Score: {f1_score(all_t, all_p, average='weighted'):.4f}")

        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for patches, labels in tqdm(val_loader, desc=f"Evaluating Epoch {ep + 1}"):
                patches = patches.to(DEVICE)
                emb, _ = model(patches)
                preds = emb.argmax(1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.numpy())

        val_f1 = f1_score(val_labels, val_preds, average='weighted')
        print(f"Validation F1 Score: {val_f1:.4f}")

# Device setup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CombinedMILContrastiveCNN(embed_dim=128, num_classes=4, temperature=0.1).to(DEVICE)

# Train and evaluate
train_and_evaluate(model, train_loader, val_loader, epochs=10)

Epoch 1: 100%|██████████| 4322/4322 [04:19<00:00, 16.67it/s]


Epoch 1 F1 Score: 0.4923


Evaluating Epoch 1: 100%|██████████| 1081/1081 [00:43<00:00, 25.13it/s]


Validation F1 Score: 0.7168


Epoch 2: 100%|██████████| 4322/4322 [04:25<00:00, 16.29it/s]


Epoch 2 F1 Score: 0.5824


Evaluating Epoch 2: 100%|██████████| 1081/1081 [00:43<00:00, 25.02it/s]


Validation F1 Score: 0.7369


Epoch 3: 100%|██████████| 4322/4322 [04:16<00:00, 16.86it/s]


Epoch 3 F1 Score: 0.6222


Evaluating Epoch 3: 100%|██████████| 1081/1081 [00:43<00:00, 24.65it/s]


Validation F1 Score: 0.7559


Epoch 4: 100%|██████████| 4322/4322 [04:16<00:00, 16.82it/s]


Epoch 4 F1 Score: 0.6546


Evaluating Epoch 4: 100%|██████████| 1081/1081 [00:42<00:00, 25.57it/s]


Validation F1 Score: 0.7831


Epoch 5: 100%|██████████| 4322/4322 [04:21<00:00, 16.51it/s]


Epoch 5 F1 Score: 0.6784


Evaluating Epoch 5: 100%|██████████| 1081/1081 [00:43<00:00, 24.58it/s]


Validation F1 Score: 0.7678


Epoch 6: 100%|██████████| 4322/4322 [04:26<00:00, 16.25it/s]


Epoch 6 F1 Score: 0.6977


Evaluating Epoch 6: 100%|██████████| 1081/1081 [00:44<00:00, 24.09it/s]


Validation F1 Score: 0.7712


Epoch 7: 100%|██████████| 4322/4322 [04:18<00:00, 16.70it/s]


Epoch 7 F1 Score: 0.7114


Evaluating Epoch 7: 100%|██████████| 1081/1081 [00:42<00:00, 25.39it/s]


Validation F1 Score: 0.7915


Epoch 8: 100%|██████████| 4322/4322 [04:20<00:00, 16.58it/s]


Epoch 8 F1 Score: 0.7234


Evaluating Epoch 8: 100%|██████████| 1081/1081 [00:42<00:00, 25.59it/s]


Validation F1 Score: 0.7862


Epoch 9: 100%|██████████| 4322/4322 [04:18<00:00, 16.70it/s]


Epoch 9 F1 Score: 0.7281


Evaluating Epoch 9: 100%|██████████| 1081/1081 [00:44<00:00, 24.50it/s]


Validation F1 Score: 0.7788


Epoch 10: 100%|██████████| 4322/4322 [04:19<00:00, 16.68it/s]


Epoch 10 F1 Score: 0.7352


Evaluating Epoch 10: 100%|██████████| 1081/1081 [00:42<00:00, 25.66it/s]

Validation F1 Score: 0.7685



