<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/NewUnetDataPreparation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pandas as pd
import numpy as np
from NewRadioImage import RadioImage
import SimpleITK as sitk
from torch.utils.data import Dataset
from itertools import chain
import torch
import matplotlib.pyplot as plt
import os
from glob import glob
from tqdm import tqdm
from albumentations import HorizontalFlip, VerticalFlip, Rotate
import cv2



# class PatientDataset(Dataset):
#     def __init__(self, image, mask):
#         self.image = image
#         self.mask = mask
#         self.n_samples = len(self.image)
#
#     def __getitem__(self, index):
#         img = torch.unsqueeze(self.image[index,:,:], 0)
#         msk = torch.unsqueeze(self.mask[index,:,:], 0)
#         return img,msk
#
#     def __len__(self):
#         return self.n_samples

class PatientDataset(Dataset):
    def __init__(self,images_path,masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self,index):
        """ Reading Image"""
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = image/image.max()
        image =image.astype(np.float32)
        image = np.expand_dims(image, axis=0)
        image = torch.from_numpy(image)

        """Reading Mask"""
        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask / mask.max()
        mask = mask.astype(np.float32)
        mask = np.expand_dims(mask,axis=0)
        mask = torch.from_numpy(mask)

        return image,mask

    def __len__(self):
        return self.n_samples




