In [1]:
import os
from os.path import join, splitext
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
data_folder = '../../../../../../../storage/ice1/shared/bmed6780/mip_group_2/CheXpert Plus'

In [26]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class CheXpertDataset(Dataset):
    def __init__(self, root_dir, transform=None, mode="train"):
        """
        Args:
            root_dir (str): Path to the parent directory containing subdirectories (e.g., 'target_folder').
            transform (callable, optional): Optional transform to be applied on an image.
            mode (str): Either "train" or "valid" to select the correct folder.
        """
        
        # self.target_folder = os.path.join(root_dir, 'chexbert_labels')
        self.root = root_dir
        self.img_path = os.path.join(self.root, 'PNG')
        self.img_folders = [folder for folder in os.listdir(self.img_path) if splitext(folder)[1] == '']

        self.target_folder = os.path.join(self.root, 'chexbert_labels')
        self.target_path = os.path.join(self.target_folder, 'findings_fixed.json')
        self.targets = []
        self.img_paths = []
        self.transform = transform
        self.mode = mode
        
        # load a dictionary of image paths and targets
        with open(self.target_path, 'r') as f:
            target_data = []
            for line in f:
                target_data.append(json.loads(line))

        for target_dict in target_data:
            target_list_per_sample = []
            for key, value in target_dict.items():
                if key == 'path_to_image': # save image paths
                    if splitext(value)[0].split('/')[0] == self.mode:
                        value = splitext(value)[0] + '.png'
                        for folder in self.img_folders:
                            img_subfolder_path = os.path.join(os.path.join(self.img_path, folder), 'PNG')
                            img_path = os.path.join(img_subfolder_path, value)
                            if os.path.exists(img_path):
                                self.img_paths.append(img_path)
                    else:
                        break # so targets for test data will not be saved
                else: # save target vectors
                    if value is None: 
                        target_list_per_sample.append(0) # if this disease is not mentioned, it is perhaps not present
                    elif value == -1:
                        target_list_per_sample.append(0.5) # if radiologist is uncertain, chances of having this disease or being healthy are half half
                    else:
                        target_list_per_sample.append(value) # either having this disease or not
            if len(target_list_per_sample) > 0: # empty list implies a testing smaple
                self.targets.append(torch.tensor(target_list_per_sample))
            
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB")  # convert to RGB
        if self.transform:
            img = self.transform(img)

        target = self.targets[idx]
        return img, target

In [27]:
# Define transforms for preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

# Create dataset and dataloader
dataset = CheXpertDataset(root_dir=data_folder, transform=transform, mode="train")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

<class 'torch.Tensor'>


In [29]:
for img, target in dataloader:
    print(img.shape)
    print(target.shape)
    break  # the first batch

torch.Size([4, 3, 224, 224])
torch.Size([4, 14])


In [30]:
len(dataset)

223228