In [1]:
from mimic_dataset import MIMIC_CXR_Dataset
from torch.utils.data import DataLoader

# Create dataset and dataloader
dataset = MIMIC_CXR_Dataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [2]:
from open_clip import create_model_from_pretrained, get_tokenizer

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

In [24]:
context_length = 256
num_pairs_needed = 100

# Lists to store features and metadata
all_image_features = []
all_text_features = []
all_paths = []
all_reports = []


with torch.no_grad():
    for batch_images, batch_paths, batch_reports in dataloader:
        
        batch_images = batch_images.to(device)
        # from PIL import Image
        # batch_images = torch.stack([preprocess(Image.open(path)) for path in batch_paths]).to(device)
        
        texts = tokenizer(batch_reports, context_length=context_length).to(device)
        
        image_features, text_features, logit_scale = model(batch_images, texts)
        
        # Store features and metadata
        all_image_features.append(image_features)
        all_text_features.append(text_features)
        all_paths.extend(batch_paths)
        all_reports.extend(batch_reports)
        
        if len(all_paths) >= num_pairs_needed:
            break

image_features = torch.cat(all_image_features, dim=0)
text_features = torch.cat(all_text_features, dim=0)

raw_similarity = logit_scale * image_features @ text_features.t()
similarity = raw_similarity.softmax(dim=-1)


max_sim_values, max_sim_indices = torch.max(similarity, dim=1)
max_raw_sim_values = torch.max(raw_similarity, dim=1)[0]


# Process results
results = []
for i, (path, pred_idx, sim_score, raw_sim_score) in enumerate(zip(all_paths, max_sim_indices, max_sim_values, max_raw_sim_values)):
    # Get the actual pair similarity (diagonal)
    actual_pair_sim = similarity[i, i].item()
    actual_pair_raw_sim = raw_similarity[i, i].item()
    actual_pair_text = all_reports[i]
    
    # Get highest similarity match
    highest_sim_text = all_reports[pred_idx]
    
    # Get ranking position of actual pair
    # Sort similarities for this image and find position of actual pair
    sorted_sims, _ = torch.sort(raw_similarity[i], descending=True)
    actual_pair_rank = torch.where(sorted_sims == actual_pair_raw_sim)[0][0].item() + 1
    
    # Create and store the tuple
    result_tuple = (
        path,
        sim_score.item(),
        raw_sim_score.item(),
        highest_sim_text,
        actual_pair_sim,
        actual_pair_raw_sim,
        actual_pair_text,
        actual_pair_rank
    )
    results.append(result_tuple)

# Print summary
print(f"\nTotal pairs processed: {len(results)}")


Total pairs processed: 128


In [None]:
print("\nFirst 3 results:")
for i, result in enumerate(results[:3]):
    path, sim_score, raw_sim_score, highest_text, actual_sim, actual_raw_sim, actual_text, actual_rank = result
    print(f"\nPair {i+1}:")
    print(f"Image: {path}")
    print(f"Highest similarity score (softmax): {sim_score:.3f}")
    print(f"Highest similarity score (raw): {raw_sim_score:.3f}")
    print(f"Highest similarity text: {highest_text}...")
    print(f"Actual pair similarity (softmax): {actual_sim:.3f}")
    print(f"Actual pair similarity (raw): {actual_raw_sim:.3f}")
    print(f"Actual pair rank: {actual_rank}")
    print(f"Actual pair text: {actual_text}...")


In [20]:
import pandas as pd

df = pd.DataFrame(results, columns=['image_path', 'highest_sim_score', 'highest_raw_sim_score', 'highest_sim_text', 'actual_pair_sim', 'actual_pair_raw_sim', 'actual_pair_text', 'actual_pair_rank'])
df.to_csv('clip_results.csv', index=False)

In [25]:
df = pd.DataFrame(results, columns=['image_path', 'highest_sim_score', 'highest_raw_sim_score', 'highest_sim_text', 'actual_pair_sim', 'actual_pair_raw_sim', 'actual_pair_text', 'actual_pair_rank'])
df['actual_pair_rank'].mean()

35.546875