<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/NewUnetDataPreparation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
from NewRadioImage import RadioImage
import SimpleITK as sitk
import pydicom
import pydicom_seg
from torch.utils.data import Dataset,DataLoader
from itertools import chain
import torch
import matplotlib.pyplot as plt
import os
from glob import glob
from tqdm import tqdm
from albumentations import HorizontalFlip, VerticalFlip, Rotate
import cv2
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms



# class PatientDataset(Dataset):
#     def __init__(self, image, mask):
#         self.image = image
#         self.mask = mask
#         self.n_samples = len(self.image)
#
#     def __getitem__(self, index):
#         img = torch.unsqueeze(self.image[index,:,:], 0)
#         msk = torch.unsqueeze(self.mask[index,:,:], 0)
#         return img,msk
#
#     def __len__(self):
#         return self.n_samples

class PatientDataset2DUNet(Dataset):
    def __init__(self, patients, metadata, train=True):
        """
        Args:
            patients (list): List of patient IDs.
            metadata (DataFrame): Metadata containing patient information.
            train (bool): If True, filters out empty slices.
        """
        self.patients = patients
        self.metadata = metadata
        self.train = train
        self.slices = self._extract_slices()  # Store (img_path, mask_path, slice_idx) tuples

    def get_path(self, subject, modality):
        subject_filtered = subject[subject['Modality'] == modality]
        return subject_filtered['File Location'].iloc[0] if not subject_filtered.empty else None

    def _extract_slices(self):
        slices = []
        for patient in self.patients:
            print(f"Processing Patient: {patient}")
            subject = self.metadata[self.metadata['Subject ID'] == patient]

            img_path = self.get_path(subject, "CT")
            msk_path = self.get_path(subject, "SEG")

            if img_path and msk_path:
                img = self.read_ct_array(img_path)
                msk = self.read_seg_array(msk_path, "GTV-1")

                if img is not None and msk is not None:
                    image = sitk.GetArrayFromImage(img).astype(np.float32)
                    mask = sitk.GetArrayFromImage(msk).astype(np.float32)

                    min_slices = min(image.shape[0], mask.shape[0])
                    image, mask = image[:min_slices], mask[:min_slices]

                    # Vectorized filtering of empty slices
                    slice_indices = np.arange(min_slices) if not self.train else np.where(np.any(mask, axis=(1, 2)))[0]

                    # Store (image_path, mask_path, slice_index)
                    slices.extend(zip([img_path] * len(slice_indices), [msk_path] * len(slice_indices), slice_indices))

        return slices

    def __len__(self):
        return len(self.slices)

    def read_ct_array(self, path):
        reader = sitk.ImageSeriesReader()
        reader.SetImageIO("GDCMImageIO")
        reader.SetFileNames(reader.GetGDCMSeriesFileNames(path))
        return reader.Execute()

    def read_seg_array(self, path, seg_type="GTV-1"):
        try:
            segmentation = pydicom.dcmread(os.path.join(path, '1-1.dcm'))
            seg_df = pd.DataFrame({f: [s[f].value for s in segmentation.SegmentSequence] for f in ['SegmentNumber', 'SegmentDescription']})
            seg_number = seg_df.loc[seg_df['SegmentDescription'] == seg_type, 'SegmentNumber'].iloc[0]
            return pydicom_seg.SegmentReader().read(segmentation).segment_image(seg_number)
        except Exception as e:
            print(f"Error reading segmentation from {path}: {e}")
            return None

    def __getitem__(self, idx):
        img_path, mask_path, slice_idx = self.slices[idx]

        # Load the full volume but extract only one slice
        img = self.read_ct_array(img_path)
        msk = self.read_seg_array(mask_path, "GTV-1")

        if img is None or msk is None:
            return None

        image = sitk.GetArrayFromImage(img).astype(np.float32)
        mask = sitk.GetArrayFromImage(msk).astype(np.float32)

        # Extract relevant slice vectorized
        image, mask = image[slice_idx], mask[slice_idx]  # Shape: (H, W)

        # Vectorized normalization
        image = (image - image.min()) / max(image.max(), 1e-6)  # Avoid divide-by-zero

        # Convert to PyTorch tensor and add channel dimension (C=1)
        image, mask = map(lambda x: torch.from_numpy(x).unsqueeze(0), (image, mask))  # Shape: [1, H, W]

        return image, mask

