In [17]:
from torch.utils.data import Dataset
from datasets import load_dataset

from pathlib import Path
import os

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import shutil
import numpy as np

from patchify import patchify

import nibabel as nib
import matplotlib.pyplot as plt

from nibabel.processing import resample_to_output

import sys
sys.path.append('/home/pawel/Documents/RISA/magisterka/segmentation_models.pytorch.3d')

# print(sys.path)

import segmentation_models_pytorch as smp
import torch
import torchmetrics

# Dataset

In [3]:
class CTDataset(Dataset):
    def __init__(self, images_ct_scans, images_ct_masks, transform=None):
        self.images_ct_scans = images_ct_scans
        self.images_ct_masks = images_ct_masks
        self.transform = transform

    def __len__(self):
        return len(self.images_ct_scans)

    def __getitem__(self, idx):
        image_file = self.images_ct_scans[idx]

        mask_file = self.images_ct_masks[idx]

        print(image_file.shape[-1])

        if self.transform is not None:
            transformed_images = []
            transformed_masks = []
            for i in range(0, image_file.shape[-1]):
                image_slice = image_file[..., i]

                mask_slice = mask_file[..., i]
                
                image_slice = image_slice.astype(np.int16)
                mask_slice = mask_slice.astype(np.uint8)

                transformed = self.transform(image=image_slice, mask=mask_slice)

                transformed_images.append(transformed["image"])
                transformed_masks.append(transformed["mask"])

            image_file = transformed_images
            mask_file = transformed_masks
        return image_file, mask_file

# Data Module

