In [None]:
!pip install git+https://github.com/RyanWangZf/MedCLIP.git
!pip install opencv-python
# !pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
#!pip install "git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI"

In [5]:
# Defining the prompts

IMG_SIZE=224
SEL_SEARCH_IMG_SIZE=341
# AUX_IMG_SIZE=256
AUX_IMG_SIZE=1024
IOU_THRESHOLD=0.01


# EVAL_IMG_DIR='/kaggle/input/vinbigdata-chest-xray-resized-png-256x256/train'
EVAL_IMG_DIR='/kaggle/input/vinbigdata-1024-image-dataset/vinbigdata/train'
FILE_RESOLUTION_DIR='/kaggle/input/vinbigdata-chest-xray-resized-png-256x256/train_meta.csv'
SEGMENTATION_RESULT='/kaggle/input/evaluation/iou_evaluation_sam_pretrained_selective_grid.csv'
TRAIN_CSV_PATH='/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/train.csv'

VINBIGDATA_TASK_LABELS=[
  "Aortic enlargement","Atelectasis","Calcification","Cardiomegaly","Consolidation","Interstitial","Infiltration",
  "Lung Opacity","Nodule","lesion","Pleural effusion","Pleural thickening","Pneumothorax","Pulmonary fibrosis"
]

VINBIGDATA_TASKS=[
  "Aortic enlargement","Atelectasis","Calcification","Cardiomegaly","Consolidation","ILD","Infiltration","Lung Opacity",
  "Nodule/Mass","Other lesion","Pleural effusion","Pleural thickening","Pneumothorax","Pulmonary fibrosis"
]

VINBIGDATA_CLASS_PROMPTS = {
    "Aortic enlargement": {
        "severity": ["", "mild", "moderate", "severe"],
        "subtype": [
            "ascending aortic enlargement",
            "aortic root enlargement",
            "descending aortic enlargement",
        ],
        "location": ["", "at the ascending aorta", "at the aortic arch", "at the descending aorta"],
    },
    "Atelectasis": {
        "severity": ["", "mild", "minimal"],
        "subtype": [
            "subsegmental atelectasis","linear atelectasis","trace atelectasis","bibasilar atelectasis",
            "retrocardiac atelectasis","bandlike atelectasis","residual atelectasis",
        ],
        "location": [
            "","at the mid lung zone","at the upper lung zone","at the right lung zone","at the left lung zone","at the lung bases",
            "at the right lung base","at the left lung base","at the bilateral lung bases","at the left lower lobe","at the right lower lobe",
        ],
    },
    "Calcification": {
        "severity": ["", "mild", "moderate", "extensive"],
        "subtype": [
            "vascular calcification","parenchymal calcification","lymph node calcification","calcified granuloma",
        ],
        "location": [
            "","in the coronary artery","in the lung parenchyma","in the pleura","in the lymph nodes",
        ],
    },
    "Cardiomegaly": {
        "severity": ["", "mild", "moderate", "severe", "extreme"],
        "subtype": [
            "cardiomegaly"
            "volume overload",
            "pressure overload",
            "dilated cardiomyopathy",
            "hypertrophic cardiomyopathy",
        ],
        "location": [""],
    },
    "Consolidation": {
        "severity": ["", "increased", "improved", "apperance of"],
        "subtype": [
            "bilateral consolidation","reticular consolidation","retrocardiac consolidation","patchy consolidation","airspace consolidation","partial consolidation",
        ],
        "location": [
            "",
            "at the lower lung zone","at the upper lung zone","at the left lower lobe","at the right lower lobe","at the left upper lobe",
            "at the right uppper lobe","at the right lung base","at the left lung base",
        ],
    },
    "Interstitial Lung Disease": {
        "severity": [ "","early stage", "progressive", "advanced"],
        "subtype": [
            "usual interstitial pneumonia pattern","non-specific interstitial pneumonia pattern","organizing pneumonia pattern",
            "lymphoid interstitial pneumonia pattern",
        ],
        "location": ["","in the upper lobes", "in the lower lobes", "diffuse"],
    },
    "Infiltration": {
        "severity": ["", "mild", "moderate", "dense"],
        "subtype": [
            "patchy infiltration","diffuse infiltration","lobar infiltration",
        ],
        "location": [
            "","in the right lung","in the left lung","in the upper lobes","in the lower lobes","centrally located",
        ],
    },
    "Lung Opacity": {
        "severity": ["mild", "moderate", "severe"],
        "subtype": ["ground-glass opacity", "hazy opacity"],
        "location": ["unilaterally", "bilaterally", "peripherally", "diffuse"],
    },

    "Nodule/mass": {
        "severity": ["", "small", "medium", "large"],
        "subtype": ["single nodule", "multiple nodules", "cavitary mass", "spiculated mass"],
        "location": ["", "in the right upper lobe", "in the left lower lobe", "adjacent to the pleura"],
    },
    "Lesion": {
        "severity": ["", "small", "moderate", "large"],
        "subtype": ["benign lesion", "malignant lesion", "indeterminate lesion"],
        "location": ["", "in the right lung", "in the left lung", "in the mediastinum"],
    },
    "Pleural Effusion": {
        "severity": ["", "small", "stable", "large", "decreased", "increased"],
        "location": ["","left", "right", "tiny"],
        "subtype": [
            "bilateral pleural effusion","subpulmonic pleural effusion","bilateral pleural effusion",
        ],
    },
    "Pleural thickening": {
        "severity": ["", "mild", "moderate", "severe"],
        "subtype": ["diffuse pleural thickening", "focal pleural thickening", "calcified pleural thickening"],
        "location": ["", "at the right pleura", "at the left pleura", "bilaterally"],
    },
    "Pneumothorax": {
        "severity": ["", "small", "large"],
        "subtype": ["primary spontaneous pneumothorax", "secondary pneumothorax", "tension pneumothorax"],
        "location": ["", "right-sided", "left-sided"],
    },
    "Pulmonary fibrosis": {
        "severity": ["", "mild", "moderate", "advanced"],
        "subtype": ["pulmonary fibriosis","usual interstitial pneumonia", "non-specific interstitial pneumonia", "traction bronchiectasis"],
        "location": ["", "in the upper lobes", "in the lower lobes", "diffuse"],
    },
    
}