class PatientDatasetAllInOneTensor(Dataset):
    def __init__(self, patients, metadata, train=True):
        """
                Args:
                    patients (list): List of patient IDs.
                    metadata (DataFrame): Metadata containing patient information.
                    train (bool): If True, filters out empty slices.
                """
        self.patients = patients
        self.metadata = metadata
        self.train = train
        self.images, self.masks = self._extract_data()

    def _extract_data(self):
        ri = RadioImage()
        all_images = []
        all_masks = []

        for patient in self.patients:
            print(f"Processing Patient: {patient}")
            subject = self.metadata[self.metadata['Subject ID'] == patient]

            img = ri.read_dicom_ct(subject)
            msk = ri.read_dicom_seg(subject, 'GTV-1')

            image = sitk.GetArrayFromImage(img).astype(np.float16)
            mask = sitk.GetArrayFromImage(msk).astype(np.float16)

            print(f"Image Shape: {image.shape}, Mask Shape: {mask.shape}")

            # Ensure number of slices are the same
            min_slices = min(image.shape[0], mask.shape[0])
            image, mask = image[:min_slices], mask[:min_slices]

            # Keep only non-empty slices for training
            if self.train:
                non_empty_slices = np.any(mask, axis=(1, 2))  # Detect slices with any segmentation
                image, mask = image[non_empty_slices], mask[non_empty_slices]

            # Normalize image (Min-Max Scaling)
            image -= image.min()
            if image.max() > 0:
                image /= image.max()

            all_images.append(image)
            all_masks.append(mask)

        # Concatenate all images & masks to maintain DataLoader consistency
        all_images = np.concatenate(all_images, axis=0)  # [Total Slices, H, W]
        all_masks = np.concatenate(all_masks, axis=0)  # [Total Slices, H, W]

        return all_images, all_masks


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).unsqueeze(0)  # Add channel dim (C=1)
        mask = torch.from_numpy(self.masks[idx]).unsqueeze(0)  # Add channel dim (C=1)
        return image, mask

class SliceDataset(Dataset):
    def __init__(self,images_path,masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self,index):
        """ Reading Image"""
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = image/image.max()
        image =image.astype(np.float32)
        image = np.expand_dims(image, axis=0)
        image = torch.from_numpy(image)

        """Reading Mask"""
        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask / mask.max()
        mask = mask.astype(np.float32)
        mask = np.expand_dims(mask,axis=0)
        mask = torch.from_numpy(mask)

        return image,mask

    def __len__(self):
        return self.n_samples




