# Dataloaders for radiology datasets 

In [31]:
import os
import numpy as np
from radvlm.data.utils import custom_collate_fn
from torch.utils.data import DataLoader
from radvlm.data.create_instructions import format_boxes
import matplotlib.pyplot as plt
import json

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

Let's first create a function to display images with potentially some BBox. 

In [32]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_img(array_data, boxes=None):
    # Display the image
    fig, ax = plt.subplots()
    ax.imshow(array_data[0], cmap='gray')
    
    # Get image dimensions
    img_height, img_width = array_data[0].shape

    # Draw bounding boxes if provided
    if boxes:
        for box in boxes:
            # Convert proportional coordinates to pixel coordinates
            x1, y1, x2, y2 = box
            x1_pixel = x1 * img_width
            y1_pixel = y1 * img_height
            x2_pixel = x2 * img_width
            y2_pixel = y2 * img_height
            
            # Calculate width and height of the bounding box
            width = x2_pixel - x1_pixel
            height = y2_pixel - y1_pixel
            
            # Create a rectangle patch
            rect = patches.Rectangle((x1_pixel, y1_pixel), width, height, linewidth=3, edgecolor='r', facecolor='none')
            
            # Add the rectangle to the plot
            ax.add_patch(rect)

    # Optional: Add colorbar and show the plot
    plt.colorbar(ax.imshow(array_data[0], cmap='gray'), ax=ax)
    plt.show()

def display_instruction(instruction):
    return print(json.dumps(instruction, indent=4, ensure_ascii=False))

Now let's retrieve the env variable `DATA_DIR`, to indicate where all datasets are located

In [3]:
from radvlm import DATA_DIR

## MIMIC-CXR dataset - Report generation

In [4]:
from radvlm.data.datasets import MIMIC_Dataset_MM

datasetpath = os.path.join(DATA_DIR, 'MIMIC-CXR-JPG')
filtered_reports_dir = os.path.join(datasetpath, 'filtered_reports') # if you have the filtered reports dir
conversation_dir =  os.path.join(datasetpath, 'conversations/train/standard') # if present 

dataset = MIMIC_Dataset_MM(
    datasetpath=datasetpath,
    split='train',
    flag_img=True, # set to True if you want the get_item function to get the images
    flag_lab=True, #  # set to True if you want the get_item function to get the labels
    flag_instr=True, # set to True to create instructions for report generation
    only_frontal=True, # set to True to ignore lateral images
    filtered_reports_dir=filtered_reports_dir, # will show filtered reports. Set to None for original reports
    sentencesBBoxpath=None, # indicate if you want to keep only the subset of samples from MS-CXR 
    conversation_dir=None, # indicate if you want to keep only the subset of samples that have conversations
)

print(len(dataset))

230980


Now we can create a dataloader from the dataset. 

In [8]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn )   

Let's explore some samples with images and attributes. If you want a new sample, simply reload the cell

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Image ---------")
show_img(sample[0]["img"])

print("\n--------- Report ---------")
print(sample[0]["txt"])

print("\n--------- Labels ---------")
print(sample[0]["labels"])

print("\n--------- RG Instructions ---------")
display_instruction(sample[0]["instr"])

#if you added conversation dir in the arguments 
# print("\n--------- Conversation ---------")
# display_instruction(sample[0]["conversation"])



## MS-CXR - Phrase Grounding
This class is derived from MIMIC-CXR, and bring the grounded phrases. It is organized per phrase (different datapoints can have same image)

In [33]:
from radvlm.data.datasets import MS_CXR

datasetpath_mimic = os.path.join(DATA_DIR, 'MIMIC-CXR-JPG') # we need this to get the images as it is derived from MIMIC-CXR
sentencesBBoxpath = os.path.join(DATA_DIR, 'MS-CXR','sentences_and_BBox_mscxr')
dataset = MS_CXR(
    datasetpath = datasetpath_mimic,
    split="train", 
    flag_img=True, 
    flag_lab=True, 
    only_frontal=True, 
    flag_instr=True, 
    sentencesBBoxpath=sentencesBBoxpath,
    seed=0)
print(len(dataset))

964


In [34]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn )   

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Phrase ---------")
print(sample[0]["label"])

print("\n--------- Bounding box ---------")
print(format_boxes(sample[0]["boxes"]))

print("\n--------- Image ---------")
show_img(sample[0]["img"], sample[0]["boxes"])

print("\n--------- Phrase grounding Instruction ---------")
display_instruction(sample[0]["instr"])

## Chest ImaGenome dataset - Anatomical grounding

In [7]:
from radvlm.data.datasets import Chest_ImaGenome_Dataset

datasetpath_mimic = os.path.join(DATA_DIR, 'MIMIC-CXR-JPG') # we need this to get the images
datasetpath_chestima = os.path.join(DATA_DIR, 'CHEST_IMA')
filtered_reports_dir = os.path.join(datasetpath_mimic, 'filtered_reports') # if you have the filtered reports dir

split = "train"
dataset = Chest_ImaGenome_Dataset(
    datasetpath=datasetpath_mimic,
    datasetpath_chestima=datasetpath_chestima, 
    split=split, 
    filtered_reports_dir=None, # optional, if you want filtered reports llm-generated
    flag_img=True, 
    flag_instr=True, 
    flag_txt=True, 
    flag_lab=False,
    pick_one_region=True, # if you want just one (randomly picked) region to be retrieved in the get_item. Set to False to if you want them all 
    sentencesBBoxpath=None,
    )

print(len(dataset))

162794


In [8]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn)  

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Report ---------")
print(sample[0]["txt"])

