In [1]:
from easydict import EasyDict as edict
import torch
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F
import sys
import os
from tqdm import tqdm

sys.path.append("../../")

In [2]:
from datasets.dataset import MultiLabelDataset
from models import SingleExpertModel, BiomedCLIP

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
subsets = ["rsna-pulmonary-embolism", "chexpert", "lung-pet-ct-dx", "CC-CCII", "ssim-covid19"]
model_root = "/media/yesindeed/DATADRIVE1/mount/remote_cse/experiments/multitask-moe/biomedclip/seed0/lp"

models = []

for subset in subsets:
    backbone = BiomedCLIP()
    model = SingleExpertModel(backbone=backbone, residual_scale=0.1)

    try:
        ckpt = torch.load(os.path.join(model_root, f"train_{subset}", "checkpoint-latest.pth"))
    except:
        pass
    model.load_state_dict(ckpt["model"])

    models.append(model.cuda())

# Evaluate results

In [4]:
texts = [
    "An image of diseased",
    "An image of no findings",
]

# texts = [
#     "An imageof",
#     "An image of normal",
# ]

with torch.no_grad():
    text_embeddings = models[1].backbone.encode_text(texts).cpu()

In [5]:
text_embeddings.shape

torch.Size([2, 512])

In [6]:
ood_id = 0

test_set = MultiLabelDataset(
    args=edict(input_size=224, data_root="/media/yesindeed/DATADRIVE1/mount/remote_cse/datasets/multitask-moe"),
    subsets=subsets[ood_id],
    split="test",
)


# datasamples = Subset(test_set, range(min(1000, len(test_set))))

In [7]:
data_loader = DataLoader(test_set, batch_size=1, shuffle=False, drop_last=False)

image_features = []
labels = []

with torch.no_grad():
    features = []
    for batch in data_loader:
        image = batch["image"].cuda()

        feature = models[1].backbone.encode_image(image)
        features.append(feature)

        labels.append(batch["label"].squeeze()[0])

    features = torch.concat(features, dim=0).cpu().numpy()
    image_features.append(features)

    for model in tqdm(models):
        model.eval()

        features = []

        for batch in data_loader:
            image = batch["image"].cuda()

            feature = model.encode_image(image)

            features.append(feature)

        features = torch.concat(features, dim=0).cpu().numpy()
        image_features.append(features)

labels = torch.stack(labels).view(-1).cpu().numpy()

In [None]:
labels

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [None]:
subsets_ = ["rsna-pulmonary-embolism", "chexpert", "lung-pet-ct-dx", "CC-CCII", "ssim-covid19"]

image_features_pretrained = torch.tensor(image_features[0])

image_features_adapted = image_features[1:]
image_features_adapted = torch.stack([torch.tensor(x) for x in image_features[1:]])

In [None]:
image_features_pretrained.shape

torch.Size([1000, 512])

In [None]:
image_features_adapted.shape

torch.Size([5, 1000, 512])

In [None]:
from sklearn.metrics import roc_auc_score

for i, subset in enumerate(subsets_):
    logits = models[0].backbone.logit_scale * \
        F.normalize(image_features_adapted[i], dim=-1) @ text_embeddings.t()

    logits = torch.softmax(logits, dim=-1)[:, -1].detach().cpu().numpy()

    auc = roc_auc_score(labels, logits)

    print(f"{subset}: {auc * 100}")

rsna-pulmonary-embolism: 80.0490912041545
chexpert: 76.45245050308343
lung-pet-ct-dx: 76.90076273937035
CC-CCII: 74.71397273612463
ssim-covid19: 75.927052904901


In [None]:
models[0].backbone.logit_scale

85.2322769165039

In [None]:
logits = models[0].backbone.logit_scale * F.normalize(image_features_pretrained, dim=-1) @ text_embeddings.t()
logits = torch.softmax(logits, dim=-1)[:, -1].detach().cpu().numpy()

auc = roc_auc_score(labels, logits)
print(f"zero-shot: {auc * 100}")

