In [32]:
from torch.utils.data import Dataset
from datasets import load_dataset

from pathlib import Path
import os
import torch.nn as nn

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import shutil
import numpy as np

from patchify import patchify

import nibabel as nib
import matplotlib.pyplot as plt

from nibabel.processing import resample_to_output

import sys
sys.path.append('/home/pawel/Documents/RISA/3D_segmentation/segmentation_models.pytorch.3d')

# print(sys.path)

import segmentation_models_pytorch as smp
import torch
import torchmetrics

# Dataset

In [262]:
class CTDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image_file = np.load(str(image_filepath))

        path_elements = list(Path(image_filepath).parts)
        index = path_elements.index('scans')
        path_elements[index] = 'airways'

        mask_filepath = os.path.join(*path_elements)
        mask_file = np.load(str(mask_filepath))

        if self.transform is not None:
            transformed_images = []
            transformed_masks = []
            for i in range(0, image_file.shape[-1]):
                
                image_slice = image_file[..., i]
                mask_slice = mask_file[..., i]

                image_slice = np.stack([image_slice] * 3, axis=-1)
                
                # image_slice = image_slice.astype(np.int16)
                # mask_slice = mask_slice.astype(np.int16)

                # print(type(image_slice[0][0][0]), type(mask_slice[0][0]))

                transformed = self.transform(image=image_slice, mask=mask_slice)

                transformed_images.append(transformed["image"])
                transformed_masks.append(transformed["mask"])

            image_file = torch.stack(transformed_images, dim=0)
            mask_file = torch.stack(transformed_masks, dim=0).type(torch.float32)

            image_file = image_file.permute(1, 0, 2, 3)
        return image_file, mask_file

# Data Module

In [263]:
class CTDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        self.augmentations = A.Compose([
        A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=64, width=64),
        # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        ToTensorV2()
        ])
        self.transforms = A.Compose([
        A.ToFloat(max_value=1024+400, always_apply=True),
        A.Resize(height=64, width=64),
        # A.Normalize(mean=[-1024/400], std=[1/400], always_apply=True),
        ToTensorV2(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self.path_to_file = '/home/pawel/Documents/RISA/3D_segmentation/dataset'


    def prepare_data(self):
        if os.path.exists(self.path_to_file):
            print("Path exists")
            images_paths = sorted(Path(self.path_to_file).rglob('*.npy'))
            print(images_paths)
            for image_path in images_paths:
                image = np.load(str(image_path))

                if image is None:
                    print("Unlink image: ", image_path)
                    image_path.unlink()
        else:
            print("Path does not exist")


    def setup(self, stage):
        paths = sorted(Path(os.path.join(self.path_to_file, 'scans')).glob('*.npy'))

        train_paths, val_paths = train_test_split(paths, test_size=0.3, random_state=42)
        val_paths, test_paths = train_test_split(val_paths, test_size=0.5, random_state=42)

        self.train_dataset = CTDataset(train_paths, transform=self.augmentations)
        self.val_dataset = CTDataset(val_paths, transform=self.transforms)
        self.test_dataset = CTDataset(test_paths, transform=self.transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=1)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1)

In [264]:
class MyDiceLoss(smp.losses.DiceLoss):
    def __init__(self, mode=smp.losses.BINARY_MODE, *args, **kwargs):
        super().__init__(mode=mode, *args, **kwargs)

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        # Zamiana kolejności wymiarów
        y_pred = y_pred.permute(0, 1, 3, 4, 2)
        y_true = y_true.permute(0, 1, 3, 4, 2)
        return super().forward(y_pred, y_true)

# Create model

In [271]:
class AirWayModel(pl.LightningModule):

    def __init__(self, encoder_name: str):
        super().__init__()
        # self.model = smp.create_model(
        #     arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        # )
        self.model = smp.Unet_3D(
            encoder_name=encoder_name)

        # self.loss_function = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        # self.loss_function = nn.CrossEntropyLoss()
        # self.loss_function = MyDiceLoss(mode=smp.losses.BINARY_MODE)
        self.loss_function = smp.losses.DiceLoss(mode='binary')

        metrics = torchmetrics.MetricCollection([
            torchmetrics.F1Score(task='BINARY'),
            torchmetrics.Precision(task='BINARY'),
            torchmetrics.Recall(task='BINARY')
        ])
        self.train_metrics = metrics.clone('train_')
        self.val_metrics = metrics.clone('val_')
        self.test_metrics = metrics.clone('test_')

        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        labels_tensor = labels.unsqueeze(0)
        labels = torch.mean(labels_tensor, dim=1, keepdim=True)
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels)

        # self.accuracy.update(outputs, labels)
        # self.log('train_loss', loss, prog_bar=True)
        # self.log('train_acc', self.accuracy, prog_bar=True)
        # self.log('val_acc', self.validation_step, prog_bar=True)

        self.train_metrics.update(outputs, labels)
        self.log('train_loss', loss, prog_bar=True)
        self.log_dict(self.train_metrics)

        # self.append("train/loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        labels_tensor = labels.unsqueeze(0)
        # print(labels)
        # print(inputs)
        labels_mean = torch.mean(labels_tensor, dim=0, keepdim=True)  # Oblicz średnią
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels_mean)

        self.val_metrics.update(outputs, labels_mean)
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.val_metrics)


    def test_step(self, batch, batch_idx):
        labels = torch.mean(labels, dim=1, keepdim=True)
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_function(outputs, labels)

        # self.accuracy.update(outputs, labels)

        # self.log('test_loss', loss)
        # self.log('test_acc', self.accuracy)

        self.test_metrics.update(outputs, labels)
        self.log('test_loss', loss, prog_bar=True)
        self.log_dict(self.test_metrics)

    def configure_optimizers(self):
        # Tym razem użyjmy optimizera Adam - uczenie powinno być szybsze
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [272]:
# segment_model_3D = AirWayModel(arch='unet_3d', encoder_name='tu-resnet10t', in_channels=1, out_classes=2, encoder_depth = 4, decoder_channels = (128, 64, 32, 16))
segment_model_3D = AirWayModel(encoder_name='resnet18_3D')

