In [2]:
from torch.utils.data import Dataset
from datasets import load_dataset

from pathlib import Path
import os

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import shutil
import numpy as np

from patchify import patchify

import nibabel as nib
import matplotlib.pyplot as plt

from nibabel.processing import resample_to_output

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class CTDataset(Dataset):
    def __init__(self, images_ct_scans, images_ct_masks, transform=None):
        self.images_ct_scans = images_ct_scans
        self.images_ct_masks = images_ct_masks
        self.transform = transform

    def __len__(self):
        return len(self.images_ct_scans)

    def __getitem__(self, idx):
        image_file = self.images_ct_scans[idx]
        image = cv2.imread(image_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        path_elements = list(Path(image_file).parts)
        index = path_elements.index('images')
        path_elements[index] = 'labels'

        mask_filepath = os.path.join(*path_elements)
        mask = cv2.imread(str(mask_filepath))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        if self.transform is not None:
            image = self.transform(image=image)["image"]
            # mask = self.transform(image=mask)["image"]
        return image, mask

In [12]:
class CTDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        self.augmentations = A.Compose([
        A.Resize(height=50, width=50),
        # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2()
        ])
        self.transforms = A.Compose([
        A.Resize(height=50, width=50),
        # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self.ct_dataset= None

        self.all_ct_scan_patches = []
        self.all_ct_mask_patches = []


    def prepare_data(self):
        self.ct_dataset = load_dataset("andreped/AeroPath")

        print(self.ct_dataset)
        for img in range(len(self.ct_dataset['test'])):
            # print(img)     #just stop here to see all file names printed
                    
            large_image = self.ct_dataset['test'][img]

            ct_image = nib.load(large_image["ct"])
            ct_image = resample_to_output(ct_image, order=1)
            ct_data_scan = ct_image.get_fdata().astype("int32")
            # print(ct_data_scan.shape)

            ct_mask = nib.load(large_image["airways"])
            ct_mask = resample_to_output(ct_mask, order=1)
            ct_data_mask = ct_mask.get_fdata().astype("int32")
                    
            percent_size_x = 0.5  # Przykładowo 20% w osi X
            percent_size_y = 0.5  # Przykładowo 20% w osi Y
            fixed_size_z = 50  # Stała liczba sliców w osi Z

            scan_shape = ct_data_scan.shape

            patch_size_x = min(int(round(scan_shape[0] * percent_size_x)), scan_shape[0])
            patch_size_y = min(int(round(scan_shape[1] * percent_size_y)), scan_shape[1])

            patch_size = (patch_size_x, patch_size_y, fixed_size_z)

            print(patch_size)

            step = patch_size

            patches_scan = patchify(ct_data_scan, patch_size, step=step)
            patches_mask = patchify(ct_data_mask, patch_size, step=step)

            # Pętle do iteracji przez otrzymane fragmenty danego skanu
            for i in range(patches_scan.shape[0]):
                for j in range(patches_scan.shape[1]):
                    for k in range(patches_scan.shape[2]):
                        patch_scan = patches_scan[i, j, k, :, :, :]
                        patch_mask = patches_mask[i, j, k, :, :, :]
                            
                        self.all_ct_scan_patches.append(patch_scan)
                        self.all_ct_mask_patches.append(patch_mask)

##################################################################### for testing purposes
            # print(len(all_img_patches))
            # fig, ax = plt.subplots(2, 2, figsize=(20, 12))
            # ax[0,0].cla()
            # ax[0,1].cla()
            # ax[1,0].cla()
            # ax[1,1].cla()
            # ax[0,0].imshow(all_img_patches[0][..., 1], cmap="gray")
            # ax[0,1].imshow(all_img_patches[patches.shape[2] + 1][..., 1], cmap="gray")
            # ax[1,0].imshow(all_img_patches[2*patches.shape[2] + 1][..., 1], cmap="gray")
            # ax[1,1].imshow(all_img_patches[3*patches.shape[2] + 1][..., 1], cmap="gray")

            # plt.show()

            # break
#####################################################################

        # images = np.array(all_img_patches)
        # images = np.expand_dims(images, -1)
                        
            if img == 2:
                break


    def setup(self):
        # Split the data and assign datasets for use in dataloaders

        all_indices = np.arange(len(self.all_ct_scan_patches))

        train_index, val_index = train_test_split(all_indices, test_size = 0.3, random_state=42)
        val_index, test_index = train_test_split(val_index, test_size = 0.5, random_state=42)

        # print('Indeksy zbioru walidacyjnego:', val_scans)
        # print('Indeksy zbioru testowego:', test_scans)
        # print('Indeksy zbioru treningowego:', train_scans)

        train_scans = self.all_ct_scan_patches[train_index]
        train_masks = self.all_ct_mask_patches[train_index]

        val_scans = self.all_ct_scan_patches[val_index]
        val_masks = self.all_ct_mask_patches[val_index]

        test_scans = self.all_ct_scan_patches[test_index]
        test_masks = self.all_ct_mask_patches[test_index]


        self.train_dataset = CTDataset(train_scans, train_masks, transform=self.augmentations)
        self.val_dataset = CTDataset(val_scans, val_masks, transform=self.transforms)
        self.test_dataset = CTDataset(test_scans, test_masks, transform=self.transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=12)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=12)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=12)

In [13]:
datamodule = CTDataModule()
datamodule.prepare_data()
datamodule.setup()

DatasetDict({
    test: Dataset({
        features: ['ct', 'airways', 'lungs'],
        num_rows: 27
    })
})
(194, 194, 50)
(190, 190, 50)
(190, 190, 50)
Indeksy zbioru walidacyjnego: [34 12  4 44 36 25 48  8 37]
Indeksy zbioru testowego: [ 0  5 46 52 13 31  6 17  3]
Indeksy zbioru treningowego: [41 19 30 49 50 54 15  9 27 26 16 24 33 55 40 11 32 56 43 29 53  1 21  2
 45 39 35 23 47 10 22 18 57 20  7 42 14 28 51 38]
