In [46]:
from pathlib import Path

import pandas as pd


def parse_nih_data(csv_path: str | Path, nih_data_path: str | Path, boxes=False) -> pd.DataFrame:
    nih_df = pd.read_csv(csv_path)
    nih_df = nih_df.rename(
        dict(zip(nih_df.columns.tolist(),
                 ['image_idx', 'label', 'bbox_x', 'bbox_y', 'bbox_w', 'bbox_h', 'a', 'b', 'c'])),
        axis='columns')
    if boxes:
        res_list = ['image_idx', 'label', 'bbox_x', 'bbox_y', 'bbox_w', 'bbox_h']
    else:
        res_list = ['image_idx', 'label']

    nih_df = nih_df[res_list]
    nih_df['image_path'] = nih_df['image_idx'].apply(lambda x: str(Path(nih_data_path) / x))
    nih_df = nih_df.drop_duplicates(subset='image_idx')
    nih_df['parsed_label'] = nih_df['label'].apply(lambda x: x.split('|'))
    return nih_df


df = parse_nih_data('../data/nih_data/balanced_data.csv', '../data/nih_data/selected_images')

In [47]:


import torch
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")


In [187]:
disease_descriptors = {

    "Cardiomegaly": [
        "Increased size of the heart shadow",
        "Enlargement of the heart silhouette",
        "Increased diameter of the heart border",
        "Increased cardiothoracic ratio"
    ],
    "Edema": [
        "Blurry vascular markings in the lungs",
        "Kerley B lines",
        "Increased interstitial markings in the lungs",
        "Widening of interstitial spaces"
    ],
    "Consolidation": [
        "Loss of lung volume",
        "Increased density of lung tissue",
        "Obliteration of the diaphragmatic silhouette",
        "Presence of opacities"
    ],
    "Pneumonia": [
        "Consolidation of lung tissue",
        "Air bronchograms",
        "Cavitation",
        "Interstitial opacities"
    ],
    "Atelectasis": [
        "Increased opacity",
        "Volume loss of the affected lung region",
        "Displacement of the diaphragm",
        "Blunting of the costophrenic angle",
        "Shifting of the mediastinum"
      ],
    "Pneumothorax": [
        "Tracheal deviation",
        "Deep sulcus sign",
        "Increased radiolucency",
        "Flattening of the hemidiaphragm",
        "Absence of lung markings",
        "Shifting of the mediastinum"
      ],
    "Pleural Effusion": [
        "Blunting of costophrenic angles",
        "Opacity in the lower lung fields",
        "Mediastinal shift",
        "Reduced lung volume",
        "Meniscus sign or veil-like appearance"
      ],
    "Infiltration": [
        "Irregular or fuzzy borders around white areas",
        "Blurring",
        "Hazy or cloudy areas",
        "Increased density or opacity of lung tissue",
        "Air bronchograms",
      ],
      "Mass": [
        "Calcifications or mineralizations",
        "Shadowing",
        "Distortion or compression of tissues",
        "Anomalous structure or irregularity in shape"
      ],
      "Nodule": [
        "Nodular shape that protrudes into a cavity or airway",
        "Distinct edges or borders",
        "Calcifications or speckled areas",
        "Small round oral shaped spots",
        "White shadows"
      ],
      "Emphysema": [
        "Flattened hemidiaphragm",
        "Pulmonary bullae",
        "Hyperlucent lungs",
        "Horizontalisation of ribs",
        "Barrel Chest",
      ],
      "Fibrosis": [
        "Reticular shadowing of the lung peripheries",
        "Volume loss",
        "Thickened and irregular interstitial markings",
        "Bronchial dilation",
        "Shaggy heart borders"
      ],
      "Pleural Thickening": [
        "Thickened pleural line",
        "Loss of sharpness of the mediastinal border",
        "Calcifications on the pleura",
        "Lobulated peripheral shadowing",
        "Loss of lung volume",
      ],
      "Hernia": [
        "Bulge or swelling in the abdominal wall",
        "Protrusion of intestine or other abdominal tissue",
        "Swelling or enlargement of the herniated sac or surrounding tissues",
        "Retro-cardiac air-fluid level",
        "Thickening of intestinal folds"
      ]
    }

In [188]:



def calculate_text_embeddings(clip_model, clip_preprocess, disease_text: list[str]) -> torch.Tensor:
    tokenized_text = clip_preprocess(disease_text, return_tensors="pt", padding=True).input_ids
    text_embeddings = clip_model.get_text_features(tokenized_text).detach()
    return text_embeddings


def calculate_image_embeddings(clip_model, clip_preprocess, images: list[any]) -> torch.Tensor:
    inputs = clip_preprocess(text=None, images=images, return_tensors="pt")
    image_embedding = clip_model.get_image_features(**inputs).detach()
    return image_embedding

