In [64]:
# ===============================================================
# Importing libraries
# ===============================================================
import os
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import clip
import numpy as np
from sklearn.metrics import roc_auc_score


dataRoot = "/content/retina/retina"

# ===============================================================
# Preparing CLIP model
# ===============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
clipModel, clipPreprocess = clip.load("ViT-B/32", device=device)
clipModel.eval()


# ===============================================================
# Retina dataset class
# ===============================================================
class RetinaDataset(Dataset):
    def __init__(self, rootDir):
        self.samples = []
        self.preprocess = clipPreprocess

        # Scanning folders and building sample list
        for clsName in os.listdir(rootDir):
            clsPath = os.path.join(rootDir, clsName)
            imgPath = os.path.join(clsPath, "img")

            if not os.path.isdir(imgPath):
                continue

            label = 0 if clsName.lower() == "good" else 1

            for f in os.listdir(imgPath):
                if f.lower().endswith(("png", "jpg", "jpeg")):
                    self.samples.append((os.path.join(imgPath, f), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        return self.preprocess(img), label


# ===============================================================
# Creating loaders
# ===============================================================
validRoot = os.path.join(dataRoot, "valid")
testRoot  = os.path.join(dataRoot, "test")

trainDs = RetinaDataset(validRoot)
testDs  = RetinaDataset(testRoot)

trainLoader = DataLoader(trainDs, batch_size=32, shuffle=True)
testLoader  = DataLoader(testDs, batch_size=32, shuffle=False)

print("Train samples:", len(trainDs))
print("Test samples:", len(testDs))


# ===============================================================
# Building multi-prompt text templates
# ===============================================================
normalTemplates = [
    "a healthy retinal OCT image", "a normal retinal OCT scan",
    "a retina with no abnormalities", "a clean retinal OCT image",
    "a normal eye OCT", "a retinal OCT showing no disease",
    "a healthy retina", "normal retinal tissue OCT",
    "a medical OCT of a healthy retina", "healthy retinal structure"
]

anomalyTemplates = [
    "an abnormal retinal OCT image", "a diseased retina OCT",
    "a retinal OCT showing pathology", "a damaged retinal OCT",
    "a retina with anomalies", "a retina showing disease in OCT",
    "an unhealthy retina", "abnormal retinal tissue OCT",
    "retinal degeneration in OCT", "retinal disease scan"
]


def buildTextEmbedding(promptList):
    tokens = clip.tokenize(promptList).to(device)
    with torch.no_grad():
        feats = clipModel.encode_text(tokens).float()
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.mean(dim=0)


normalTextFeat  = buildTextEmbedding(normalTemplates)
anomalyTextFeat = buildTextEmbedding(anomalyTemplates)
textFeatures = torch.stack([normalTextFeat, anomalyTextFeat], dim=0).float().to(device)


# ===============================================================
# Zero-shot classification functions
# ===============================================================
def computeZeroShotScore(images):
    with torch.no_grad():
        imgFeat = clipModel.encode_image(images).float()
        imgFeat = imgFeat / imgFeat.norm(dim=-1, keepdim=True)

        logits = imgFeat @ textFeatures.T
        anomalyScore = logits[:, 1] - logits[:, 0]
        return anomalyScore


def evaluateZeroShot(loader):
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            logits = computeZeroShotScore(images)
            preds = (logits > 0).long().cpu()

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total


def evaluateZeroShotAUC(loader):
    allScores, allLabels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            score = computeZeroShotScore(images)
            allScores.extend(score.cpu().numpy())
            allLabels.extend(labels.numpy())

    return roc_auc_score(allLabels, allScores)


# ===============================================================
# MVFA-inspired adapter (deep MLP + LayerNorm + skip + gate)
# ===============================================================
class MVFAAdapter(nn.Module):
    def __init__(self, dim=512, hidden=1024, gated=True):
        super().__init__()
        self.gated = gated

        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, dim)
        )

        self.ln = nn.LayerNorm(dim)
        self.gate = nn.Parameter(torch.ones(dim)) if gated else None

    def forward(self, x):
        out = self.mlp(x)

        if self.gated:
            out = out * self.gate

        out = self.ln(out + x)
        return out


# ===============================================================
# Few-shot sampling (balanced, k per class)
# ===============================================================
def getFewShotSamples(dataset, k):
    normal = [s for s in dataset.samples if s[1] == 0][:k]
    anomaly = [s for s in dataset.samples if s[1] == 1][:k]
    return normal + anomaly


class FewShotDataset(Dataset):
    def __init__(self, sampleList):
        self.samples = sampleList
        self.preprocess = clipPreprocess

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        return self.preprocess(img), label