COCO_CATEGORIES=[{"id":0,"name":"Aortic enlargement"},{"id":1,"name":"Atelectasis"},{"id":2,"name":"Calcification"},{"id":3,"name":"Cardiomegaly"},{"id":4,"name":"Consolidation"},{"id":5,"name":"ILD"},{"id":6,"name":"Infiltration"},{"id":7,"name":"Lung Opacity"},{"id":8,"name":"Nodule/Mass"},{"id":9,"name":"Other lesion"},{"id":10,"name":"Pleural effusion"},{"id":11,"name":"Pleural thickening"},{"id":12,"name":"Pneumothorax"},{"id":13,"name":"Pulmonary fibrosis"},{"id":14,"name":"No finding"}]

DETAILED_PROMPTS = {
    'Aortic Enlargement': {
        'adjective': ['mild', 'moderate', 'severe'],
        'description': ['dilatation', 'elongation'],
        'subtype': ['ascending aorta', 'aortic arch', 'descending aorta'],
        'location': ['along the mediastinum'],
    },
    'Atelectasis': {
        'adjective': ['subsegmental', 'partial', 'complete'],
        'description': ['collapse', 'volume loss'],
        'subtype': ['linear', 'lobar', 'round'],
        'location': ['lower lobe', 'upper lobe', 'basilar'],
    },
    'Calcification': {
        'adjective': ['discrete', 'scattered', 'extensive'],
        'description': ['high density', 'nodular'],
        'subtype': ['vascular', 'nodular', 'massive'],
        'location': ['in the heart', 'in the lungs', 'in the pleura'],
    },
    'Cardiomegaly': {
        'adjective': ['mild', 'moderate', 'severe'],
        'description': ['enlarged cardiac silhouette', 'increased cardiothoracic ratio'],
        'subtype': ['ventricular hypertrophy', 'chamber enlargement'],
        'location': ['central thorax'],
    },
    'Consolidation': {
        'adjective': ['focal', 'multifocal', 'lobar'],
        'description': ['air-space opacification', 'alveolar filling'],
        'subtype': ['bacterial', 'viral', 'aspiration'],
        'location': ['lower lobe', 'upper lobe', 'peripheral'],
    },
    'ILD': {
        'adjective': ['early', 'progressive', 'advanced'],
        'description': ['reticular', 'reticulonodular', 'honeycombing'],
        'subtype': ['UIP', 'NSIP', 'DIP'],
        'location': ['lower lobes', 'diffuse'],
    },
    'Infiltration': {
        'adjective': ['patchy', 'diffuse', 'localized'],
        'description': ['interstitial marking', 'ground-glass appearance'],
        'subtype': ['lymphatic', 'interstitial pneumonia', 'edema'],
        'location': ['perihilar', 'basilar', 'peripheral'],
    },
    'Lung Opacity': {
        'adjective': ['mild', 'moderate', 'dense'],
        'description': ['non-specific opacity', 'ground-glass opacity'],
        'subtype': ['focal', 'diffuse', 'multifocal'],
        'location': ['upper zone', 'lower zone', 'diffuse'],
    },
    'Nodule/Mass': {
        'adjective': ['small', 'medium', 'large'],
        'description': ['well-defined', 'spiculated', 'cavitary'],
        'subtype': ['benign', 'malignant', 'indeterminate'],
        'location': ['upper lobe', 'lower lobe', 'peripheral'],
    },
    'Other Lesion': {
        'adjective': ['solitary', 'multiple', 'confluent'],
        'description': ['well-demarcated', 'ill-defined', 'calcified'],
        'subtype': ['hamartoma', 'fibrosis', 'hematoma'],
        'location': ['central', 'peripheral', 'apical'],
    },
    'Pleural Effusion': {
        'adjective': ['small', 'moderate', 'large'],
        'description': ['blunting of costophrenic angle', 'meniscus sign'],
        'subtype': ['transudative', 'exudative', 'hemothorax'],
        'location': ['left-sided', 'right-sided', 'bilateral'],
    },
    'Pleural Thickening': {
        'adjective': ['smooth', 'nodular', 'irregular'],
        'description': ['fibrotic bands', 'pleural rind'],
        'subtype': ['diffuse', 'focal', 'apical'],
        'location': ['apex', 'costal pleura', 'diaphragmatic pleura'],
    },
    'Pneumothorax': {
        'adjective': ['small', 'large', 'tension'],
        'description': ['visceral pleural line', 'absence of lung markings'],
        'subtype': ['spontaneous', 'traumatic', 'iatrogenic'],
        'location': ['apical', 'basal', 'lateral'],
    },
    'Pulmonary Fibrosis': {
        'adjective': ['mild', 'moderate', 'severe'],
        'description': ['reticular pattern', 'honeycombing', 'traction bronchiectasis'],
        'subtype': ['idiopathic', 'chronic', 'acute exacerbation'],
        'location': ['basal regions', 'upper lobes', 'diffuse'],
    },
}


