In [1]:
from torch.utils.data import Dataset
from datasets import load_dataset

from pathlib import Path
import os

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import shutil
import numpy as np

from patchify import patchify

import nibabel as nib
import matplotlib.pyplot as plt

from nibabel.processing import resample_to_output

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class CTDataset(Dataset):
    def __init__(self, images_ct_scans, images_ct_masks, transform=None):
        self.images_ct_scans = images_ct_scans
        self.images_ct_masks = images_ct_masks
        self.transform = transform

    def __len__(self):
        return len(self.images_ct_scans)

    def __getitem__(self, idx):
        image_file = str(self.images_ct_scans[idx])
        print("image_file", image_file)
        image = cv2.imread(image_file)
        print("image", image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # path_elements = list(Path(image_file).parts)
        # index = path_elements.index('images')
        # path_elements[index] = 'labels'

        mask_filepath = self.images_ct_masks[idx]
        mask = cv2.imread(str(mask_filepath))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        # if self.transform is not None:
        #     image = self.transform(image=image)["image"]
            # mask = self.transform(image=mask)["image"]
        return image, mask

In [15]:
class CTDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        self.augmentations = A.Compose([
        A.Resize(height=50, width=50),
        # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2()
        ])
        self.transforms = A.Compose([
        A.Resize(height=50, width=50),
        # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self.ct_dataset= None

        self.all_ct_scan_patches = []
        self.all_ct_mask_patches = []


    def prepare_data(self):
        self.ct_dataset = load_dataset("andreped/AeroPath")

        print('len dataset', self.ct_dataset['test'].num_rows)


        print(self.ct_dataset)
        for img in range(len(self.ct_dataset['test'])):
            # print(img)     #just stop here to see all file names printed
                    
            large_image = self.ct_dataset['test'][img]

            ct_image = nib.load(large_image["ct"])
            ct_image = resample_to_output(ct_image, order=1)
            ct_data_scan = ct_image.get_fdata().astype("int32")
            # print(ct_data_scan.shape)

            ct_mask = nib.load(large_image["airways"])
            ct_mask = resample_to_output(ct_mask, order=1)
            ct_data_mask = ct_mask.get_fdata().astype("int32")
                    
            percent_size_x = 0.5  # Przykładowo 20% w osi X
            percent_size_y = 0.5  # Przykładowo 20% w osi Y
            fixed_size_z = 50  # Stała liczba sliców w osi Z

            scan_shape = ct_data_scan.shape

            patch_size_x = min(int(round(scan_shape[0] * percent_size_x)), scan_shape[0])
            patch_size_y = min(int(round(scan_shape[1] * percent_size_y)), scan_shape[1])

            patch_size = (patch_size_x, patch_size_y, fixed_size_z)

            print(patch_size)

            step = patch_size

            patches_scan = patchify(ct_data_scan, patch_size, step=step)
            patches_mask = patchify(ct_data_mask, patch_size, step=step)

            # Pętle do iteracji przez otrzymane fragmenty danego skanu
            for i in range(patches_scan.shape[0]):
                for j in range(patches_scan.shape[1]):
                    for k in range(patches_scan.shape[2]):
                        patch_scan = patches_scan[i, j, k, :, :, :]
                        patch_mask = patches_mask[i, j, k, :, :, :]
                            
                        self.all_ct_scan_patches.append(patch_scan)
                        self.all_ct_mask_patches.append(patch_mask)

##################################################################### for testing purposes
            # print(len(all_img_patches))
            # fig, ax = plt.subplots(2, 2, figsize=(20, 12))
            # ax[0,0].cla()
            # ax[0,1].cla()
            # ax[1,0].cla()
            # ax[1,1].cla()
            # ax[0,0].imshow(all_img_patches[0][..., 1], cmap="gray")
            # ax[0,1].imshow(all_img_patches[patches.shape[2] + 1][..., 1], cmap="gray")
            # ax[1,0].imshow(all_img_patches[2*patches.shape[2] + 1][..., 1], cmap="gray")
            # ax[1,1].imshow(all_img_patches[3*patches.shape[2] + 1][..., 1], cmap="gray")

            # plt.show()

            # break
#####################################################################

        # images = np.array(all_img_patches)
        # images = np.expand_dims(images, -1)
                        
            if img == 1:
                break


    def setup(self):
        # Split the data and assign datasets for use in dataloaders

        all_indices = np.arange(self.ct_dataset['test'].num_rows)

        # self.all_ct_scan_patches = np.array([np.array(patch) for patch in self.all_ct_scan_patches])
        # self.all_ct_mask_patches = np.array([np.array(patch) for patch in self.all_ct_mask_patches])

        self.all_ct_scan_patches = np.array(self.all_ct_scan_patches, dtype=object)
        self.all_ct_mask_patches = np.array(self.all_ct_mask_patches, dtype=object)

        train_index, val_index = train_test_split(all_indices, test_size = 0.3, random_state=42)
        val_index, test_index = train_test_split(val_index, test_size = 0.5, random_state=42)

        # print('Typ 1:', self.all_ct_scan_patches)
        # print('Typ 2:', type(self.all_ct_mask_patches))
        print('Indeksy zbioru walidacyjnego:', val_index)
        print('Indeksy zbioru testowego:', test_index)
        print('Indeksy zbioru treningowego:', train_index)

        train_scans = self.all_ct_scan_patches[train_index]
        train_masks = self.all_ct_mask_patches[train_index]

        val_scans = self.all_ct_scan_patches[val_index]
        val_masks = self.all_ct_mask_patches[val_index]

        test_scans = self.all_ct_scan_patches[test_index]
        test_masks = self.all_ct_mask_patches[test_index]


        self.train_dataset = CTDataset(train_scans, train_masks, transform=self.augmentations)
        self.val_dataset = CTDataset(val_scans, val_masks, transform=self.transforms)
        self.test_dataset = CTDataset(test_scans, test_masks, transform=self.transforms)

        data = self.train_dataset.__getitem__(0)
        print(data)

        print(self.train_dataset)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=12)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=12)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=12)