class DataPreprocess:
    def __init__(self):
        metadata_lung1 = pd.read_csv('metadata/metadata_lung1.csv', sep=',', index_col=False)
        patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_lung1.index(i) for i in ['LUNG1-128'] ]
        patient_list_lung1 = np.delete(patient_list_lung1, index_of_error_patient)
        patient_list_lung1 = [p + "/None" + "/GTV-1" for p in patient_list_lung1]

        metadata_md = pd.read_csv('metadata/metadata_md.csv', sep=',', index_col=False)
        patient_list_md = metadata_md["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_md.index(i) for i in ['interobs09']]
        patient_list_md = np.delete(patient_list_md, index_of_error_patient)

        patient_list_md1 = [p + "/None" + "/GTV-1vis-1" for p in patient_list_md]
        patient_list_md2 = [p + "/None" + "/GTV-1vis-2" for p in patient_list_md]
        patient_list_md3 = [p + "/None" + "/GTV-1vis-3" for p in patient_list_md]
        patient_list_md4 = [p + "/None" + "/GTV-1vis-4" for p in patient_list_md]
        patient_list_md5 = [p + "/None" + "/GTV-1vis-5" for p in patient_list_md]
        metadata_test = pd.read_csv('metadata/metadata_test_retest.csv', sep=',', index_col=False)
        patient_list_test_retest = metadata_test["Subject ID"].unique().tolist()
        index_of_error_patient = [patient_list_test_retest.index(i) for i in
                                  ['RIDER-2283289298', 'RIDER-5195703382', 'RIDER-8509201188']]
        patient_list_test_retest = np.delete(patient_list_test_retest, index_of_error_patient)
        patient_list_test = [p + "/TEST" + "/GTVp_test_man" for p in patient_list_test_retest]
        patient_list_retest = [p + "/RETEST" + "/GTVp_retest_man" for p in patient_list_test_retest]
        self.patient_list = list(chain(patient_list_lung1,patient_list_test,patient_list_retest,patient_list_md1,
                                       patient_list_md2,patient_list_md3,patient_list_md4,patient_list_md5))
    def get_data(self):
        return self.patient_list

    def load_array(self, pid, sub_type, seg_type):
        ri = RadioImage()

        print("Patient ID : ", pid)
        print("Subject Type : ", sub_type)
        print("Segment Type : ", seg_type)
        if "LUNG1" in pid:
            metadata = pd.read_csv('metadata/metadata_lung1.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        elif "interobs" in pid:
            metadata = pd.read_csv('metadata/metadata_md.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        else:
            metadata = pd.read_csv('metadata/metadata_test_retest.csv', sep=',', index_col=False)
            subject = metadata[metadata['Subject ID'] == pid]

        if sub_type == "None":
            img = ri.read_dicom_ct(subject)
            msk = ri.read_dicom_seg(subject, seg_type)
        else:
            img = ri.read_dicom_test_retest_ct(subject, sub_type)
            msk = ri.read_dicom_test_retest_seg(subject, sub_type, seg_type)

        image = sitk.GetArrayFromImage(img)
        mask = sitk.GetArrayFromImage(msk)
        print("Shape of image : ", image.shape)
        print("Shape of mask : ", mask.shape)
        # Ensure the number of slices are the same
        num_slices_img, num_slices_mask = image.shape[0], mask.shape[0]

        if num_slices_img != num_slices_mask:
            print("Warning: The number of slices in the image and mask are not the same!")

        # Determine the minimum number of slices to avoid indexing errors
        min_slices = min(num_slices_img, num_slices_mask)

        # Adjust the image and mask to have the same number of slices
        image = image[:min_slices]
        mask = mask[:min_slices]

        # Find non-empty slices efficiently using np.any()
        non_empty_slices = np.any(mask, axis=(1, 2))  # Check if any pixel is non-zero along height & width
        image = image[non_empty_slices]
        mask = mask[non_empty_slices]
        # print("Shape of image : ", image.shape)
        # print("Shape of mask : ", mask.shape)
        image = image.astype(np.float64)
        mask = mask.astype(np.float64)
        image -= image.min()
        image = 255*image / image.max()
        mask = 255 * mask
        # image = torch.from_numpy(image)
        # mask = torch.from_numpy(mask)
        # print("Shape of Image : ", image.shape)
        # print("Shape of Mask : ", mask.shape)
        # ri.display(img, msk, seg_type)
        return image, mask

    def entire_data_to_tensor(self, p_list):
        """ Reading image and Mask """
        pid, sub_type, seg_type = p_list[0].split("/")
        image, mask = self.load_array(pid, sub_type, seg_type)
        print(image.shape)
        print(mask.shape)
        print("After")
        p_list = np.delete(p_list, [0])
        for patient in p_list:
            print(patient)
            pid, sub_type, seg_type =patient.split("/")
            temp_image, temp_mask = self.load_array(pid, sub_type, seg_type)
            # print(temp_image.shape)
            # print(temp_mask.shape)
            image = torch.cat([image, temp_image], axis=0)
            mask = torch.cat([mask, temp_mask], axis=0)
        print(image.shape)
        print(mask.shape)

        return image, mask
    def entire_data_to_disk(self,p_list):
        path = "F:\\Idiot Developer\\radioGenomic\\Segementation\\data\\full data"
        for patient in p_list:
            pid, sub_type, seg_type = patient.split("/")
            image,mask = self.load_array(pid, sub_type, seg_type)
            for idx in tqdm(range(len(image))):
                if "RIDER" in pid :
                    file_name = f"{pid}_{sub_type}_{idx}.png"
                elif "interobs" in pid:
                    _, _, seg_num = seg_type.split("-")
                    file_name = f"{pid}_md_{seg_num}_{idx}.png"
                else:
                    file_name = f"{pid}_{idx}.png"
                image_path = os.path.join(path,"images",file_name)
                mask_path = os.path.join(path,"masks",file_name)
                # print(file_name)
                # print(image_path)
                # print(mask_path)
                # print(image[idx])
                # print(mask[idx])
                cv2.imwrite(image_path, image[idx])
                cv2.imwrite(mask_path, mask[idx])
            # print(image.shape)
            # print(mask.shape)
            # break


if __name__ == "__main__":
    datapreprocess = DataPreprocess()
    print("Final List")
    patient_list = datapreprocess.get_data()
    print(len(patient_list))
    print(patient_list)
    # # datapreprocess.load_array("LUNG1-172/None/GTV-1")
    datapreprocess.entire_data_to_disk(patient_list)
    # np.random.seed(42)
    # p_list = np.random.choice(patient_list,60)
    # print(len(p_list))
    # image,mask = datapreprocess.entire_data_to_tensor(p_list)
    # torch.save(image, "F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # torch.save(mask, "F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    # image = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # mask = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    # print(image.shape)
    # print(mask.shape)
    # print(len(image))
    # patient_dataset = PatientDataset(image,mask)
    # train_dataset, test_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # print(train_dataset[0][0].shape)
    # print(train_dataset[0][1].shape)
    # print( len(train_dataset))
    # print(len(test_dataset))
    # print(train_dataset[0][0])
    # print(test_dataset[0][0])
    # plt.imshow(train_dataset[0][0],cmap="gray")
    # plt.show()
    # plt.imshow(train_dataset[0][1],cmap="gray")
    # plt.show()
    # image, mask = dp. __getitem__(500)

    # print(type(dp))