# ===============================================================
# Few-shot adapter training
# ===============================================================
def trainAdapter(k=4, epochs=10, lr=1e-4):
    sampleList = getFewShotSamples(trainDs, k)
    fsDs = FewShotDataset(sampleList)
    fsLoader = DataLoader(fsDs, batch_size=4, shuffle=True)

    adapter = MVFAAdapter(dim=clipModel.visual.output_dim).to(device)
    optimizer = torch.optim.Adam(adapter.parameters(), lr=lr)

    for epoch in range(epochs):
        for images, labels in fsLoader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                imgFeat = clipModel.encode_image(images).float()

            adapted = adapter(imgFeat)
            logits = adapted @ textFeatures.T

            loss = F.cross_entropy(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

    return adapter


# ===============================================================
# Evaluating adapter + fusion scoring
# ===============================================================
def computeFewShotScore(images, adapter):
    with torch.no_grad():
        imgFeat = clipModel.encode_image(images).float()
        adapted = adapter(imgFeat)
        logits = adapted @ textFeatures.T
        score = logits[:, 1] - logits[:, 0]
        return score


def evaluateAdapter(loader, adapter):
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.cpu()   # moving labels to CPU

            score = computeFewShotScore(images, adapter).cpu()
            preds = (score > 0).long()

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total


def evaluateAdapterAUC(loader, adapter):
    allScores, allLabels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            score = computeFewShotScore(images, adapter).cpu()
            allScores.extend(score.numpy())
            allLabels.extend(labels.numpy())

    return roc_auc_score(allLabels, allScores)



def evaluateFusionScore(loader, adapter, alpha=0.5):
    allScores, allLabels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            scoreZS = computeZeroShotScore(images).cpu()
            scoreFS = computeFewShotScore(images, adapter).cpu()

            fused = alpha * scoreZS + (1 - alpha) * scoreFS

            allScores.extend(fused.numpy())
            allLabels.extend(labels.numpy())

    return roc_auc_score(allLabels, allScores)



# ===============================================================
# Running experiments
# ===============================================================
zeroAcc = evaluateZeroShot(testLoader)
zeroAUC = evaluateZeroShotAUC(testLoader)

print(f"\nZero-Shot Accuracy: {zeroAcc:.4f}")
print(f"Zero-Shot AUC: {zeroAUC*100:.2f}%\n")

for k in [2, 4, 8]:
    print(f"===== FEW-SHOT (k={k}) =====")

    adapter = trainAdapter(k=k, epochs=10)

    fsAcc = evaluateAdapter(testLoader, adapter)
    fsAUC = evaluateAdapterAUC(testLoader, adapter)
    fusionAUC = evaluateFusionScore(testLoader, adapter)

    print(f"Few-Shot Accuracy: {fsAcc:.4f}")
    print(f"Few-Shot AUC: {fsAUC*100:.2f}%")
    print(f"Fusion AUC: {fusionAUC*100:.2f}%\n")


Train samples: 40
Test samples: 968

Zero-Shot Accuracy: 0.5496
Zero-Shot AUC: 23.76%

===== FEW-SHOT (k=2) =====
Epoch 1/10 | Loss: 0.6820
Epoch 2/10 | Loss: 0.6753
Epoch 3/10 | Loss: 0.6690
Epoch 4/10 | Loss: 0.6629
Epoch 5/10 | Loss: 0.6567
Epoch 6/10 | Loss: 0.6503
Epoch 7/10 | Loss: 0.6438
Epoch 8/10 | Loss: 0.6368
Epoch 9/10 | Loss: 0.6296
Epoch 10/10 | Loss: 0.6219
Few-Shot Accuracy: 0.3874
Few-Shot AUC: 31.97%
Fusion AUC: 31.73%

===== FEW-SHOT (k=4) =====
Epoch 1/10 | Loss: 0.6783
Epoch 2/10 | Loss: 0.6835
Epoch 3/10 | Loss: 0.6729
Epoch 4/10 | Loss: 0.6540
Epoch 5/10 | Loss: 0.6634
Epoch 6/10 | Loss: 0.6551
Epoch 7/10 | Loss: 0.6625
Epoch 8/10 | Loss: 0.6399
Epoch 9/10 | Loss: 0.6449
Epoch 10/10 | Loss: 0.6103
Few-Shot Accuracy: 0.6271
Few-Shot AUC: 66.48%
Fusion AUC: 65.50%

===== FEW-SHOT (k=8) =====
Epoch 1/10 | Loss: 0.7037
Epoch 2/10 | Loss: 0.6706
Epoch 3/10 | Loss: 0.7079
Epoch 4/10 | Loss: 0.6694
Epoch 5/10 | Loss: 0.6839
Epoch 6/10 | Loss: 0.6642
Epoch 7/10 | Loss: 0

In [65]:
import zipfile
import os

zipPath = "/content/liver.zip"
dataRoot = "/content/liver"

# Extracting the retina dataset
with zipfile.ZipFile(zipPath, "r") as zipRef:
    zipRef.extractall(dataRoot)

print("Dataset is extracted to:", dataRoot)


Dataset is extracted to: /content/liver


In [3]:
# ================================================================
# Importing libraries
# ================================================================
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
import clip
import matplotlib.pyplot as plt


# ================================================================
# Setting paths
# ================================================================
dataRoot = "/content/liver/liver"
device = "cuda" if torch.cuda.is_available() else "cpu"
imgSize = 256  # using 256×256 resolution


# ================================================================
# Loading CLIP model
# ================================================================
clipModel, clipPreprocess = clip.load("ViT-B/32", device=device)
clipModel.eval()


# ================================================================
# Segmentation dataset loader
# ================================================================
class LiverSegDataset(Dataset):
    def __init__(self, rootDir):
        self.imgList = []
        self.maskList = []

        for clsName in os.listdir(rootDir):
            clsFolder = os.path.join(rootDir, clsName)
            imgFolder = os.path.join(clsFolder, "img")
            maskFolder = os.path.join(clsFolder, "anomaly_mask")

            if not os.path.isdir(imgFolder):
                continue

            for f in os.listdir(imgFolder):
                if f.lower().endswith(("png", "jpg", "jpeg")):
                    imgPath = os.path.join(imgFolder, f)
                    maskPath = os.path.join(maskFolder, f)

                    if os.path.exists(maskPath):
                        self.imgList.append(imgPath)
                        self.maskList.append(maskPath)

        self.preprocess = clipPreprocess

    def __len__(self):
        return len(self.imgList)

    def loadMask(self, path):
        mask = Image.open(path).convert("L").resize((imgSize, imgSize))
        mask = np.array(mask)
        mask = (mask > 128).astype(np.float32)
        return torch.tensor(mask).unsqueeze(0)  # [1,h,w]

    def __getitem__(self, idx):
        img = Image.open(self.imgList[idx]).convert("RGB")
        img = img.resize((imgSize, imgSize))
        imgTensor = self.preprocess(img)

        maskTensor = self.loadMask(self.maskList[idx])

        label = 1 if "Ungood" in self.imgList[idx] else 0
        return imgTensor, maskTensor, label


# ================================================================
# Creating loaders
# ================================================================
trainPath = os.path.join(dataRoot, "valid")
testPath  = os.path.join(dataRoot, "test")

trainDs = LiverSegDataset(trainPath)
testDs  = LiverSegDataset(testPath)

trainLoader = DataLoader(trainDs, batch_size=8, shuffle=True)
testLoader  = DataLoader(testDs, batch_size=8, shuffle=False)

print("Train images:", len(trainDs))
print("Test images:", len(testDs))


# ================================================================
# Multi-prompt text embeddings for classification component
# ================================================================
normalPrompts = [
    "a healthy liver CT scan",
    "a normal liver scan",
    "a liver CT image without abnormalities",
    "clear liver tissue CT",
    "normal medical liver CT"
]

anomalyPrompts = [
    "an abnormal liver CT scan",
    "a diseased liver CT",
    "a liver CT image showing anomalies",
    "liver lesion CT scan",
    "CT scan of abnormal liver tissue"
]


def buildTextEmbedding(promptList):
    tokens = clip.tokenize(promptList).to(device)
    with torch.no_grad():
        feats = clipModel.encode_text(tokens).float()
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.mean(dim=0)


normalTextFeat = buildTextEmbedding(normalPrompts)
anomalyTextFeat = buildTextEmbedding(anomalyPrompts)
textFeatures = torch.stack([normalTextFeat, anomalyTextFeat], dim=0).to(device)


# ================================================================
# Zero-shot segmentation (CLIP similarity map)
# ================================================================
def computeZeroShotSeg(images):
    # Running CLIP and extracting patch features
    with torch.no_grad():
        imgFeat = clipModel.encode_image(images).float()
        imgFeat = imgFeat / imgFeat.norm(dim=-1, keepdim=True)

        logits = imgFeat @ textFeatures.T
        anomalyScore = logits[:, 1] - logits[:, 0]  # scalar anomaly confidence

        # Creating simple heatmap: repeat score to image size
        heatmap = anomalyScore.unsqueeze(-1).unsqueeze(-1)
        heatmap = heatmap.repeat(1, imgSize, imgSize)
        return heatmap


# ================================================================
# Segmentation head (simple UNet-like 1x1 conv + upsample)
# ================================================================
class SegHead(nn.Module):
    def __init__(self, inDim=512):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inDim, 256, 1),
            nn.ReLU(),
            nn.Conv2d(256, 1, 1)
        )

    def forward(self, x):
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = x.repeat(1, 1, imgSize, imgSize)
        return self.conv(x)


