In [1]:
!ls

__pycache__			    image-text-retrieval-evals.ipynb
attention_pool.py		    improved_dataloaders.py
attention_pool_helpers.py	    losses.py
chexpert			    metrics.py
chexpert.zip			    mimic
clip_training_config-mimic_lt.yaml  mimic.zip
clip_training_config-nih.yaml	    ml_decoder.py
cxr-bert-custom-vit.json	    modeling.py
cxr-bert-epoch_50.pt		    output.log
dino-mld-bce-1e-5		    siglip-cxr-bert-custom-vit.json
dino_model_config.yaml		    siglip-cxr-bert-epoch_50.pt
dino_training_config-mimic_lt.yaml  training_99999.pth
dino_training_config-nih.yaml	    wandb
dinov2				    wget-log
finetune.py			    wget-log.1


In [2]:
import json
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from open_clip import create_model_and_transforms, get_tokenizer
from open_clip.factory import _MODEL_CONFIGS

In [3]:
# Load the model and config files
model_name = "cxrclip_local"

with open("siglip-cxr-bert-custom-vit.json", "r") as f:
    config = json.load(f)
    model_cfg = config["model_cfg"]
    preprocess_cfg = config["preprocess_cfg"]

_MODEL_CONFIGS[model_name] = model_cfg
tokenizer = get_tokenizer(model_name)

In [4]:
# in case of siglip, we set init_logit_bias=-10 in the config. in case of clip, remove it entirely
model, _, preprocess = create_model_and_transforms(
    model_name=model_name,
    pretrained="siglip-cxr-bert-epoch_50.pt",
    **{f"image_{k}": v for k, v in preprocess_cfg.items()},
)

# Process the tar file
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

In [5]:
!ls mimic

files				       mimic-cxr-lt-merged.csv
mimic-cxr-2.0.0-merged-with-paths.csv  mimic_cxr_labels_reports_splits.csv


In [6]:
def load_pairs(df, image_dir="mimic/", img_path_column="img_path", report_column="report"):
    """Load image-text pairs from MIMIC/Chexpert dataset using dataframe"""
    valid_pairs = []
    
    # Iterate through dataframe rows
    for _, row in df.iterrows():
        try:
            # Construct full image path
            img_path = image_dir + row[img_path_column]
            
            # Load and preprocess image
            image = Image.open(img_path).convert('RGB')
            image = image.resize((224, 224), Image.BICUBIC)
            
            # Get corresponding report text
            text = row[report_column].strip()
            
            valid_pairs.append((image, text))
            
        except Exception as e:
            print(f"Error processing {row[img_path_column]}: {str(e)}")
            continue
            
    return valid_pairs

# Load MIMIC test set
import pandas as pd
mimic = pd.read_csv("mimic/mimic_cxr_labels_reports_splits.csv", sep=",")
mimic = mimic[mimic['split'] == "test"]
mimic['img_path'] = mimic['img_path'].apply(lambda x: x.replace("data/MIMIC-CXR/mimic-cxr-jpg-2.0.0-small/images/", ""))

chexpert = pd.read_csv("chexpert/chexpert_plus_labels_with_5x200.csv", sep="\t")
chexpert = chexpert[chexpert['split'].isin(["test", "valid"])][['path_to_image', 'report']]

# Get image-text pairs
mimic_image_text_pairs = load_pairs(mimic, image_dir="mimic/", img_path_column="img_path", report_column="report")
chexpert_image_text_pairs = load_pairs(chexpert, image_dir="chexpert/", img_path_column="path_to_image", report_column="report")

In [12]:
def get_clip_metrics(image_features, text_features, logit_scale, logit_bias=None):
    metrics = {}
    logits_per_image = (logit_scale * image_features @ text_features.t())
    if logit_bias:
        logits_per_image += logit_bias
    logits_per_image = logits_per_image.detach().cpu()
    logits_per_text = logits_per_image.t().detach().cpu()

    logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
    ground_truth = torch.arange(len(text_features)).view(-1, 1)

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        preds = torch.where(ranking == ground_truth)[1]
        preds = preds.detach().cpu().numpy()
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)

    return metrics