zero-shot: 76.0183382018825


In [None]:
image_features_ood = torch.concat([image_features_adapted[:ood_id], image_features_adapted[ood_id + 1 :]], dim=0)
image_features_ensembled = torch.mean(image_features_ood, dim=0)

logits = models[0].backbone.logit_scale * F.normalize(image_features_ensembled, dim=-1) @ text_embeddings.t()
logits = torch.softmax(logits, dim=-1)[:, -1].detach().cpu().numpy()

auc = roc_auc_score(labels, logits)
print(f"ensemble-no-zs: {auc}")

ensemble-no-zs: 0.7615830899058748


In [None]:
image_features_ensembled = torch.mean(
    image_features_ood, dim=0) * 0.5 + image_features_pretrained * 0.5
image_features_ensembled = F.normalize(image_features_ensembled, dim=-1)

logits = models[0].backbone.logit_scale * \
    image_features_ensembled @ text_embeddings.t()
logits = torch.softmax(logits, dim=-1)[:, -1].detach().cpu().numpy()

auc = roc_auc_score(labels, logits)
print(f"ensemble: {auc}")

ensemble: 0.7602036676403765


## OOD: CXP

rsna-pulmonary-embolism: 0.4916133815216384

chexpert: 0.8632090528114993

lung-pet-ct-dx: 0.708120965001699

CC-CCII: 0.7538792615245214

ssim-covid19: 0.8346667490398378


zero-shot: 0.7366323788342137

ensemble-no-zs: 0.7897218875812149

ensemble: 0.7591254028562897

## OOD: lung pet ct dx

rsna-pulmonary-embolism: 0.7061092268525175

chexpert: 0.682443690741007

lung-pet-ct-dx: 0.8930483544924767

CC-CCII: 0.6562280620141925

ssim-covid19: 0.703504157116154

zero-shot: 0.7447369979108696

ensemble-no-zs: 0.7459634106114302

ensemble: 0.7485600223291692

## OOD: rsna-pulmonary-embolism

rsna-pulmonary-embolism: 0.8383844530996429

chexpert: 0.6130720545277506

lung-pet-ct-dx: 0.6887374229146381

CC-CCII: 0.5008317104836092

ssim-covid19: 0.6793857513794224

zero-shot: 0.7328180785459266

ensemble-no-zs: 0.6419587796170075

ensemble: 0.7204032781564428

## CC-CCII

rsna-pulmonary-embolism: 0.725738284703802

chexpert: 0.8657559681697613

lung-pet-ct-dx: 0.8848452696728559

CC-CCII: 0.9770999115826702

ssim-covid19: 0.9083377541998232

zero-shot: 0.8183642793987623

ensemble-no-zs: 0.9037400530503978

ensemble: 0.856445623342175

## OOD: SSIM-COVID

rsna-pulmonary-embolism: 0.3698609602678976

chexpert: 0.8022954566024075

lung-pet-ct-dx: 0.748461399221649

CC-CCII: 0.7846326590641687

ssim-covid19: 0.841524119829849

zero-shot: 0.7170134175038465

ensemble-no-zs: 0.7999847271246264

ensemble: 0.7484953389447009

# Norm Constrained

## lung-pet-ct-dx

rsna-pulmonary-embolism: 0.7039989512056906

chexpert: 0.7044683712393535

lung-pet-ct-dx: 0.8937842021128132

CC-CCII: 0.6626899882433541

ssim-covid19: 0.6773688795662727

zero-shot: 0.7348960932411974

ensemble-no-zs: 0.7496933968248598

ensemble: 0.7492620378750074

## OOD:SSIM-COVID

rsna-pulmonary-embolism: 0.4038487645940809

chexpert: 0.789426644945244

lung-pet-ct-dx: 0.685644628473165

CC-CCII: 0.73853403022898

ssim-covid19: 0.8443807131867137

zero-shot: 0.7236938863245543

ensemble-no-zs: 0.7065599828038737

ensemble: 0.7348996515521766