class DataPreprocess:
    def __init__(self):
        metadata_lung1 = pd.read_csv('metadata/metadata_lung1.csv', sep=',', index_col=False)
        patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_lung1.index(i) for i in ['LUNG1-128'] ]
        patient_list_lung1 = np.delete(patient_list_lung1, index_of_error_patient)
        patient_list_lung1 = [p + "/None" + "/GTV-1" for p in patient_list_lung1]

        metadata_md = pd.read_csv('metadata/metadata_md.csv', sep=',', index_col=False)
        patient_list_md = metadata_md["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_md.index(i) for i in ['interobs09']]
        patient_list_md = np.delete(patient_list_md, index_of_error_patient)

        patient_list_md1 = [p + "/None" + "/GTV-1vis-1" for p in patient_list_md]
        patient_list_md2 = [p + "/None" + "/GTV-1vis-2" for p in patient_list_md]
        patient_list_md3 = [p + "/None" + "/GTV-1vis-3" for p in patient_list_md]
        patient_list_md4 = [p + "/None" + "/GTV-1vis-4" for p in patient_list_md]
        patient_list_md5 = [p + "/None" + "/GTV-1vis-5" for p in patient_list_md]
        metadata_test = pd.read_csv('metadata/metadata_test_retest.csv', sep=',', index_col=False)
        patient_list_test_retest = metadata_test["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_test_retest.index(i) for i in
                                  ['RIDER-2283289298', 'RIDER-5195703382', 'RIDER-8509201188']]
        patient_list_test_retest = np.delete(patient_list_test_retest, index_of_error_patient)
        patient_list_test = [p + "/TEST" + "/GTVp_test_man" for p in patient_list_test_retest]
        patient_list_retest = [p + "/RETEST" + "/GTVp_retest_man" for p in patient_list_test_retest]
        self.patient_list = list(chain(patient_list_lung1,patient_list_test,patient_list_retest,patient_list_md1,
                                       patient_list_md2,patient_list_md3,patient_list_md4,patient_list_md5))
    def get_data(self):
        return self.patient_list

    def load_array(self, pid, sub_type, seg_type):
        ri = RadioImage()

        print("Patient ID : ", pid)
        print("Subject Type : ", sub_type)
        print("Segment Type : ", seg_type)
        if "LUNG1" in pid:
            metadata = pd.read_csv('metadata/metadata_lung1.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        elif "interobs" in pid:
            metadata = pd.read_csv('metadata/metadata_md.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        else:
            metadata = pd.read_csv('metadata/metadata_test_retest.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        if sub_type == "None":
            img = ri.read_dicom_ct(subject)
            msk = ri.read_dicom_seg(subject, seg_type)
        else:
            img = ri.read_dicom_test_retest_ct(subject, sub_type)
            msk = ri.read_dicom_test_retest_seg(subject, sub_type, seg_type)

        image = sitk.GetArrayFromImage(img)
        mask = sitk.GetArrayFromImage(msk)
        print("Shape of image : ", image.shape)
        print("Shape of mask : ", mask.shape)
        # Ensure the number of slices are the same
        num_slices_img, num_slices_mask = image.shape[0], mask.shape[0]

        if num_slices_img != num_slices_mask:
            print("Warning: The number of slices in the image and mask are not the same!")

        # Determine the minimum number of slices to avoid indexing errors
        min_slices = min(num_slices_img, num_slices_mask)

        # Adjust the image and mask to have the same number of slices
        image = image[:min_slices]
        mask = mask[:min_slices]

        # Find non-empty slices efficiently using np.any()
        non_empty_slices = np.any(mask, axis=(1, 2))  # Check if any pixel is non-zero along height & width
        image = image[non_empty_slices]
        mask = mask[non_empty_slices]
        # print("Shape of image : ", image.shape)
        # print("Shape of mask : ", mask.shape)
        image = image.astype(np.float64)
        mask = mask.astype(np.float64)
        image -= image.min()
        image = 255*image / image.max()
        mask = 255 * mask
        # image = torch.from_numpy(image)
        # mask = torch.from_numpy(mask)
        # print("Shape of Image : ", image.shape)
        # print("Shape of Mask : ", mask.shape)
        # ri.display(img, msk, seg_type)
        return image, mask

    """Saving into tensor"""
    def entire_data_to_tensor(self, p_list):
        """ Reading image and Mask """
        pid, sub_type, seg_type = p_list[0].split("/")
        image, mask = self.load_array(pid, sub_type, seg_type)
        print(image.shape)
        print(mask.shape)
        print("After")
        p_list = np.delete(p_list, [0])
        for patient in p_list:
            print(patient)
            pid, sub_type, seg_type =patient.split("/")
            temp_image, temp_mask = self.load_array(pid, sub_type, seg_type)
            # print(temp_image.shape)
            # print(temp_mask.shape)
            image = torch.cat([image, temp_image], axis=0)
            mask = torch.cat([mask, temp_mask], axis=0)
        print(image.shape)
        print(mask.shape)
        return image, mask

    """Saving into Disk"""
    def entire_data_to_disk(self,p_list):
        path = "F:\\Idiot Developer\\radioGenomic\\Segementation\\data\\full data"
        for patient in p_list:
            pid, sub_type, seg_type = patient.split("/")
            image,mask = self.load_array(pid, sub_type, seg_type)
            for idx in tqdm(range(len(image))):
                if "RIDER" in pid :
                    file_name = f"{pid}_{sub_type}_{idx}.png"
                elif "interobs" in pid:
                    _, _, seg_num = seg_type.split("-")
                    file_name = f"{pid}_md_{seg_num}_{idx}.png"
                else:
                    file_name = f"{pid}_{idx}.png"
                image_path = os.path.join(path,"images",file_name)
                mask_path = os.path.join(path,"masks",file_name)
                # print(file_name)
                # print(image_path)
                # print(mask_path)
                # print(image[idx])
                # print(mask[idx])
                cv2.imwrite(image_path, image[idx])
                cv2.imwrite(mask_path, mask[idx])
            # print(image.shape)
            # print(mask.shape)
            # break


if __name__ == "__main__":
    # datapreprocess = DataPreprocess()
    # print("Final List")
    # patient_list = datapreprocess.get_data()
    # print(len(patient_list))
    # print(patient_list)
    metadata_lung1 = pd.read_csv('metadata/metadata_lung1.csv', sep=',', index_col=False)
    patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()
    index_of_error_patient = [patient_list_lung1.index(i) for i in ['LUNG1-128']]
    patient_list_lung1 = np.delete(patient_list_lung1, index_of_error_patient)
    print(patient_list_lung1)
    train_patient, valid_patient = train_test_split(patient_list_lung1,test_size = 0.1, random_state =42)
    train_patient, test_patient = train_test_split(train_patient, test_size=0.1, random_state=42)
    print("Number of Total Patients : ", len(patient_list_lung1))
    print("Number of Patients for Training : ",len(train_patient))
    print("Number of Patients for Validation : ",len(valid_patient))
    print("Number of Patients for Testing : ",len(test_patient))
    # transform = transforms.Compose([
    #     transforms.ToTensor(),
    # ])
    batch_size = 2
    # # Load dataset
    # print("Training Loading...")
    # train_dataset = PatientDatasetAllInOneTensor(train_patient, metadata_lung1, train=True)
    # print("Valid Loading...")
    # valid_dataset = PatientDatasetAllInOneTensor(valid_patient, metadata_lung1, train=False)
    # print("Testing Loading...")
    # test_dataset = PatientDatasetAllInOneTensor(test_patient, metadata_lung1, train=False)
    # Load dataset
    print("Training Loading...")
    train_dataset = PatientDataset2DUNet(train_patient, metadata_lung1, train=True)
    print("Valid Loading...")
    valid_dataset = PatientDataset2DUNet(valid_patient, metadata_lung1, train=False)
    print("Testing Loading...")
    test_dataset = PatientDataset2DUNet(test_patient, metadata_lung1, train=False)
    #
    # # train_dataset, valid_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # #
    # #
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=os.cpu_count()
    )
    #
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count()
    )

    #
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count()
    )
    # Fetch one batch
    for images, masks in train_loader:
        print(f"Batch Image Shape: {images.shape}")  # Expected: [8, 1, H, W]
        print(f"Batch Mask Shape: {masks.shape}")  # Expected: [8, 1, H, W]
        break
    # for image ,mask in test_loader:
    #     print("Shape of image : ", image.shape)
    #     print("Shape of mask : ", mask.shape)
    #
    # for image ,mask in valid_loader:
    #     print("Shape of image : ", image.shape)
    #     print("Shape of mask : ", mask.shape)
    print(f"Total images in Training Dataset: {len(train_dataset)}")
    print(f"Total images in Valid Dataset: {len(valid_dataset)}")
    print(f"Total images in Testing Dataset: {len(test_dataset)}")


    # # datapreprocess.load_array("LUNG1-172/None/GTV-1")
    # datapreprocess.entire_data_to_disk(patient_list)
    # np.random.seed(42)
    # p_list = np.random.choice(patient_list,60)
    # print(len(p_list))
    # image,mask = datapreprocess.entire_data_to_tensor(p_list)
    # torch.save(image, "F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # torch.save(mask, "F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    # image = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # mask = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    # print(image.shape)
    # print(mask.shape)
    # print(len(image))
    # patient_dataset = PatientDataset(image,mask)
    # train_dataset, test_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # print(train_dataset[0][0].shape)
    # print(train_dataset[0][1].shape)
    # print( len(train_dataset))
    # print(len(test_dataset))
    # print(train_dataset[0][0])
    # print(test_dataset[0][0])
    # plt.imshow(train_dataset[0][0],cmap="gray")
    # plt.show()
    # plt.imshow(train_dataset[0][1],cmap="gray")
    # plt.show()
    # image, mask = dp. __getitem__(500)

    # print(type(dp))