In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import os
from medclip import MedCLIPModel, MedCLIPVisionModelViT, MedCLIPProcessor, PromptClassifier, MedCLIPVisionModel
from medclip.prompts import generate_chexpert_class_prompts, process_class_prompts
import pandas as pd
import json
import ast
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import cv2
from torchvision.ops import nms
from tqdm import tqdm

#vision transformer for full image
vit_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
vit_model.from_pretrained()
vit_clf = PromptClassifier(vit_model, ensemble=True)
vit_clf.cuda()
#resnet model for cropped image
resnet_model = MedCLIPModel(vision_cls=MedCLIPVisionModel)
resnet_model.from_pretrained()
resnet_clf = PromptClassifier(resnet_model, ensemble=True)
resnet_clf.cuda()

processor = MedCLIPProcessor()
print('medclip Loaded')

In [15]:
# Loading the dataset
class DatasetLoader(torch.utils.data.Dataset):
  def __init__(self, df, output_column, transform=None):
    self.df = df
    self.transform = transform
    self.output_column = output_column

  def __len__(self):
    return len(self.df.index)

  def __getitem__(self, idx):
    file_name=self.df.iloc[idx,0]
    image_path = os.path.join(EVAL_IMG_DIR, str(file_name)+'.png')
    image=Image.open(image_path)
    bounding_boxes=self.df.loc[idx,self.output_column]
    bounding_boxes=ast.literal_eval(bounding_boxes)
    #perform non max suppression
    boxes_tensor = torch.tensor(bounding_boxes, dtype=torch.float32)

    if self.transform:
      image = self.transform(image)

    return image, boxes_tensor, file_name