def process_image_text_pairs(image_text_pairs, model, tokenizer, preprocess, device, batch_size=100, use_logit_bias=False):
    """
    Process image-text pairs in batches and extract features using CLIP model
    
    Args:
        image_text_pairs: List of (image, text) tuples
        model: CLIP model
        tokenizer: CLIP tokenizer
        preprocess: Image preprocessing function
        device: Device to run model on
        batch_size: Batch size for processing
        
    Returns:
        image_features: Tensor of image features
        text_features: Tensor of text features
        logit_scale: Model's logit scale
    """
    n_pairs = len(image_text_pairs)
    all_image_features = []
    all_text_features = []
    logit_scale = None

    # Process pairs in batches
    for i in tqdm(range(0, n_pairs, batch_size), desc="Processing batches"):
        batch_slice = slice(i, min(i + batch_size, n_pairs))
        current_batch = image_text_pairs[batch_slice]
        
        # Prepare batch inputs
        batch_images = [img for img, _ in current_batch]
        batch_texts = [text for _, text in current_batch]
        
        images = torch.stack([preprocess(img) for img in batch_images]).to(device)
        texts = tokenizer(batch_texts, context_length=512).to(device)
        
        with torch.no_grad():
            #in case of siglip, use the get_logits function directly as it has a bias term
            if use_logit_bias:
                image_features, text_features, logit_scale, logit_bias = model(images, texts)
            else:
                image_features, text_features, logit_scale = model(images, texts) # enable this for clip
                logit_bias=None
            
            all_image_features.append(image_features)
            all_text_features.append(text_features)
    
    return torch.cat(all_image_features), torch.cat(all_text_features), logit_scale, logit_bias

In [13]:
# This cell is for siglip style inference.
# for clip, do this:
# image_features, text_features, logit_scale, logit_bias = process_image_text_pairs(mimic_image_text_pairs, model, tokenizer, preprocess, device, batch_size=500, use_logit_bias=False)
# mimic_metrics = get_clip_metrics(image_features, text_features, logit_scale)

# Concatenate all features
image_features, text_features, logit_scale, logit_bias = process_image_text_pairs(
    mimic_image_text_pairs, model, tokenizer, preprocess, device, batch_size=500, use_logit_bias=True)

# Compute metrics on all data
mimic_metrics = get_clip_metrics(image_features, text_features, logit_scale, logit_bias)

# Print metrics
print("\nMetrics:")
for key, value in mimic_metrics.items():
    print(f"{key}: {value:.4f}")

Processing batches:   0%|          | 0/11 [00:00<?, ?it/s]


Metrics:
image_to_text_mean_rank: 150.6711
image_to_text_median_rank: 21.0000
image_to_text_R@1: 0.0961
image_to_text_R@5: 0.2820
image_to_text_R@10: 0.3861
text_to_image_mean_rank: 155.6424
text_to_image_median_rank: 21.0000
text_to_image_R@1: 0.1016
text_to_image_R@5: 0.2888
text_to_image_R@10: 0.3945


In [14]:
# This cell is for siglip style inference
# for clip, do this:
# image_features, text_features, logit_scale, logit_bias = process_image_text_pairs(mimic_image_text_pairs, model, tokenizer, preprocess, device, batch_size=500, use_logit_bias=False)
# mimic_metrics = get_clip_metrics(image_features, text_features, logit_scale)

# Concatenate all features
image_features, text_features, logit_scale, logit_bias = process_image_text_pairs(
    chexpert_image_text_pairs, model, tokenizer, preprocess, device, batch_size=500, use_logit_bias=True)

# Compute metrics on all data
chexpert_metrics = get_clip_metrics(image_features, text_features, logit_scale, logit_bias)

# Print metrics
print("\nMetrics:")
for key, value in chexpert_metrics.items():
    print(f"{key}: {value:.4f}")

Processing batches:   0%|          | 0/3 [00:00<?, ?it/s]


Metrics:
image_to_text_mean_rank: 63.5640
image_to_text_median_rank: 7.0000
image_to_text_R@1: 0.2196
image_to_text_R@5: 0.4579
image_to_text_R@10: 0.5583
text_to_image_mean_rank: 80.7196
text_to_image_median_rank: 6.0000
text_to_image_R@1: 0.2350
text_to_image_R@5: 0.4708
text_to_image_R@10: 0.5810


In [25]:
def compute_all():
    # List of model definitions using each model's configuration and checkpoint
    models_info = [
        {
            "model_name": "scratch-transformer",
            "json_file": "scratch-transformer-custom-vit.json",
            "ckpt": "scratch-transformer-epoch_50.pt"
        },
        {
            "model_name": "bert",
            "json_file": "bert-custom-vit.json",
            "ckpt": "bert-epoch_50.pt"
        },
        {
            "model_name": "cxr-bert",
            "json_file": "cxr-bert-custom-vit.json",
            "ckpt": "cxr-bert-epoch_50.pt"
        }
    ]

    # Iterate over each model definition, load the model, and evaluate on both datasets.
    for info in models_info:
        print(f"\n========== Evaluating model: {info['model_name']} ==========")
        model_key = f"{info['model_name']}_local"
        with open(info["json_file"], "r") as f:
            config = json.load(f)
        model_cfg = config["model_cfg"]
        preprocess_cfg = config["preprocess_cfg"]
        _MODEL_CONFIGS[model_key] = model_cfg
        tokenizer = get_tokenizer(model_key)
        model, _, preprocess = create_model_and_transforms(
            model_name=model_key,
            pretrained=info["ckpt"],
            **{f"image_{k}": v for k, v in preprocess_cfg.items()},
        )
        model.to(device)
        model.eval()

        # Evaluate on MIMIC dataset
        print(f"\nMetrics for {info['model_name']} on MIMIC dataset:")
        img_features, txt_features, scale = process_image_text_pairs(
            mimic_image_text_pairs, model, tokenizer, preprocess, device, batch_size=100)
        metrics = get_clip_metrics(img_features, txt_features, scale)
        for key, value in metrics.items():
            print(f"{key}: {value:.4f}")

        # Evaluate on Chexpert dataset
        print(f"\nMetrics for {info['model_name']} on Chexpert dataset:")
        img_features, txt_features, scale = process_image_text_pairs(
            chexpert_image_text_pairs, model, tokenizer, preprocess, device, batch_size=100)
        metrics = get_clip_metrics(img_features, txt_features, scale)
        for key, value in metrics.items():
            print(f"{key}: {value:.4f}")

