In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [78]:
import torch
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch import nn, optim
from tqdm import tqdm
import numpy as np

In [79]:
dataset = load_dataset("itsanmolgupta/mimic-cxr-dataset", split="train")
print(dataset[0])

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7FAD87A07ED0>, 'findings': 'The lungs are clear of focal consolidation, pleural effusion or pneumothorax. The heart size is normal. The mediastinal contours are normal. Multiple surgical clips project over the left breast, and old left rib fractures are noted. ', 'impression': 'No acute cardiopulmonary process.'}


In [80]:
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

In [81]:
lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=[
        "text_model.encoder.layers.11.self_attn.q_proj",
        "text_model.encoder.layers.11.self_attn.v_proj",
        "vision_model.encoder.layers.11.self_attn.query",
        "vision_model.encoder.layers.11.self_attn.value"
    ]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 16,384 || all params: 151,293,697 || trainable%: 0.0108


In [88]:
named_modules = list(model.named_modules())
for i, (name, module) in enumerate(named_modules):
    print(f"{i:3d}: {name}")

  0: 
  1: base_model
  2: base_model.model
  3: base_model.model.text_model
  4: base_model.model.text_model.embeddings
  5: base_model.model.text_model.embeddings.token_embedding
  6: base_model.model.text_model.embeddings.position_embedding
  7: base_model.model.text_model.encoder
  8: base_model.model.text_model.encoder.layers
  9: base_model.model.text_model.encoder.layers.0
 10: base_model.model.text_model.encoder.layers.0.self_attn
 11: base_model.model.text_model.encoder.layers.0.self_attn.k_proj
 12: base_model.model.text_model.encoder.layers.0.self_attn.v_proj
 13: base_model.model.text_model.encoder.layers.0.self_attn.q_proj
 14: base_model.model.text_model.encoder.layers.0.self_attn.out_proj
 15: base_model.model.text_model.encoder.layers.0.layer_norm1
 16: base_model.model.text_model.encoder.layers.0.mlp
 17: base_model.model.text_model.encoder.layers.0.mlp.activation_fn
 18: base_model.model.text_model.encoder.layers.0.mlp.fc1
 19: base_model.model.text_model.encoder.laye

In [89]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.ds = dataset

    def __getitem__(self, idx):
        item = self.ds[idx]
        return {
            "text": item["findings"],
            "image": item["image"]
        }

    def __len__(self):
        return len(self.ds)

In [91]:
def collate_fn(batch):
    texts = [item["text"] for item in batch]
    images = [item["image"] for item in batch]
    
    inputs = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=77
    )
    return inputs

In [93]:
train_dataset = CLIPDataset(dataset)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

def contrastive_loss(logits_per_image, logits_per_text, temperature=0.07):
    batch_size = logits_per_image.shape[0]
    labels = torch.arange(batch_size, device=logits_per_image.device)
    
    logits_per_image = logits_per_image / temperature
    logits_per_text = logits_per_text / temperature
    
    loss_img = F.cross_entropy(logits_per_image, labels)
    loss_txt = F.cross_entropy(logits_per_text, labels)
    
    return (loss_img + loss_txt) / 2