device="cuda" if torch.cuda.is_available() else "cpu"
eval_df=pd.read_csv(SEGMENTATION_RESULT)

eval_dataset = DatasetLoader(eval_df,'list_gt_bboxes')
eval_loader = DataLoader(eval_dataset, batch_size=1)
##
train_df=pd.read_csv(TRAIN_CSV_PATH)
file_resolution_df=pd.read_csv(FILE_RESOLUTION_DIR)

In [17]:
# Cropping the images and drawing the bounding boxes
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch

def crop_images(image, bounding_box,gaussian_radius):
    """
    This function takes an image and a list of bounding boxes as arguments and returns a list of cropped images.

    Args:
      image: A NumPy array representing the image.
      bounding_box: bounding box is a list of four integers representing
                      the top-left x, top-left y, bottom-right x, and bottom-right y coordinates.

    Returns:
      NumPy array representing the cropped image.
    """
    #Drawing the mask 
#     width, height= AUX_IMG_SIZE, AUX_IMG_SIZE
    
    blurred_image = image.copy()
    blurred_image = blurred_image.filter(ImageFilter.GaussianBlur(radius=gaussian_radius))  # Apply Gaussian blur

    # Create a mask: opaque in the bbox area and transparent outside
    mask = Image.new("L", image.size, 0)  # Start with a completely black mask
    draw = ImageDraw.Draw(mask)
    draw.rectangle(bounding_box, fill=255)  # Draw a white rectangle for the bbox area

    # Composite the original image over the blurred image using the mask
    focused_image = Image.composite(image, blurred_image, mask)

    return np.array(focused_image)


