In [2]:
from torch.utils.data import Dataset
from datasets import load_dataset

from pathlib import Path
import os

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import shutil
import numpy as np

from patchify import patchify

import nibabel as nib
import matplotlib.pyplot as plt

from nibabel.processing import resample_to_output
from concurrent.futures import ThreadPoolExecutor
import concurrent
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print("Wersja CUDNN:", torch.backends.cudnn.version())

Wersja CUDNN: 8902


In [13]:
class CTDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image_file = np.load(str(image_filepath))

        path_elements = list(Path(image_filepath).parts)
        print(path_elements)
        index = path_elements.index('scans')
        path_elements[index] = 'airways'

        mask_filepath = os.path.join(*path_elements)
        mask_file = np.load(str(mask_filepath))

        if self.transform is not None:
            transformed_images = []
            transformed_masks = []
            for i in range(0, image_file.shape[-1]):
                # print(image_file[..., i].shape)
                # print(i)
                # print("image_file[..., i]", image_file[..., i])
                image_slice = image_file[..., i]

                # print("type: ", type(image_slice))
                # print("image_slice: ", image_slice)
                # print("image_slice shape: ", image_slice.shape)
                mask_slice = mask_file[..., i]

                image_slice = np.stack([image_slice] * 3, axis=-1)
                
                image_slice = image_slice.astype(np.int16)
                mask_slice = mask_slice.astype(np.uint8)

                # print("max: ", np.max(image_slice))
                # print("min: ", np.min(image_slice))
                # print("dtype: ", image_slice[0].dtype)

                transformed = self.transform(image=image_slice, mask=mask_slice)

                transformed_images.append(transformed["image"])
                transformed_masks.append(transformed["mask"])

            image_file = transformed_images
            mask_file = transformed_masks
        return image_file, mask_file

In [20]:
class CTDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        self.augmentations = A.Compose([
        A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=64, width=64),
        # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        ToTensorV2()
        ])
        self.transforms = A.Compose([
        A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=64, width=64),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        ToTensorV2(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self.ct_dataset= None

        self.path_to_file = '/home/pawel/Documents/RISA/3D_segmentation/dataset'


    def prepare_data(self):
        if os.path.exists(self.path_to_file):
            print("Path exists")
            images_paths = sorted(Path(self.path_to_file).rglob('*.npy'))
            print(images_paths)
            for image_path in images_paths:
                image = np.load(str(image_path))

                if image is None:
                    print("Unlink image: ", image_path)
                    image_path.unlink()
        else:
            print("Path does not exist")


    def setup(self):
        # Split the data and assign datasets for use in dataloaders

        paths = sorted(Path(os.path.join(self.path_to_file, 'scans')).glob('*.npy'))

        train_paths, val_paths = train_test_split(paths, test_size = 0.3, random_state=42)
        val_paths, test_paths = train_test_split(val_paths, test_size = 0.5, random_state=42)

        # print('Typ 1:', self.all_ct_scan_patches)
        # print('Typ 2:', type(self.all_ct_mask_patches))
        print('Indeksy zbioru walidacyjnego:', len(val_paths), val_paths)
        print('Indeksy zbioru testowego:', len(test_paths), test_paths)
        print('Indeksy zbioru treningowego:', len(train_paths), train_paths)

        self.train_dataset = CTDataset(train_paths, transform=self.augmentations)
        self.val_dataset = CTDataset(val_paths, transform=self.transforms)
        self.test_dataset = CTDataset(test_paths, transform=self.transforms)

        data, mask = self.train_dataset.__getitem__(0)
        print("len1", len(data))
        print("len2", np.shape(data[0]))
        print("data.shape", np.shape(data))
        print("dtype: ", type(data))

        plt.imshow(data[0].reshape(50, 50), cmap='gray')
        plt.show()

        print("data: ", data[0])

        print(self.train_dataset)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=12)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=12)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=12)

In [21]:
datamodule = CTDataModule()
datamodule.prepare_data()
datamodule.setup()

Path exists
[PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_0.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_1.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_2.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_3.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_4.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_5.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_0.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_1.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_2.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_3.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airwa

RuntimeError: shape '[50, 50]' is invalid for input of size 12288