# ================================================================
# Few-shot segmentation adapter (MVFA-inspired)
# ================================================================
class SegAdapter(nn.Module):
    def __init__(self, dim=512, hidden=1024):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, dim)
        )
        self.ln = nn.LayerNorm(dim)

    def forward(self, x):
        out = self.mlp(x)
        out = self.ln(out + x)
        return out


# ================================================================
# Dice loss + BCE loss
# ================================================================
def diceLoss(pred, mask):
    pred = torch.sigmoid(pred)
    smooth = 1e-6
    inter = (pred * mask).sum()
    union = pred.sum() + mask.sum()
    return 1 - (2 * inter + smooth) / (union + smooth)


def segLoss(pred, mask):
    bce = F.binary_cross_entropy_with_logits(pred, mask)
    dsc = diceLoss(pred, mask)
    return bce + dsc


# ================================================================
# Training segmentation adapter + head
# ================================================================
def trainSegmentation(epochs=15, k=None):
    adapter = SegAdapter().to(device)
    head = SegHead().to(device)
    optim = torch.optim.Adam(list(adapter.parameters()) + list(head.parameters()), lr=1e-4)

    for epoch in range(epochs):
        runningLoss = 0
        for images, masks, labels in trainLoader:
            images = images.to(device)
            masks = masks.to(device)

            with torch.no_grad():
                imgFeat = clipModel.encode_image(images).float()

            adapted = adapter(imgFeat)
            predMask = head(adapted)

            loss = segLoss(predMask, masks)
            optim.zero_grad()
            loss.backward()
            optim.step()

            runningLoss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} | Loss: {runningLoss/len(trainLoader):.4f}")

    return adapter, head