def draw(img, results, labels, colour=(0, 0, 255), width=1, font_colour=(0, 0, 0), font_scale=1,
            font_thickness=1,
            T=0.6):
    """
    :param img:
    :param results: Dict with keys masks and boxes
    :param labels:
    :param colour:
    :param width:
    :param font_scale:
    :param font_thickness:
    :param T: transparency
    :return:
    """
    img = np.array(img)
    R, G, B = colour
    if 'masks' in results:
        for mask, (x1, y1, x2, y2) in zip(results['masks'], results['boxes']):
            img[y1:y2, x1:x2, 0][mask] = ((1 - T) * img[y1:y2, x1:x2, 0][mask]).astype(np.uint8) + np.uint8(R * T)
            img[y1:y2, x1:x2, 1][mask] = ((1 - T) * img[y1:y2, x1:x2, 1][mask]).astype(np.uint8) + np.uint8(G * T)
            img[y1:y2, x1:x2, 2][mask] = ((1 - T) * img[y1:y2, x1:x2, 2][mask]).astype(np.uint8) + np.uint8(B * T)

    cnt=0
    for x1, y1, x2, y2 in results['boxes']:
        x1,y1,x2,y2=int(x1),int(x2),int(x2),int(y2)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), colour, width)
        img = cv2.putText(img, labels[cnt], (x1 + (x2 - x1) // 10, y1 + (y2 - y1) // 5), cv2.FONT_ITALIC,
                            font_scale, (0, 0, 0), font_thickness + 1, cv2.LINE_AA)
        img = cv2.putText(img, labels[cnt], (x1 + (x2 - x1) // 10, y1 + (y2 - y1) // 5), cv2.FONT_ITALIC,
                            font_scale, font_colour, font_thickness, cv2.LINE_AA)
        cnt+=1
    return img



In [9]:
import random

def add_bbox_prompts(input_to_add_prompt,bounding_box,indices_list, natural_prompt=False):
    prompts={}
    bbox=resize_bounding_box_normalized(bounding_box, IMG_SIZE, IMG_SIZE)
    bbox=[int(bbox[0].item()), int(bbox[1].item()), int(bbox[2].item()), int(bbox[3].item())]
    for index in indices_list:
        cls_prompts = []
        if natural_prompt:
            region=categorize_bbox_to_zone(bbox)
            cls_prompts.append(f"{VINBIGDATA_TASK_LABELS[index]} {region} within the area {bbox}")
        else:
            cls_prompts.append(f"{VINBIGDATA_TASK_LABELS[index]} within the area {bbox}")
        prompts[VINBIGDATA_TASKS[index]]=cls_prompts
#     print(prompts)
    input_to_add_prompt['prompt_inputs'] = process_class_prompts(prompts)
    
def add_filtered_prompts(input_to_add_prompt,indexes, num_prompts=20):
    # prepare input prompt texts based on medclip severity, subtype, location
    cls_prompts = process_class_prompts(generate_detailed_class_prompts(n=num_prompts,indexes=indexes))
#     cls_prompts = process_class_prompts(generate_prompts_baseline())
    input_to_add_prompt['prompt_inputs'] = cls_prompts
    
def add_prompts(input_to_add_prompt, num_prompts=20):
    # prepare input prompt texts based on medclip severity, subtype, location
    cls_prompts = process_class_prompts(generate_detailed_class_prompts(n=num_prompts))
#     cls_prompts = process_class_prompts(generate_prompts_baseline())
    input_to_add_prompt['prompt_inputs'] = cls_prompts
    
def generate_prompts_baseline():
    """Generate text prompts for baseline zero shot baseline of VinBigData tasks
    Parameters
    ----------
    none
    Returns
    -------
    class prompts : dict
        dictionary of class to prompts
    """
    prompts={}
    for i in range(len(VINBIGDATA_TASK_LABELS)):
        cls_prompts = []
        cls_prompts.append(VINBIGDATA_TASK_LABELS[i])
        prompts[VINBIGDATA_TASKS[i]]=cls_prompts
    return prompts

def categorize_bbox_to_zone(bbox):
    """
    Categorize bounding box to specific lung zone.
    
    Parameters:
    - bbox: Bounding box [x_min, y_min, width, height]
    - img_width: Width of the X-ray image
    - img_height: Height of the X-ray image
    
    Returns:
    - zone: The lung zone where the bounding box is located
    """
    x_min, y_min, x_max, y_max = bbox
    x_center = (x_min + x_max) // 2
    y_center = (y_min + y_max) // 2
    
    # Define thresholds for upper and lower zones based on image height
    upper_zone_threshold = IMG_SIZE * 0.33
    lower_zone_threshold = IMG_SIZE * 0.66
    
    # Define thresholds for left and right sides based on image width
    left_side_threshold = IMG_SIZE / 2
    
    # Determine vertical zone
    if y_center < upper_zone_threshold:
        vertical_zone = "upper lung zone"
    elif y_center > lower_zone_threshold:
        vertical_zone = "lower lung zone"
    else:
        vertical_zone = None  # Middle zone, not explicitly categorized
    
    # Determine left or right side
    if x_center < left_side_threshold:
        side = "left"
    else:
        side = "right"
        
    # Combine side and vertical zone for detailed categorization
    if vertical_zone:
        detailed_zone = f"at the {side} {vertical_zone}"
        if "upper" in vertical_zone:
            detailed_zone = f"at the {side} upper lobe"
        else:
            detailed_zone = f"at the {side} lower lobe"
        if "lower lung zone" in detailed_zone:
            detailed_zone = detailed_zone.replace("lower lung zone", "lung base")
    else:
        detailed_zone = "at the mid lung zone"  # For bounding boxes that fall in the middle third
    
    return detailed_zone

def generate_detailed_class_prompts_extended(n = None, indexes=None):
    """Generate text prompts for each classification task
    Parameters
    ----------
    n:  int
        number of prompts per class
    indexes:  List[int]
        indexes to consider
    Returns
    -------
    class prompts : dict
        dictionary of class to prompts
    """

    prompts = {}
    for idx,(k, v) in enumerate(DETAILED_PROMPTS.items()):
        if indexes and (idx not in indexes):
            continue
        cls_prompts = []
        keys = list(v.keys())

        for k0 in v[keys[0]]:
            for k1 in v[keys[1]]:
                for k2 in v[keys[2]]:
                    for k3 in v[keys[3]]:
                        cls_prompts.append(f"{k0} {k1} {k2} {k3}")

        # randomly sample n prompts for zero-shot classification
        # TODO: we shall make use all the candidate prompts for autoprompt tuning
        if n is not None and n < len(cls_prompts):
            prompts[k] = random.sample(cls_prompts, n)
        else:
            prompts[k] = cls_prompts
#         print(f'sample {len(prompts[k])} num of prompts for {k} from total {len(cls_prompts)}')
    return prompts

def generate_detailed_class_prompts(n = None, indexes=None):
    """Generate text prompts for each CheXpert classification task
    Parameters
    ----------
    n:  int
        number of prompts per class
    indexes:  List[int]
        indexes to consider
    Returns
    -------
    class prompts : dict
        dictionary of class to prompts
    """

    prompts = {}
    for idx,(k, v) in enumerate(VINBIGDATA_CLASS_PROMPTS.items()):
        if indexes and (idx not in indexes):
            continue
        cls_prompts = []
        keys = list(v.keys())

        # severity
        for k0 in v[keys[0]]:
            # subtype
            for k1 in v[keys[1]]:
                # location
                for k2 in v[keys[2]]:
                    cls_prompts.append(f"{k0} {k1} {k2}")

        # randomly sample n prompts for zero-shot classification
        if n is not None and n < len(cls_prompts):
            prompts[k] = random.sample(cls_prompts, n)
        else:
            prompts[k] = cls_prompts
#         print(f'sample {len(prompts[k])} num of prompts for {k} from total {len(cls_prompts)}')
    return prompts

In [10]:
def resize_bounding_box_unnormalized(bbox, original_width, original_height, new_width, new_height):
    x_min, y_min, x_max, y_max = bbox[0],bbox[1],bbox[2],bbox[3]

    width_ratio = new_width / original_width
    height_ratio = new_height / original_height

    new_x_min = int(x_min * width_ratio)
    new_y_min = int(y_min * height_ratio)
    new_x_max = int(x_max * width_ratio)
    new_y_max = int(y_max * height_ratio)

    return [new_x_min, new_y_min, new_x_max, new_y_max]

def resize_bounding_box_normalized(bbox, new_width, new_height):
    x_min, y_min, x_max, y_max = bbox[0],bbox[1],bbox[2],bbox[3]
#     print('rs',x_min, y_min, x_max, y_max)
    new_x_min = x_min * new_width
    new_y_min = y_min * new_height
    new_x_max = x_max * new_width
    new_y_max = y_max * new_height
#     print('rs_new',new_x_min, new_y_min, new_x_max, new_y_max)
    return [new_x_min, new_y_min, new_x_max, new_y_max]

def convert_to_original_resolution(file_name, bbox, toDefaultFlag, width=None, height=None):
    if toDefaultFlag:
        #scale_up
        image_row_index=file_resolution_df[file_resolution_df['image_id']==file_name].index.tolist()[0]
        _,new_height, new_width = file_resolution_df.iloc[image_row_index]
        if width:
            return resize_bounding_box_unnormalized(bbox, width, height, new_width, new_height)
        else:
            return resize_bounding_box_normalized(bbox, new_width, new_height)
    else:
        #scale_down
        image_row_index=file_resolution_df[file_resolution_df['image_id']==file_name].index.tolist()[0]
        _,original_height, original_width = file_resolution_df.iloc[image_row_index]
        return resize_bounding_box_unnormalized(bbox, original_width, original_height, width, height)


In [11]:
def process_and_crop_images(bounding_boxes,image_tensor,gaussian_radius):
    processed_cropped_images=[]
    for idx,bbox in enumerate(bounding_boxes):
        bounding_box=resize_bounding_box_normalized(bbox, AUX_IMG_SIZE, AUX_IMG_SIZE)
        cropped_image=crop_images(image=image_tensor,bounding_box=bounding_box,gaussian_radius=gaussian_radius)
        pil_image = Image.fromarray(cropped_image, 'L')
        processed_cropped_images.append(processor(images=pil_image, return_tensors="pt"))
    return processed_cropped_images

def build_target(file_name,coco_images,coco_targets):
    filtered_targets=train_df[train_df['image_id']==file_name]
    target_list = filtered_targets.to_dict(orient='records')
    image_row_index=file_resolution_df[file_resolution_df['image_id']==file_name].index.tolist()[0]
    _,dicom_height, dicom_width = file_resolution_df.iloc[image_row_index]
    coco_images.append({'id':file_name,'width':dicom_width,'height':dicom_height})
    if len(target_list)==1 and target_list[0]['class_id']==14:
        coco_targets.append({'image_id':file_name,'category_id':14,'bbox':[0,0,1,1],'score':1.0})
    for ix,row in enumerate(target_list):
        x_min,y_min,x_max,y_max=int(row['x_min']),int(row['y_min']),int(row['x_max']),int(row['y_max'])
        width=x_max-x_min
        height=y_max-y_min
        area=int(width*height)  
        bbox_str=[x_min,y_min,width,height]
        coco_targets.append({'id':ix,'image_id':file_name,'category_id':row['class_id'],'bbox':bbox_str,'area':area,'iscrowd':0})

In [12]:
def generate_attention_mask(image, bbox):
    mask = np.full((IMG_SIZE,IMG_SIZE), -1e9, dtype=np.uint8)
    mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 1
    mask = np.expand_dims(mask, axis=2)
    mask = np.repeat(mask, 3, axis=2)
    mask = (mask * 255).astype(np.uint8)
    mask = Image.fromarray(mask)
    return mask

In [None]:
FULL_IMAGE_WEIGHT=0.6 #
CROP_IMAGE_WEIGHT=0.4 #
FALSE_POSITIVE_THRESHOLD=0.8 #
SKIP_CLASS_THRESHOLD=0.5 #
K=7 #
gaussian_radius=20 #

results={'boxes':[]}
labels=[]
csv_list=[]
coco_predictions=[]
coco_targets=[]
coco_images=[]
filtered_image_fp=[]
filtered_image_bbox=[]
with torch.no_grad():
    for image, bounding_boxes, file_name in (eval_dataset):
        # Make full image prediction and skip normal x-rays
        inputs=processor(images=image, return_tensors="pt")
        add_prompts(inputs, num_prompts=20)
        output=vit_clf(**inputs)
        full_image_logits=output['logits']
        #reduce bias
        
        full_image_probs=full_image_logits.softmax(dim=1)

        #filter classes that exceed threshold
        values, indices = torch.topk(full_image_logits[0], K)
        indices_list = indices.tolist()

        filtered_full_image_logits=full_image_logits[0,indices_list]
        filtered_full_image_probs=filtered_full_image_logits.softmax(dim=0)
        
        ##crop image
        image_tensor=np.array(image)
        processed_cropped_images=process_and_crop_images(bounding_boxes,image,gaussian_radius)

        ##run medclip on cropped images
        for i,processed_cropped_image in enumerate(processed_cropped_images):
            add_filtered_prompts(processed_cropped_image,indices_list, 10)
            anchor_output=vit_clf(**processed_cropped_image)
            anchor_probs=anchor_output['logits'].softmax(dim=1)
            dicom_bbox=convert_to_original_resolution(file_name,bounding_boxes[i],True)
            weighted_probs=(FULL_IMAGE_WEIGHT*filtered_full_image_probs)+(CROP_IMAGE_WEIGHT*anchor_probs)

            filtered_index=weighted_probs.argmax().item()
            prediction_class_id=indices_list[filtered_index]
            prediction_confidence=weighted_probs.max().item()

            x_min,y_min,x_max,y_max=int(dicom_bbox[0].item()),int(dicom_bbox[1].item()),int(dicom_bbox[2].item()),int(dicom_bbox[3].item())
            width=x_max-x_min
            height=y_max-y_min
            bbox_str=[x_min,y_min,width,height]
            coco_predictions.append({'image_id':file_name,'category_id':prediction_class_id,'bbox':bbox_str,'score':prediction_confidence})
        build_target(file_name,coco_images,coco_targets)

    # print('Skip FP',len(filtered_image_fp))
    # print('Skip bbox',len(filtered_image_bbox))