In [1]:
from easydict import EasyDict as edict
import torch
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F
import sys
import os
from tqdm import tqdm

sys.path.append("../../")

In [2]:
from datasets.dataset import MoEDataset
from models import SingleExpertModel, BiomedCLIP

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
subsets = ["rsna-pulmonary-embolism", "chexpert",
           "lung-pet-ct-dx", "CC-CCII", "ssim-covid19"]
model_root = "/media/yesindeed/DATADRIVE1/mount/remote_cse/experiments/multitask-moe/biomedclip/seed0/lp"

models = []

for subset in subsets:
    backbone = BiomedCLIP()
    model = SingleExpertModel(backbone=backbone)
    ckpt = torch.load(os.path.join(
        model_root, f"train_{subset}", "checkpoint-best.pth"))
    model.load_state_dict(ckpt["model"])

    models.append(model.cuda())

In [4]:
ood_id = 4

test_set = MoEDataset(
    args=edict(input_size=224,
               data_root="/media/yesindeed/DATADRIVE1/mount/remote_cse/datasets/multitask-moe"),
    subsets=subsets[ood_id],
    split="test",
)


datasamples = Subset(test_set, range(min(1000, len(test_set))))

In [7]:
data_loader = DataLoader(datasamples, batch_size=1, shuffle=False)

image_features = []
labels = []

with torch.no_grad():
    features = []
    for batch in data_loader:
        image = batch["image"].cuda()

        feature = models[0].backbone.encode_image(image)
        features.append(feature)

        labels.append(batch["label"].squeeze()[0])

    features = torch.concat(features, dim=0).cpu().numpy()
    image_features.append(features)

    for model in tqdm(models):
        model.eval()

        features = []

        for batch in data_loader:
            image = batch["image"].cuda()

            feature = model.encode_image(image)

            features.append(feature)

        features = torch.concat(features, dim=0).cpu().numpy()
        image_features.append(features)

labels = torch.stack(labels).view(-1).cpu().numpy()

100%|██████████| 5/5 [00:56<00:00, 11.27s/it]


In [12]:
features.shape

(950, 512)

In [9]:
subsets_ = ["rsna-pulmonary-embolism", "chexpert", "lung-pet-ct-dx", "CC-CCII", "ssim-covid19"]

image_features_pretrained = torch.tensor(image_features[0])

image_features_adapted = image_features[1:]
image_features_adapted = torch.stack([torch.tensor(x) for x in image_features[1:]])

In [10]:
residual = image_features_adapted - image_features_pretrained.unsqueeze(0)

pretrained_norm = torch.norm(image_features_pretrained, dim=-1)
residual_norm = torch.norm(residual, dim=-1)

In [13]:
pretrained_norm

tensor([77.7565, 85.5108, 78.9626, 88.1821, 86.0995, 84.6493, 84.8844, 87.0345,
        82.2796, 79.9496, 80.9448, 82.9127, 85.4861, 82.2506, 80.3074, 86.0218,
        82.1236, 82.6381, 85.1434, 86.2288, 85.8321, 85.5393, 81.7297, 88.0879,
        81.5135, 84.2252, 78.6687, 83.1744, 80.6136, 81.4652, 85.6169, 81.2864,
        86.4124, 79.1082, 81.3952, 82.3270, 84.5872, 82.6520, 81.6595, 80.7244,
        79.1655, 86.6975, 81.4506, 82.5677, 84.9450, 82.4140, 81.7583, 84.2297,
        82.2691, 80.4645, 80.4314, 86.7129, 81.3215, 85.7796, 80.5423, 82.5270,
        79.8865, 79.5414, 83.0444, 87.1571, 80.1011, 79.6366, 80.0548, 80.8346,
        81.5372, 82.2432, 81.4792, 84.0228, 76.0700, 83.7190, 77.9396, 84.8763,
        79.8867, 80.1247, 84.8632, 84.4007, 80.1273, 73.8117, 78.2971, 83.7438,
        83.3171, 85.8837, 81.8714, 85.1907, 88.1669, 79.4432, 84.3472, 80.0486,
        85.2283, 85.7795, 85.2496, 78.8896, 86.4983, 81.3728, 88.3089, 80.3275,
        85.1390, 78.8742, 82.1485, 82.55

In [17]:
scale = residual_norm / pretrained_norm

scale.mean(-1)

tensor([3.8321, 5.8617, 3.8708, 3.0488, 3.3099])