In [37]:
class CTDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        self.augmentations = A.Compose([
        # A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=128, width=128),
        # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2()
        ])
        self.transforms = A.Compose([
        # A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=128, width=128),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        A.ToFloat(max_value=255, always_apply=True),
        ToTensorV2(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self.ct_dataset= None

        self.all_ct_scan_patches = []
        self.all_ct_mask_patches = []


    def prepare_data(self):
        self.ct_dataset = load_dataset("andreped/AeroPath")

        print('len dataset', self.ct_dataset['test'].num_rows)

        print(self.ct_dataset)

        for index, img in enumerate(range(len(self.ct_dataset['test']))):
            print(index)     #just stop here to see all file names printed
                    
            large_image = self.ct_dataset['test'][img]

            ct_image = nib.load(large_image["ct"])
            ct_image = resample_to_output(ct_image, order=1)
            ct_data_scan = ct_image.get_fdata().astype("int16")
            # ct_data_scan[ct_data_scan < -1024] = -1024
            # ct_data_scan[ct_data_scan > 400] = 400

            # ct_data_scan = ct_data_scan + 1024

            ct_data_scan = np.clip(ct_data_scan, -1024, 400)
            ct_data_scan += 1024

            ct_mask = nib.load(large_image["airways"])
            ct_mask = resample_to_output(ct_mask, order=1)
            ct_data_mask = ct_mask.get_fdata().astype("uint8")
                    
            percent_size_x = 0.5  # Przykładowo 20% w osi X
            percent_size_y = 0.5  # Przykładowo 20% w osi Y
            fixed_size_z = 50  # Stała liczba sliców w osi Z

            scan_shape = ct_data_scan.shape

            patch_size_x = min(int(round(scan_shape[0] * percent_size_x)), scan_shape[0])
            patch_size_y = min(int(round(scan_shape[1] * percent_size_y)), scan_shape[1])

            patch_size = (patch_size_x, patch_size_y, fixed_size_z)

            step = patch_size

            patches_scan = patchify(ct_data_scan, patch_size, step=step)
            patches_mask = patchify(ct_data_mask, patch_size, step=step)

            print("TEST TEST TEST")

            # Pętle do iteracji przez otrzymane fragmenty danego skanu
            for i in range(patches_scan.shape[0]):
                for j in range(patches_scan.shape[1]):
                    for k in range(patches_scan.shape[2]):
                        patch_scan = patches_scan[i, j, k, :, :, :]
                        patch_mask = patches_mask[i, j, k, :, :, :]
                            
                        self.all_ct_scan_patches.append(patch_scan)
                        self.all_ct_mask_patches.append(patch_mask)

        print("FINISHING PREPARE DATA")

    def setup(self):
        # Split the data and assign datasets for use in dataloaders

        all_indices = np.arange(len(self.all_ct_scan_patches))

        self.all_ct_scan_patches = np.array(self.all_ct_scan_patches, dtype=object)
        self.all_ct_mask_patches = np.array(self.all_ct_mask_patches, dtype=object)

        train_index, val_index = train_test_split(all_indices, test_size = 0.3, random_state=42)
        val_index, test_index = train_test_split(val_index, test_size = 0.5, random_state=42)

        train_scans = self.all_ct_scan_patches[train_index]
        train_masks = self.all_ct_mask_patches[train_index]

        val_scans = self.all_ct_scan_patches[val_index]
        val_masks = self.all_ct_mask_patches[val_index]

        test_scans = self.all_ct_scan_patches[test_index]
        test_masks = self.all_ct_mask_patches[test_index]

        self.train_dataset = CTDataset(train_scans, train_masks, transform=self.augmentations)
        self.val_dataset = CTDataset(val_scans, val_masks, transform=self.transforms)
        self.test_dataset = CTDataset(test_scans, test_masks, transform=self.transforms)

        print("FINISHING SETUP")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=2)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=2)

# Create model

In [38]:
class AirWayModel(pl.LightningModule):

    def __init__(self, encoder_name: str):
        super().__init__()
        # self.model = smp.create_model(
        #     arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        # )
        self.model = smp.Unet_3D(
            encoder_name=encoder_name)

        self.loss_function = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

        metrics = torchmetrics.MetricCollection([
            torchmetrics.F1Score(task='BINARY'),
            torchmetrics.Precision(task='BINARY'),
            torchmetrics.Recall(task='BINARY')
        ])
        self.train_metrics = metrics.clone('train_')
        self.val_metrics = metrics.clone('val_')
        self.test_metrics = metrics.clone('test_')

        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        labels = torch.mean(labels, dim=1, keepdim=True)
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels)

        # self.accuracy.update(outputs, labels)
        # self.log('train_loss', loss, prog_bar=True)
        # self.log('train_acc', self.accuracy, prog_bar=True)
        # self.log('val_acc', self.validation_step, prog_bar=True)

        self.train_metrics.update(outputs, labels)
        self.log('train_loss', loss, prog_bar=True)
        self.log_dict(self.train_metrics)

        # self.append("train/loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        labels = torch.mean(labels, dim=1, keepdim=True)
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels)

        # self.accuracy.update(outputs, labels)

        # self.log('val_loss', loss, prog_bar=True)
        # self.log('val_acc', self.accuracy, prog_bar=True)

        self.val_metrics.update(outputs, labels)
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.val_metrics)

    def test_step(self, batch, batch_idx):
        labels = torch.mean(labels, dim=1, keepdim=True)
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels)

        # self.accuracy.update(outputs, labels)

        # self.log('test_loss', loss)
        # self.log('test_acc', self.accuracy)

        self.test_metrics.update(outputs, labels)
        self.log('test_loss', loss, prog_bar=True)
        self.log_dict(self.test_metrics)

    def configure_optimizers(self):
        # Tym razem użyjmy optimizera Adam - uczenie powinno być szybsze
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [39]:
# segment_model_3D = AirWayModel(arch='unet_3d', encoder_name='tu-resnet10t', in_channels=1, out_classes=2, encoder_depth = 4, decoder_channels = (128, 64, 32, 16))
segment_model_3D = AirWayModel(encoder_name='resnet18')

data_module = CTDataModule()

# checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor = 'val_BinaryF1Score', mode='max', verbose = True)

# # logger = pl.loggers.NeptuneLogger(project="gawron.pawel1999/ZPO-Lab06", api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI0MDE0NTRlMy03NDY2LTQ0OWYtYTcxMS1jMDE5OThlYjIyZjQifQ==")

trainer = pl.Trainer(accelerator='gpu',
                     max_epochs=10)
trainer.fit(segment_model_3D, datamodule = data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


len dataset 27
DatasetDict({
    test: Dataset({
        features: ['ct', 'airways', 'lungs'],
        num_rows: 27
    })
})


TypeError: CTDataModule.process_image() takes 1 positional argument but 2 were given