<a href="https://colab.research.google.com/github/veerendra12/CS598-DL4H-Project/blob/main/notebooks/NIHDataSet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import import_ipynb

import os
from PIL import Image

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torchvision.transforms import ToPILImage

from Utils import get_device
from Configuration import CONFIG
from LungRegionGenerator import lung_region_generator, lung_region_image_generator
from ModelFactory import load_segmentation_model

In [None]:
def get_NIH_class_labels():
  CLASS_LABELS = CONFIG['CLASS_LABELS']
  return CLASS_LABELS

In [None]:
class NIHImageDataset(Dataset): 

    def __init__(self, dataset_csv, transform = None, use_lung_region_images = False):
        self.transform = transform
        
        self.df = pd.read_csv(CONFIG['BASE_DIR'] + dataset_csv)
        self.df = self.df.set_index("Image Index")
        self.use_lung_region_images = use_lung_region_images
        if use_lung_region_images:
          self.segmentation_model = load_segmentation_model()
        self.DEVICE = get_device()
        self.CLASS_LABELS = get_NIH_class_labels()

    def __len__(self):
        """
        Args: None
        Returns : Length of dataset
        """
        return len(self.df)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(CONFIG['IMAGE_DIR'], self.df.index[idx]))
        image = image.convert('RGB')

        if self.use_lung_region_images:
          image_tensor = CONFIG['SEGMENTATION_TRANSOFRM'](image)
          image_tensor = image_tensor.to(device=self.DEVICE)
          mask = lung_region_generator(image_tensor, self.segmentation_model)
          lung_region_image = lung_region_image_generator(image_tensor, mask)

          image = ToPILImage()(lung_region_image)      

        label_one_hot = np.zeros(len(self.CLASS_LABELS), dtype=int)
        for i in range(0, len(self.CLASS_LABELS)):
            if(self.df[self.CLASS_LABELS[i].strip()].iloc[idx].astype('int') > 0):
                label_one_hot[i] = self.df[self.CLASS_LABELS[i].strip()].iloc[idx].astype('int')

        if self.transform is not None:
            image = self.transform(image)

        return (image, label_one_hot)

In [None]:
def get_NIH_FullCXR_loaders():
    print('Using get_NIH_FullCXR_loaders')
    batch_size = CONFIG['BATCH_SIZE']
    train_dataset = NIHImageDataset(dataset_csv = CONFIG['TRAIN_CSV'], transform = CONFIG['NIH_TRANSFORMS']['train'], use_lung_region_images = False)
    validation_dataset = NIHImageDataset(dataset_csv = CONFIG['VALIDATION_CSV'], transform = CONFIG['NIH_TRANSFORMS']['validation'], use_lung_region_images = False)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    return train_dataset, train_loader, validation_dataset, validation_loader

In [None]:
def get_NIH_LungRegion_loaders():
    print('Using get_NIH_LungRegion_loaders')
    batch_size = CONFIG['BATCH_SIZE']
    train_dataset = NIHImageDataset(dataset_csv = CONFIG['TRAIN_CSV'], transform = CONFIG['NIH_TRANSFORMS']['train'], use_lung_region_images = True)
    validation_dataset = NIHImageDataset(dataset_csv = CONFIG['VALIDATION_CSV'], transform = CONFIG['NIH_TRANSFORMS']['validation'], use_lung_region_images = True)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    return train_dataset, train_loader, validation_dataset, validation_loader