In [26]:
compute_all()



Metrics for scratch-transformer on MIMIC dataset:


Processing batches:   0%|          | 0/52 [00:00<?, ?it/s]

image_to_text_mean_rank: 201.0337
image_to_text_median_rank: 26.0000
image_to_text_Hit@1: 0.0861
image_to_text_Hit@5: 0.2594
image_to_text_Hit@10: 0.3561
image_to_text_MRR: 0.1727
text_to_image_mean_rank: 197.2134
text_to_image_median_rank: 25.0000
text_to_image_Hit@1: 0.0849
text_to_image_Hit@5: 0.2654
text_to_image_Hit@10: 0.3644
text_to_image_MRR: 0.1750

Metrics for scratch-transformer on Chexpert dataset:


Processing batches:   0%|          | 0/13 [00:00<?, ?it/s]

image_to_text_mean_rank: 110.7950
image_to_text_median_rank: 20.0000
image_to_text_Hit@1: 0.1718
image_to_text_Hit@5: 0.3541
image_to_text_Hit@10: 0.4222
image_to_text_MRR: 0.2629
text_to_image_mean_rank: 98.4279
text_to_image_median_rank: 11.0000
text_to_image_Hit@1: 0.1985
text_to_image_Hit@5: 0.3955
text_to_image_Hit@10: 0.4838
text_to_image_MRR: 0.2968



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]


Metrics for bert on MIMIC dataset:


Processing batches:   0%|          | 0/52 [00:00<?, ?it/s]

image_to_text_mean_rank: 157.9971
image_to_text_median_rank: 26.0000
image_to_text_Hit@1: 0.0816
image_to_text_Hit@5: 0.2469
image_to_text_Hit@10: 0.3460
image_to_text_MRR: 0.1699
text_to_image_mean_rank: 168.1653
text_to_image_median_rank: 26.0000
text_to_image_Hit@1: 0.0826
text_to_image_Hit@5: 0.2452
text_to_image_Hit@10: 0.3470
text_to_image_MRR: 0.1690

Metrics for bert on Chexpert dataset:


Processing batches:   0%|          | 0/13 [00:00<?, ?it/s]

image_to_text_mean_rank: 63.8979
image_to_text_median_rank: 7.0000
image_to_text_Hit@1: 0.2472
image_to_text_Hit@5: 0.4554
image_to_text_Hit@10: 0.5478
image_to_text_MRR: 0.3489
text_to_image_mean_rank: 82.2585
text_to_image_median_rank: 7.0000
text_to_image_Hit@1: 0.2131
text_to_image_Hit@5: 0.4627
text_to_image_Hit@10: 0.5648
text_to_image_MRR: 0.3312


Metrics for cxr-bert on MIMIC dataset:


Processing batches:   0%|          | 0/52 [00:00<?, ?it/s]

image_to_text_mean_rank: 143.9335
image_to_text_median_rank: 22.0000
image_to_text_Hit@1: 0.0950
image_to_text_Hit@5: 0.2760
image_to_text_Hit@10: 0.3819
image_to_text_MRR: 0.1872
text_to_image_mean_rank: 141.4169
text_to_image_median_rank: 21.0000
text_to_image_Hit@1: 0.0942
text_to_image_Hit@5: 0.2694
text_to_image_Hit@10: 0.3778
text_to_image_MRR: 0.1866

Metrics for cxr-bert on Chexpert dataset:


Processing batches:   0%|          | 0/13 [00:00<?, ?it/s]

image_to_text_mean_rank: 72.6442
image_to_text_median_rank: 8.0000
image_to_text_Hit@1: 0.2261
image_to_text_Hit@5: 0.4352
image_to_text_Hit@10: 0.5389
image_to_text_MRR: 0.3287
text_to_image_mean_rank: 84.6823
text_to_image_median_rank: 7.0000
text_to_image_Hit@1: 0.2212
text_to_image_Hit@5: 0.4538
text_to_image_Hit@10: 0.5527
text_to_image_MRR: 0.3301