In [94]:
def safe_forward(model, input_ids, attention_mask, pixel_values):
    if hasattr(model, 'base_model'):
        base_model = model.base_model
    else:
        base_model = model
        
    text_outputs = base_model.text_model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    
    vision_outputs = base_model.vision_model(pixel_values=pixel_values)
    text_embeds = text_outputs.pooler_output
    image_embeds = vision_outputs.pooler_output
    text_embeds = base_model.text_projection(text_embeds)
    image_embeds = base_model.visual_projection(image_embeds)

    class CLIPOutput:
        def __init__(self, text_embeds, image_embeds):
            self.text_embeds = text_embeds
            self.image_embeds = image_embeds
    
    return CLIPOutput(text_embeds, image_embeds)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): CLIPModel(
      (text_model): CLIPTextTransformer(
        (embeddings): CLIPTextEmbeddings(
          (token_embedding): Embedding(49408, 512)
          (position_embedding): Embedding(77, 512)
        )
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-10): 11 x CLIPEncoderLayer(
              (self_attn): CLIPSdpaAttention(
                (k_proj): Linear(in_features=512, out_features=512, bias=True)
                (v_proj): Linear(in_features=512, out_features=512, bias=True)
                (q_proj): Linear(in_features=512, out_features=512, bias=True)
                (out_proj): Linear(in_features=512, out_features=512, bias=True)
              )
              (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                (activation_fn): QuickGELUActivation()
                (fc1): Linear(in_features=512, out_features=2048, bi

In [95]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 1

for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = safe_forward(model, input_ids, attention_mask, pixel_values)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds

        image_embeds = F.normalize(image_embeds, p=2, dim=1)
        text_embeds = F.normalize(text_embeds, p=2, dim=1)
        
        logits_per_image = torch.matmul(image_embeds, text_embeds.t())
        logits_per_text = torch.matmul(text_embeds, image_embeds.t())
        
        loss = contrastive_loss(logits_per_image, logits_per_text)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        if num_batches % 100 == 0:
            avg_loss = total_loss / num_batches
            print(f"Batch {num_batches}, Average Loss: {avg_loss:.4f}")
    
    if num_batches > 0:
        avg_epoch_loss = total_loss / num_batches
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_epoch_loss:.4f}")
    else:
        print(f"Epoch {epoch+1} completed with no successful batches")

Epoch 1/1:   1%|▏         | 102/7659 [00:05<06:53, 18.28it/s]

Batch 100, Average Loss: 1.3993


Epoch 1/1:   3%|▎         | 202/7659 [00:11<06:41, 18.56it/s]

Batch 200, Average Loss: 1.4009


Epoch 1/1:   4%|▍         | 302/7659 [00:16<06:42, 18.27it/s]

Batch 300, Average Loss: 1.3962


Epoch 1/1:   5%|▌         | 402/7659 [00:21<06:45, 17.89it/s]

Batch 400, Average Loss: 1.3947


Epoch 1/1:   7%|▋         | 502/7659 [00:27<06:27, 18.46it/s]

Batch 500, Average Loss: 1.3926


Epoch 1/1:   8%|▊         | 602/7659 [00:33<06:21, 18.50it/s]

Batch 600, Average Loss: 1.3913


Epoch 1/1:   9%|▉         | 702/7659 [00:38<06:26, 18.01it/s]

Batch 700, Average Loss: 1.3904


Epoch 1/1:  10%|█         | 798/7659 [00:43<06:12, 18.42it/s]

Batch 800, Average Loss: 1.3902


Epoch 1/1:  12%|█▏        | 902/7659 [00:49<06:08, 18.31it/s]

Batch 900, Average Loss: 1.3887


Epoch 1/1:  13%|█▎        | 1002/7659 [00:54<05:59, 18.52it/s]

Batch 1000, Average Loss: 1.3875


Epoch 1/1:  14%|█▍        | 1102/7659 [01:00<06:40, 16.37it/s]

Batch 1100, Average Loss: 1.3860


Epoch 1/1:  16%|█▌        | 1202/7659 [01:05<05:48, 18.54it/s]

Batch 1200, Average Loss: 1.3843


Epoch 1/1:  17%|█▋        | 1302/7659 [01:11<05:43, 18.50it/s]

Batch 1300, Average Loss: 1.3831


Epoch 1/1:  18%|█▊        | 1402/7659 [01:16<05:37, 18.57it/s]

Batch 1400, Average Loss: 1.3811


Epoch 1/1:  20%|█▉        | 1502/7659 [01:21<05:43, 17.95it/s]

Batch 1500, Average Loss: 1.3791


Epoch 1/1:  21%|██        | 1602/7659 [01:27<05:26, 18.53it/s]

Batch 1600, Average Loss: 1.3766


Epoch 1/1:  22%|██▏       | 1702/7659 [01:32<05:21, 18.50it/s]

Batch 1700, Average Loss: 1.3748


Epoch 1/1:  24%|██▎       | 1802/7659 [01:38<05:41, 17.14it/s]

Batch 1800, Average Loss: 1.3730


Epoch 1/1:  25%|██▍       | 1902/7659 [01:43<05:16, 18.22it/s]

Batch 1900, Average Loss: 1.3713


Epoch 1/1:  26%|██▌       | 2002/7659 [01:49<05:00, 18.85it/s]

Batch 2000, Average Loss: 1.3689


Epoch 1/1:  27%|██▋       | 2102/7659 [01:54<05:00, 18.50it/s]

Batch 2100, Average Loss: 1.3669


Epoch 1/1:  29%|██▉       | 2202/7659 [02:00<04:52, 18.69it/s]

Batch 2200, Average Loss: 1.3640


Epoch 1/1:  30%|███       | 2302/7659 [02:05<04:47, 18.66it/s]

Batch 2300, Average Loss: 1.3623


Epoch 1/1:  31%|███▏      | 2402/7659 [02:10<04:43, 18.56it/s]

Batch 2400, Average Loss: 1.3603


Epoch 1/1:  33%|███▎      | 2502/7659 [02:16<04:38, 18.52it/s]

Batch 2500, Average Loss: 1.3588


Epoch 1/1:  34%|███▍      | 2602/7659 [02:21<04:33, 18.51it/s]

Batch 2600, Average Loss: 1.3570


Epoch 1/1:  35%|███▌      | 2702/7659 [02:27<04:25, 18.70it/s]

Batch 2700, Average Loss: 1.3563


Epoch 1/1:  37%|███▋      | 2802/7659 [02:32<04:20, 18.61it/s]

Batch 2800, Average Loss: 1.3550


Epoch 1/1:  38%|███▊      | 2902/7659 [02:38<04:34, 17.32it/s]

Batch 2900, Average Loss: 1.3538


Epoch 1/1:  39%|███▉      | 3002/7659 [02:43<04:13, 18.34it/s]

Batch 3000, Average Loss: 1.3528


Epoch 1/1:  41%|████      | 3102/7659 [02:49<04:05, 18.56it/s]

Batch 3100, Average Loss: 1.3519


Epoch 1/1:  42%|████▏     | 3202/7659 [02:54<04:00, 18.57it/s]

Batch 3200, Average Loss: 1.3512


Epoch 1/1:  43%|████▎     | 3302/7659 [02:59<03:55, 18.48it/s]

Batch 3300, Average Loss: 1.3491


Epoch 1/1:  44%|████▍     | 3402/7659 [03:05<03:48, 18.66it/s]

Batch 3400, Average Loss: 1.3486


Epoch 1/1:  46%|████▌     | 3502/7659 [03:10<03:44, 18.52it/s]

Batch 3500, Average Loss: 1.3475


Epoch 1/1:  47%|████▋     | 3602/7659 [03:16<03:38, 18.57it/s]

Batch 3600, Average Loss: 1.3471


Epoch 1/1:  48%|████▊     | 3702/7659 [03:21<03:33, 18.51it/s]

Batch 3700, Average Loss: 1.3469


Epoch 1/1:  50%|████▉     | 3802/7659 [03:27<03:28, 18.54it/s]

Batch 3800, Average Loss: 1.3458


Epoch 1/1:  51%|█████     | 3902/7659 [03:32<03:21, 18.61it/s]

Batch 3900, Average Loss: 1.3443


Epoch 1/1:  52%|█████▏    | 4002/7659 [03:38<03:41, 16.50it/s]

Batch 4000, Average Loss: 1.3435


Epoch 1/1:  54%|█████▎    | 4102/7659 [03:43<03:14, 18.31it/s]

Batch 4100, Average Loss: 1.3430


Epoch 1/1:  55%|█████▍    | 4202/7659 [03:48<03:08, 18.34it/s]

Batch 4200, Average Loss: 1.3429


Epoch 1/1:  56%|█████▌    | 4302/7659 [03:54<03:07, 17.89it/s]

Batch 4300, Average Loss: 1.3418


Epoch 1/1:  57%|█████▋    | 4402/7659 [03:59<02:56, 18.50it/s]

Batch 4400, Average Loss: 1.3408


Epoch 1/1:  59%|█████▉    | 4502/7659 [04:05<02:49, 18.63it/s]

Batch 4500, Average Loss: 1.3399


Epoch 1/1:  60%|██████    | 4602/7659 [04:10<02:50, 17.96it/s]

Batch 4600, Average Loss: 1.3392


Epoch 1/1:  61%|██████▏   | 4702/7659 [04:16<02:38, 18.69it/s]

Batch 4700, Average Loss: 1.3381


Epoch 1/1:  63%|██████▎   | 4802/7659 [04:21<02:32, 18.75it/s]

Batch 4800, Average Loss: 1.3376


Epoch 1/1:  64%|██████▍   | 4902/7659 [04:26<02:30, 18.35it/s]

Batch 4900, Average Loss: 1.3373


Epoch 1/1:  65%|██████▌   | 5002/7659 [04:32<02:23, 18.57it/s]

Batch 5000, Average Loss: 1.3365


Epoch 1/1:  67%|██████▋   | 5102/7659 [04:37<02:17, 18.53it/s]

Batch 5100, Average Loss: 1.3359


Epoch 1/1:  68%|██████▊   | 5202/7659 [04:43<02:11, 18.75it/s]

Batch 5200, Average Loss: 1.3350


Epoch 1/1:  69%|██████▉   | 5302/7659 [04:48<02:10, 18.04it/s]

Batch 5300, Average Loss: 1.3340


Epoch 1/1:  71%|███████   | 5402/7659 [04:54<02:01, 18.65it/s]

Batch 5400, Average Loss: 1.3334


Epoch 1/1:  72%|███████▏  | 5502/7659 [04:59<01:55, 18.64it/s]

Batch 5500, Average Loss: 1.3325


Epoch 1/1:  73%|███████▎  | 5602/7659 [05:04<01:52, 18.25it/s]

Batch 5600, Average Loss: 1.3317


Epoch 1/1:  74%|███████▍  | 5702/7659 [05:10<01:44, 18.77it/s]

Batch 5700, Average Loss: 1.3309


Epoch 1/1:  76%|███████▌  | 5802/7659 [05:15<01:39, 18.71it/s]

Batch 5800, Average Loss: 1.3300


Epoch 1/1:  77%|███████▋  | 5902/7659 [05:21<01:33, 18.74it/s]

Batch 5900, Average Loss: 1.3299


Epoch 1/1:  78%|███████▊  | 6002/7659 [05:26<01:29, 18.62it/s]

Batch 6000, Average Loss: 1.3294


Epoch 1/1:  80%|███████▉  | 6102/7659 [05:32<01:24, 18.53it/s]

Batch 6100, Average Loss: 1.3287


Epoch 1/1:  81%|████████  | 6202/7659 [05:37<01:20, 18.07it/s]

Batch 6200, Average Loss: 1.3282


Epoch 1/1:  82%|████████▏ | 6302/7659 [05:42<01:13, 18.42it/s]

Batch 6300, Average Loss: 1.3275


Epoch 1/1:  84%|████████▎ | 6402/7659 [05:48<01:09, 18.04it/s]

Batch 6400, Average Loss: 1.3267


Epoch 1/1:  85%|████████▍ | 6502/7659 [05:53<01:02, 18.52it/s]

Batch 6500, Average Loss: 1.3262


Epoch 1/1:  86%|████████▌ | 6602/7659 [05:59<00:57, 18.52it/s]

Batch 6600, Average Loss: 1.3259


Epoch 1/1:  88%|████████▊ | 6702/7659 [06:04<00:51, 18.42it/s]

Batch 6700, Average Loss: 1.3254


Epoch 1/1:  89%|████████▉ | 6802/7659 [06:10<00:45, 18.79it/s]

Batch 6800, Average Loss: 1.3250


Epoch 1/1:  90%|█████████ | 6902/7659 [06:15<00:40, 18.66it/s]

Batch 6900, Average Loss: 1.3246


Epoch 1/1:  91%|█████████▏| 7002/7659 [06:21<00:34, 18.84it/s]

Batch 7000, Average Loss: 1.3240


Epoch 1/1:  93%|█████████▎| 7102/7659 [06:26<00:29, 18.57it/s]

Batch 7100, Average Loss: 1.3235


Epoch 1/1:  94%|█████████▍| 7202/7659 [06:31<00:24, 18.45it/s]

Batch 7200, Average Loss: 1.3228


Epoch 1/1:  95%|█████████▌| 7302/7659 [06:37<00:19, 18.53it/s]

Batch 7300, Average Loss: 1.3223


Epoch 1/1:  97%|█████████▋| 7402/7659 [06:42<00:13, 18.65it/s]

Batch 7400, Average Loss: 1.3218


Epoch 1/1:  98%|█████████▊| 7502/7659 [06:48<00:10, 14.80it/s]

Batch 7500, Average Loss: 1.3210


Epoch 1/1:  99%|█████████▉| 7602/7659 [06:53<00:03, 18.51it/s]

Batch 7600, Average Loss: 1.3204


Epoch 1/1: 100%|██████████| 7659/7659 [06:56<00:00, 18.38it/s]

Epoch 1 completed. Average Loss: 1.3200





In [98]:
model.save_pretrained("./clip-lora-mimic-cxr")
processor.save_pretrained("./clip-lora-mimic-cxr")

[]

In [101]:
from transformers import CLIPProcessor, CLIPModel
import torch
from peft import PeftModel, PeftConfig
from PIL import Image
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = CLIPProcessor.from_pretrained("./clip-lora-mimic-cxr")
base_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model = PeftModel.from_pretrained(base_model, "./clip-lora-mimic-cxr")
model.to(device)
model.eval()

@torch.no_grad()
def safe_forward(model, input_ids, attention_mask, pixel_values):
    if hasattr(model, 'base_model'):
        base_model = model.base_model
    else:
        base_model = model
    
    text_outputs = base_model.text_model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    vision_outputs = base_model.vision_model(pixel_values=pixel_values)
    
    text_embeds = base_model.text_projection(text_outputs.pooler_output)
    image_embeds = base_model.visual_projection(vision_outputs.pooler_output)
    
    return text_embeds, image_embeds

image_path = "/kaggle/input/testing1/test_img1.jpg"
image = Image.open(image_path).convert("RGB")

candidates = [
    "There is evidence of left lower lobe pneumonia.",
    "The heart size is enlarged with signs of pulmonary edema.",
    "Lungs are clear with no acute cardiopulmonary abnormality."
]

inputs = processor(
    text=candidates,
    images=[image] * len(candidates),
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=77
)

inputs = {k: v.to(device) for k, v in inputs.items()}

text_embeds, image_embeds = safe_forward(
    model,
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    pixel_values=inputs["pixel_values"]
)

text_embeds = F.normalize(text_embeds, p=2, dim=1)
image_embeds = F.normalize(image_embeds, p=2, dim=1)
similarities = torch.matmul(image_embeds, text_embeds.T)

best_idx = similarities.argmax().item()
print(f"Most similar sentence:\n \"{candidates[best_idx]}\"")
print(f"Similarity scores: {similarities.cpu().numpy()[0]}")

Most similar sentence:
 "Lungs are clear with no acute cardiopulmonary abnormality."
Similarity scores: [0.18905325 0.20127088 0.2176774 ]