In [16]:
datamodule = CTDataModule()
datamodule.prepare_data()
datamodule.setup()

len dataset 27
DatasetDict({
    test: Dataset({
        features: ['ct', 'airways', 'lungs'],
        num_rows: 27
    })
})
(194, 194, 50)
(190, 190, 50)
Indeksy zbioru walidacyjnego: [ 9  0 21 16]
Indeksy zbioru testowego: [17 13 11  8 12]
Indeksy zbioru treningowego: [24  1  4  5  2 15 22  3 25 23 18 26 20  7 10 14 19  6]
image_file [[[  398   383   310 ...   321   309   317]
  [  272   375   393 ...   309   297   288]
  [  181   299   350 ...   242   236   279]
  ...
  [-1005 -1004 -1007 ...  -993 -1011 -1003]
  [-1007 -1010 -1008 ...  -998 -1001 -1009]
  [    0     0     0 ...     0     0     0]]

 [[  491   452   375 ...   314   313   328]
  [  467   486   483 ...   309   276   278]
  [  382   418   368 ...   245   267   247]
  ...
  [ -997  -999 -1007 ...  -996  -997  -989]
  [-1009 -1011 -1013 ...  -996 -1001  -995]
  [    0     0     0 ...     0     0     0]]

 [[  442   460   384 ...   291   280   299]
  [  426   462   418 ...   267   241   258]
  [  394   451   379 ...   28

[ WARN:0@1710.625] global loadsave.cpp:248 findDecoder imread_('[[[  398   383   310 ...   321   309   317]
  [  272   375   393 ...   309   297   288]
  [  181   299   350 ...   242   236   279]
  ...
  [-1005 -1004 -1007 ...  -993 -1011 -1003]
  [-1007 -1010 -1008 ...  -998 -1001 -1009]
  [    0     0     0 ...     0     0     0]]

 [[  491   452   375 ...   314   313   328]
  [  467   486   483 ...   309   276   278]
  [  382   418   368 ...   245   267   247]
  ...
  [ -997  -999 -1007 ...  -996  -997  -989]
  [-1009 -1011 -1013 ...  -996 -1001  -995]
  [    0     0     0 ...     0     0     0]]

 [[  442   460   384 ...   291   280   299]
  [  426   462   418 ...   267   241   258]
  [  394   451   379 ...   288   262   246]
  ...
  [ -985  -992 -1002 ... -1006  -992  -998]
  [ -996 -1006 -1000 ... -1008  -997  -994]
  [    0     0     0 ...     0     0     0]]

 ...

 [[-1004 -1009 -1004 ...  -994 -1001  -990]
  [-1015  -996 -1004 ... -1003 -1004 -1013]
  [-1012  -998  -996 ... -

error: OpenCV(4.8.1) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'