print("\n--------- Region name ---------")
print(sample[0]["label"])

print("\n--------- Bounding box ---------")
print(format_boxes(sample[0]["boxes"])) # formating to round to 2 floating point numbers, just for display purposes

print("\n--------- Image ---------")
show_img(sample[0]["img"], sample[0]["boxes"])

print("\n--------- Anatomy grounding Instructions ---------")
display_instruction(sample[0]["instr"])

## VinDr-CXR dataset - Abnormality detection
This dataset class is designed for abnormality detection: it contains all samples from the original VinDr-CXR (healthy and non healthy), and displays instructions for abnormliaty detection task.

In [36]:
from radvlm.data.datasets import VinDr_CXR_Dataset

datasetpath = os.path.join(DATA_DIR, "VinDr-CXR") 

dataset = VinDr_CXR_Dataset(
    datasetpath=datasetpath, 
    split="train", 
    flag_img=True, 
    flag_instr=True,
    )

print(len(dataset))

15000


In [37]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [None]:
sample = next(iter(data_loader))


print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Abnormality names ---------")
print(sample[0]["labels"])

print("\n--------- Bounding box ---------")
print(sample[0]["boxes"])

print("\n--------- Image ---------")
show_img(sample[0]["img"], sample[0]["boxes"])

print("\n--------- Abnormality detection instruction ---------")
display_instruction(sample[0]["instr"])

## VinDr-CXR for mono-class grounding
This class is designed for abnormality grounding: there are no healthy samples, only samples that contain abnormality; plus, there is only one abnormality per sample. 

In [39]:
from radvlm.data.datasets import VinDr_CXR_Single_Label_Dataset


datasetpath = os.path.join(DATA_DIR, "VinDr-CXR") 

dataset = VinDr_CXR_Single_Label_Dataset(
    datasetpath=datasetpath, 
    split="train", 
    flag_img=True, 
    flag_instr=True,
    )

print(len(dataset))

16089


In [40]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Abnormality name ---------")
print(sample[0]["label"])

print("\n--------- Bounding box ---------")
print(format_boxes(sample[0]["boxes"]))

print("\n--------- Image ---------")
show_img(sample[0]["img"], sample[0]["boxes"])

print("\n--------- Abnormality grounding instruction} ---------")
display_instruction(sample[0]["instr"])

## CheXpert Dataset - abnormality classification

In [None]:
from radvlm.data.datasets import CheXpert_Dataset_MM

datasetpath = os.path.join(DATA_DIR, 'CheXpert')
dataset = CheXpert_Dataset_MM(
    datasetpath=datasetpath,
    split='train',
    unique_patients=False, 
    only_frontal=True, 
    flag_img=True,
    flag_instr=True,
    flag_lab=True,
)
print(len(dataset))

In [19]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Image ---------")
show_img(sample[0]["img"])

print("\n--------- Abnormality labels ---------")
print(sample[0]["labels"])

print("\n--------- Abnormality classification instruction} ---------")
display_instruction(sample[0]["instr"])

## CheXpert-Plus - report generation 


In [21]:
from radvlm.data.datasets import CheXpertPlus_Dataset

datasetpath = os.path.join(DATA_DIR, 'CheXpert')
filtered_reports_dir = os.path.join(datasetpath, 'filtered_reports')
dataset = CheXpertPlus_Dataset(
    datasetpath=datasetpath, 
    split='train',
    flag_img=True,
    flag_txt=True, 
    flag_lab=True,
    only_frontal=True, 
    filtered_reports_dir=filtered_reports_dir, # optional, set to None for original reports
)
print(len(dataset))

186463


In [22]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Image ---------")
show_img(sample[0]["img"])

print("\n--------- Report ---------")
print(sample[0]["txt"])

print("\n--------- Labels ---------")
print(sample[0]["labels"])

print("\n--------- RG Instructions ---------")
display_instruction(sample[0]["instr"])

## PadChest - Phrase grounding
This dataset class is organized per observations, i.e., different datapoints can contain the same image (like for VinDr-CXR mono class)

In [24]:
from radvlm.data.datasets import PadChest_grounding

datasetpath = os.path.join(DATA_DIR, 'PadChest')
dataset = PadChest_grounding(
    datasetpath=datasetpath,
    split='train',
    flag_instr=True,
    flag_img=True
)



In [25]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [26]:
sample = next(iter(data_loader))

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Image ---------")
show_img(sample[0]["img"])

print("\n--------- Report ---------")
print(sample[0]["txt"])

print("\n--------- Phrase ---------")
print(sample[0]["label"])

print("\n--------- RG Instructions ---------")
display_instruction(sample[0]["instr"])

If you want to load the dataset per image with conversations, use the other class (below)

In [28]:
from radvlm.data.datasets import PadChest_grounding_per_image

conversation_dir = os.path.join(datasetpath, 'conversations/train/grounding')

dataset = PadChest_grounding_per_image(
    datasetpath=datasetpath,
    split='train',
    flag_img=True, 
    conversation_dir=conversation_dir
)
print(len(dataset))

1945


In [29]:
data_loader = DataLoader(dataset, batch_size = 1, shuffle = True, collate_fn=custom_collate_fn, num_workers=1)

In [None]:
sample = next(iter(data_loader))

print("\n--------- Image Path ---------")
print(sample[0]["img_path"])

print("\n--------- Image ---------")
show_img(sample[0]["img"])

print("\n--------- Report ---------")
print(sample[0]["txt"])

print("\n--------- Phrases ---------")
display_instruction(sample[0]["sentencesBBox"])

print("\n--------- Grounded conversations ---------")
display_instruction(sample[0]["conversation"])