# ================================================================
# Evaluation: localization AUC
# ================================================================



# ================================================================
# Visualization helper
# ================================================================
def showPrediction(images, masks, predMasks):
    for i in range(min(3, len(images))):
        img = images[i].permute(1,2,0).cpu().numpy()
        m = masks[i][0].cpu().numpy()
        p = predMasks[i][0].detach().cpu().numpy()

        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1); plt.imshow(img); plt.title("Image")
        plt.subplot(1,3,2); plt.imshow(m, cmap="gray"); plt.title("GT Mask")
        plt.subplot(1,3,3); plt.imshow(p, cmap="jet"); plt.title("Pred Mask")
        plt.show()


# ================================================================
# Running segmentation training & evaluation
# ================================================================
adapter, head = trainSegmentation(epochs=15)




Train images: 166
Test images: 1493
Epoch 1/15 | Loss: 1.2265
Epoch 2/15 | Loss: 1.0230
Epoch 3/15 | Loss: 1.0196
Epoch 4/15 | Loss: 1.0185
Epoch 5/15 | Loss: 1.0174
Epoch 6/15 | Loss: 1.0164
Epoch 7/15 | Loss: 1.0135
Epoch 8/15 | Loss: 1.0079
Epoch 9/15 | Loss: 1.0041
Epoch 10/15 | Loss: 0.9999
Epoch 11/15 | Loss: 0.9989
Epoch 12/15 | Loss: 0.9965
Epoch 13/15 | Loss: 0.9930
Epoch 14/15 | Loss: 0.9990
Epoch 15/15 | Loss: 1.0000


In [4]:
def evaluateSegAUC(loader, adapter, head):
    allScores, allLabels = [], []

    with torch.no_grad():
        for images, masks, labels in loader:
            images = images.to(device)
            masks = masks.to(device)

            imgFeat = clipModel.encode_image(images).float()
            adapted = adapter(imgFeat)
            predMask = torch.sigmoid(head(adapted))

            # ===== Downsampling masks & predictions to 128×128 =====
            predMask128 = F.interpolate(predMask, size=(128, 128), mode="bilinear").cpu().numpy()
            mask128     = F.interpolate(masks, size=(128, 128), mode="nearest").cpu().numpy()

            # Flattening each image's pixels
            predFlat = predMask128.reshape(predMask128.shape[0], -1)
            maskFlat = mask128.reshape(mask128.shape[0], -1)

            for p, m in zip(predFlat, maskFlat):
                allScores.extend(p)
                allLabels.extend(m)

    return roc_auc_score(allLabels, allScores)

auc = evaluateSegAUC(testLoader, adapter, head)
print("\nSegmentation AUC:", auc)


Segmentation AUC: 0.875160833531623