data_module = CTDataModule()

checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor = 'val_loss', mode='max', verbose = True)

# # logger = pl.loggers.NeptuneLogger(project="gawron.pawel1999/ZPO-Lab06", api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI0MDE0NTRlMy03NDY2LTQ0OWYtYTcxMS1jMDE5OThlYjIyZjQifQ==")

trainer = pl.Trainer(accelerator='gpu',
                     callbacks = [checkpoint_callback],
                     max_epochs=10)
trainer.fit(segment_model_3D, datamodule = data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Path exists
[PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_0.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_1.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_2.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_3.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_4.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_0_5.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_0.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_1.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_2.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airways/10_CT_HR_0_1_3.npy'), PosixPath('/home/pawel/Documents/RISA/3D_segmentation/dataset/airwa

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | Unet_3D          | 42.6 M
1 | loss_function | DiceLoss         | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
42.6 M    Trainable params
0         Non-trainable params
42.6 M    Total params
170.393   Total estimated model params size (MB)


Epoch 0: 100%|██████████| 375/375 [00:35<00:00, 10.46it/s, v_num=75, train_loss=1.000, val_loss=0.550]

Epoch 0, global step 375: 'val_loss' reached 0.55028 (best 0.55028), saving model to '/home/pawel/Documents/RISA/3D_segmentation/model_segmentation/lightning_logs/version_75/checkpoints/epoch=0-step=375.ckpt' as top 1


Epoch 1: 100%|██████████| 375/375 [00:36<00:00, 10.27it/s, v_num=75, train_loss=1.000, val_loss=0.566]

Epoch 1, global step 750: 'val_loss' reached 0.56567 (best 0.56567), saving model to '/home/pawel/Documents/RISA/3D_segmentation/model_segmentation/lightning_logs/version_75/checkpoints/epoch=1-step=750.ckpt' as top 1


Epoch 2: 100%|██████████| 375/375 [00:36<00:00, 10.30it/s, v_num=75, train_loss=1.000, val_loss=0.587]

Epoch 2, global step 1125: 'val_loss' reached 0.58698 (best 0.58698), saving model to '/home/pawel/Documents/RISA/3D_segmentation/model_segmentation/lightning_logs/version_75/checkpoints/epoch=2-step=1125.ckpt' as top 1


Epoch 3: 100%|██████████| 375/375 [00:36<00:00, 10.27it/s, v_num=75, train_loss=1.000, val_loss=0.458]

Epoch 3, global step 1500: 'val_loss' was not in top 1


Epoch 4: 100%|██████████| 375/375 [00:36<00:00, 10.32it/s, v_num=75, train_loss=1.000, val_loss=0.426]

Epoch 4, global step 1875: 'val_loss' was not in top 1


Epoch 5: 100%|██████████| 375/375 [00:36<00:00, 10.30it/s, v_num=75, train_loss=1.000, val_loss=0.415] 

Epoch 5, global step 2250: 'val_loss' was not in top 1


Epoch 6: 100%|██████████| 375/375 [00:36<00:00, 10.27it/s, v_num=75, train_loss=1.000, val_loss=0.412] 

Epoch 6, global step 2625: 'val_loss' was not in top 1


Epoch 7: 100%|██████████| 375/375 [00:36<00:00, 10.32it/s, v_num=75, train_loss=1.000, val_loss=0.390] 

Epoch 7, global step 3000: 'val_loss' was not in top 1


Epoch 8: 100%|██████████| 375/375 [00:36<00:00, 10.30it/s, v_num=75, train_loss=1.000, val_loss=0.392] 

Epoch 8, global step 3375: 'val_loss' was not in top 1


Epoch 9: 100%|██████████| 375/375 [00:36<00:00, 10.27it/s, v_num=75, train_loss=1.000, val_loss=0.393] 

Epoch 9, global step 3750: 'val_loss' was not in top 1
`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 375/375 [00:36<00:00, 10.27it/s, v_num=75, train_loss=1.000, val_loss=0.393]