In [219]:


disease_descriptors = {
i: list(map(lambda x: f' a chest x-ray with {x},  i', v)) for i, v in disease_descriptors.items()}

In [220]:
disease_descriptors

{'Cardiomegaly': [' a chest x-ray with Cardiomegaly',
  ' a chest x-ray with Cardiomegaly',
  ' a chest x-ray with Cardiomegaly',
  ' a chest x-ray with Cardiomegaly'],
 'Edema': [' a chest x-ray with Edema',
  ' a chest x-ray with Edema',
  ' a chest x-ray with Edema',
  ' a chest x-ray with Edema'],
 'Consolidation': [' a chest x-ray with Consolidation',
  ' a chest x-ray with Consolidation',
  ' a chest x-ray with Consolidation',
  ' a chest x-ray with Consolidation'],
 'Pneumonia': [' a chest x-ray with Pneumonia',
  ' a chest x-ray with Pneumonia',
  ' a chest x-ray with Pneumonia',
  ' a chest x-ray with Pneumonia'],
 'Atelectasis': [' a chest x-ray with Atelectasis',
  ' a chest x-ray with Atelectasis',
  ' a chest x-ray with Atelectasis',
  ' a chest x-ray with Atelectasis',
  ' a chest x-ray with Atelectasis'],
 'Pneumothorax': [' a chest x-ray with Pneumothorax',
  ' a chest x-ray with Pneumothorax',
  ' a chest x-ray with Pneumothorax',
  ' a chest x-ray with Pneumothorax',


In [232]:
desc_embs = {label: calculate_text_embeddings(model, processor, [f' a chest x-ray with {label}']) for label, descs in
             disease_descriptors.items()}

In [233]:
list(desc_embs.values())[0].shape

torch.Size([1, 512])

In [234]:
# desc_embs = {label: embs.mean(dim=0, keepdim=True) for label, embs in desc_embs.items()}
# desc_embs

In [235]:
desc_embs.keys()

dict_keys(['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'])

In [236]:
from PIL import Image

row = df.iloc[3]
label = row["label"].split('|')
image_path = row["image_path"]
image_emb = calculate_image_embeddings(clip_model=model, clip_preprocess=processor, images=[Image.open(image_path)])

In [238]:
sims = []
print(row['label'])
avg_emb = torch.stack(list(desc_embs.values())).squeeze()
# for label, text_embs in desc_embs.items():
cos_sim = (torch.nn.functional.cosine_similarity(image_emb, avg_emb))
cos_dict = (dict(zip(desc_embs.keys(), cos_sim)))
for c, v in cos_dict.items():
        print((c, v))
#     print(f'{label} :  {cos_sim}')
#     sims.append(torch.max(cos_sim))
# print(sims)



Edema|Infiltration
('Cardiomegaly', tensor(0.2613))
('Edema', tensor(0.2614))
('Consolidation', tensor(0.2629))
('Pneumonia', tensor(0.2651))
('Atelectasis', tensor(0.2771))
('Pneumothorax', tensor(0.2666))
('Pleural Effusion', tensor(0.2711))
('Infiltration', tensor(0.2698))
('Mass', tensor(0.2536))
('Nodule', tensor(0.2463))
('Emphysema', tensor(0.2638))
('Fibrosis', tensor(0.2644))
('Pleural Thickening', tensor(0.2647))
('Hernia', tensor(0.2602))


In [56]:
# stacked_avg_embeddings = torch.stack(list(label_avg_embeddings)).squeeze()

In [57]:
# from tqdm import tqdm
# from PIL import Image
# 
# threshold = 0.625
# res_df = pd.DataFrame(columns=['image_index', 'preds', 'probas', 'probas_dict', 'thr'])
# for index, row in tqdm(list(df.iterrows()), desc='validate'):
#     label = row["label"].split('|')
#     image_path = row["image_path"]
#     image_emb = calculate_image_embeddings(clip_model=model, clip_preprocess=processor, image=Image.open(image_path))
#     cos_sim = torch.nn.functional.cosine_similarity(image_emb, stacked_avg_embeddings)
#     cos_sim = (cos_sim + 1) / 2
#     cos_dict = (dict(zip(labels, list(cos_sim.detach().numpy()))))
#     cos_dict = dict(sorted(cos_dict.items(), key=lambda item: item[1], reverse=True))
#     preds = []
#     probas = []
#     for l, v in cos_dict.items():
#         if v > threshold:
#             preds.append(l)
#             probas.append(v)
# 
#     new_row = {
#         'image_index': [row['image_idx']],
#         'preds': [preds],
#         'probas': [probas],
#         'probas_dict': [cos_dict],
#         'thr': [threshold],
#     }
#     res_df = pd.concat([res_df, pd.DataFrame(new_row)], ignore_index=True)
