In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install monai

In [3]:
import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import alias, export

import warnings
from typing import Sequence, Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import alias, export

In [4]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')

In [5]:
class ChannelAttention(nn.Module):
    def __init__(self, submodule, in_planes, out_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.submodule = submodule
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.in_planes = in_planes
        self.fc = nn.Sequential(nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False),
                               nn.GELU(),
                               nn.Conv3d(in_planes // ratio, out_planes, 1, bias=False))

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # print("CAT X = ", x.shape, flush=True)
        y = self.submodule(x)
        # print("CAT Y = ", y.shape, flush=True)
        x_av = self.avg_pool(x)
        # print("CAT AVG MID = ", x.shape,  self.in_planes, flush=True)
        # print("FC = ", self.fc, flush=True)
        avg_out = self.fc(x_av)
        # print("CAT AVG = ", avg_out.shape, self.in_planes, flush=True)
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        # print("CA output = ", y.shape, out.shape, avg_out.shape, self.in_planes, flush=True)
        # print("CHANNEL = ", out.shape)
        return y*self.sigmoid(out)

In [6]:
class SpatialAttention(nn.Module):
    def __init__(self, submodule, in_channels, kernel_size=7, out_channels=None, add_conv_1x1=False):
        super(SpatialAttention, self).__init__()
        self.submodule = submodule
        self.conv_flat = nn.Conv3d(in_channels, in_channels, 1, bias=False)
        self.conv1 = nn.Conv3d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.pool_layer = nn.MaxPool3d(2)
        self.upscale_layer = nn.Upsample(scale_factor=2, mode='nearest')
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.add_conv_1x1 = add_conv_1x1
        if add_conv_1x1:
            if out_channels is None:
                raise Exception("Out channels needed for conv 1x1")
            self.conv_1x1 = nn.Conv3d(in_channels, out_channels, 1, bias=False)
            self.act = torch.nn.PReLU()

    def forward(self, x):
        y = self.submodule(x)
        # x_2 = self.conv_flat(x)
        # print("X2 = ", x_2.shape)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        # print("X = ", x.shape, y.shape, avg_out.shape, max_out.shape, flush=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # print("X2 = ", x.shape, flush=True)
        x = self.conv1(x)
        if x.shape[-1] > y.shape[-1]:
            x = self.pool_layer(x)
        elif x.shape[-1] < y.shape[-1]:
            x = self.upscale_layer(x)
        # print("X3 = ", x.shape, y.shape, flush=True)
        # print("SPATIAL = ", x.shape)
        x = y*self.sigmoid(x)
        if self.add_conv_1x1:
            x = self.conv_1x1(x)
            x = self.act(x)
        return x

In [7]:
class ConvBlock(nn.Module):
    def __init__(self, in_c=32, out_c=32, kernel_size=3, activation=nn.GELU):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv3d(n_c, out_c, kernel_size, padding=kernel_size//2, bias=False)
        self.batchnorm = nn.BatchNorm3d(out_c)
        self.activation = activation()
        self.conv2 = nn.Conv3d(out_c, out_c, kernel_size, padding=kernel_size//2, bias=False)
        self.batchnorm2 = nn.BatchNorm3d(out_c)
        self.activation2 = activation()

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.activation2(x)
        return x


In [8]:
__all__ = ["UNet", "Unet", "unet"]


@export("monai.networks.nets")
@alias("Unet")
class UNet(nn.Module):
    def __init__(
        self,
        dimensions: int,
        in_channels: int,
        out_channels: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        up_kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 0,
        act: Union[Tuple, str] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout=0.0,
    ) -> None:
        super().__init__()

        if len(channels) < 2:
            raise ValueError(
                "the length of `channels` should be no less than 2.")
        delta = len(strides) - (len(channels) - 1)
        if delta < 0:
            raise ValueError(
                "the length of `strides` should equal to `len(channels) - 1`.")
        if delta > 0:
            warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.")
        if isinstance(kernel_size, Sequence):
            if len(kernel_size) != dimensions:
                raise ValueError(
                    "the length of `kernel_size` should equal to `dimensions`.")
        if isinstance(up_kernel_size, Sequence):
            if len(up_kernel_size) != dimensions:
                raise ValueError(
                    "the length of `up_kernel_size` should equal to `dimensions`.")

        self.dimensions = dimensions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.channels = channels
        self.strides = strides
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout

        def _create_block(
            inc: int,
            outc: int,
            channels: Sequence[int],
            strides: Sequence[int],
            is_top: bool) -> nn.Sequential:
            print("IN CREATE BLOCK : ", inc, outc, channels, strides, is_top)
            c = channels[0]
            s = strides[0]
            print(c, channels)
            subblock: nn.Module
            if len(channels)>2:
                subblock = _create_block(
                    c, c, channels[1:], strides[1:], False)                
                upc = c * 2
                if len(channels) > len(self.channels)-1:
                    add_spatial = True
                    add_channel = False
                else:
                    add_spatial = False
                    add_channel = False
            else:
                subblock = self._get_bottom_layer(c, channels[1])
                print("CREATED bottom LAYER : ", inc, outc, channels, strides, is_top)
                subblock = ChannelAttention(subblock, in_planes=c, out_planes=channels[1])
                # subblock = ChannelAttention(subblock, in_planes=channels[0])
                upc = c + channels[1]
                add_spatial = False
                add_channel = False

            down = self._get_down_layer(inc, c, s, is_top)
            print("CREATED DOWN LAYER : ", inc, c, s, is_top)
            if add_spatial:
                down = SpatialAttention(down, in_channels=inc)
            if add_channel:
                down = ChannelAttention(down, in_planes=inc, out_planes=c)

            
            if len(channels)==len(self.channels) and add_spatial:
                up = self._get_up_layer(upc, upc, s, is_top)
            else:
                up = self._get_up_layer(upc, outc, s, is_top)
            print("CREATED UP LAYER : ", upc, outc, s, is_top)
            # print(up, flush=True)
            if add_spatial:
                # print("CHANNELS = ", len(channels), channels[0], channels[1], flush=True)
                print("CH = ",channels, len(channels))
                add_conv_1x1 = len(channels)==len(self.channels)
                up = SpatialAttention(up, in_channels=upc, out_channels=outc, add_conv_1x1=add_conv_1x1)
            if add_channel:
                up = ChannelAttention(up, in_planes=upc, out_planes=outc)

            print("OUT OF CREATE BLOCK : ", inc, outc, channels, strides, is_top)
            return nn.Sequential(down, SkipConnection(subblock), up)

        self.model = _create_block(
            in_channels, out_channels, self.channels, self.strides, True)

    def _get_down_layer(self,
        in_channels: int,
        out_channels: int,
        strides: int,
        is_top: bool) -> nn.Module:
        # print("CREATING DOWN LAYER : ", in_channels, out_channels, strides, is_top)
        if self.num_res_units > 0:
            return ResidualUnit(
                self.dimensions,
                in_channels,
                out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                subunits=self.num_res_units,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
            )
        return Convolution(
            self.dimensions,
            in_channels,
            out_channels,
            strides=strides,
            kernel_size=self.kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
        )

    def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:

        # print("CREATING bottom LAYER : ", in_channels, out_channels)
        return self._get_down_layer(in_channels, out_channels, 1, False)

    def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:

        # print("CREATING UP LAYER : ", in_channels, out_channels, strides, is_top)
        conv: Union[Convolution, nn.Sequential]

        conv = Convolution(
            self.dimensions,
            in_channels,
            out_channels,
            strides=strides,
            kernel_size=self.up_kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
            conv_only=is_top and self.num_res_units == 0,
            is_transposed=True,
        )

        if self.num_res_units > 0:
            ru = ResidualUnit(
                self.dimensions,
                out_channels,
                out_channels,
                strides=1,
                kernel_size=self.kernel_size,
                subunits=1,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                last_conv_only=is_top,
            )
            conv = nn.Sequential(conv, ru)

        return conv

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x


Unet = unet = UNet

In [None]:
!pip install -U monai

In [10]:
import os
best_metric_model_file = "last_model.pth"

print(list(os.listdir('./')), flush=True)

import psutil
import argparse
import json
import logging
import sys
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    RandAffined,
    RandShiftIntensityd,
    Rand3DElasticd,
    RandFlipd,
    RandGaussianNoised,
    CenterSpatialCropd,
    SpatialPadd,
    ToTensord,
    CastToTyped
)
# from monai.handlers.utils import 
from monai.networks.nets import UNet, UNETR, DynUNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch,SmartCacheDataset
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np

['.config', 'drive', 'sample_data']


In [11]:
root_data = '/content/drive/MyDrive/kits23/dataset'
root_dir = '/content/drive/MyDrive/model'

train_images=[]
train_labels=[]
for case in list(sorted(os.listdir(root_data))):
        train_images.append(root_data+"/"+case+"/imaging.nii.gz"),
        train_labels.append(root_data+"/"+case+"/segmentation.nii.gz")

data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
# print(data_dicts[:6])

In [12]:
train_files, val_files = data_dicts[80:], data_dicts[:80]
# val_files.extend([data_dicts[151],data_dicts[156]])
print("VAL FILES = ", len(val_files), flush=True)
print("TRAIN FILES = ", len(train_files), flush=True)
print("MEMORY = ", str(round(psutil.virtual_memory().total / (1024.0 **3)))+" GB", flush=True)

VAL FILES =  80
TRAIN FILES =  409
MEMORY =  13 GB


In [13]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            2, 1.62, 1.62), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-80, a_max=305,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        SpatialPadd(keys=["image", "label"], spatial_size=(160,160,64), mode="constant"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(160, 160, 64),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        Rand3DElasticd(
            keys=["image", "label"],
            mode=("bilinear", "nearest"),
            prob=0.5,
            sigma_range=(5, 8),
            magnitude_range=(50, 150),
            spatial_size=(160, 160, 64),
            translate_range=(10, 10, 5),
            rotate_range=(np.pi/36,np.pi/36, np.pi),
            scale_range=(0.1, 0.1, 0.1),
            padding_mode="zeros",
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.25,
        ),
        RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1),
        CastToTyped(keys=["label"], dtype=np.uint8),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            2, 1.62, 1.62), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],a_min=-80, a_max=305,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        SpatialPadd(keys=["image", "label"], spatial_size=(160,160,64), mode="constant"),
        CastToTyped(keys=["label"], dtype=np.uint8),
        EnsureTyped(keys=["image", "label"]),
    ]
)

In [14]:
train_files = [d for d in train_files if d.get('image') != '/content/drive/MyDrive/kits23/dataset/case_00554/imaging.nii.gz']

In [15]:
print("CREATING TRAIN DS", flush=True)
print(train_files)
train_ds = SmartCacheDataset(
    data=train_files, transform=train_transforms,
    cache_rate=0.1, replace_rate=0.5)
print(len(train_ds))
# train_ds = Dataset(data=train_files, transform=train_transforms)
print("CREATED TRAIN DS", flush=True)
# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
print("CREATED TRAIN DATALOADER", flush=True)

CREATING TRAIN DS
[{'image': '/content/drive/MyDrive/kits23/dataset/case_00080/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00080/segmentation.nii.gz'}, {'image': '/content/drive/MyDrive/kits23/dataset/case_00081/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00081/segmentation.nii.gz'}, {'image': '/content/drive/MyDrive/kits23/dataset/case_00082/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00082/segmentation.nii.gz'}, {'image': '/content/drive/MyDrive/kits23/dataset/case_00083/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00083/segmentation.nii.gz'}, {'image': '/content/drive/MyDrive/kits23/dataset/case_00084/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00084/segmentation.nii.gz'}, {'image': '/content/drive/MyDrive/kits23/dataset/case_00085/imaging.nii.gz', 'label': '/content/drive/MyDrive/kits23/dataset/case_00085/segmentation.nii.gz'}, {'image': '/content/drive/M

Loading dataset: 100%|██████████| 40/40 [05:17<00:00,  7.93s/it]

40
CREATED TRAIN DS
CREATED TRAIN DATALOADER





In [None]:
val_ds = SmartCacheDataset(
    data=val_files, transform=val_transforms, cache_rate=0.1,replace_rate=0.5)
# val_ds = Dataset(data=val_files, transform=val_transforms)
print("CREATED VAL DS", flush=True)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)
print("CREATED VAL DATALOADER", flush=True)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')

In [None]:
print("CREATING MODEL", flush=True)
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=4,
    channels=(64, 128, 256, 512, 512),
    strides=(2, 2, 2, 2), 
)
# model.load_state_dict(torch.load(best_metric_model_file,
#     map_location=torch.device(device)))

print("CREATED MODEL", flush=True)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [None]:
max_epochs = 25
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=4)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=4)])
model.to(device)

for epoch in range(max_epochs):
    print("-" * 10, flush=True)
    print(f"epoch {epoch + 1}/{max_epochs}", flush=True)
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}", flush=True)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}", flush=True)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160,160, 64)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    root_dir, "best_metric_model_"+str(epoch)+"_"+str(f"{metric:.4f}")+".pth"))
                print("saved new best metric model", flush=True)
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}", flush=True
            )

torch.save(model.state_dict(), os.path.join(
                    root_dir, "last_model.pth"))
print(epoch_loss_values, flush=True)
print(metric_values, flush=True)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}", flush=True)



----------
epoch 1/25
1/20, train_loss: 2.4418
2/20, train_loss: 2.3788
3/20, train_loss: 2.3247
4/20, train_loss: 2.2836
5/20, train_loss: 2.2518
6/20, train_loss: 2.2114
7/20, train_loss: 2.1732
8/20, train_loss: 2.1501
9/20, train_loss: 2.1211
10/20, train_loss: 2.0984
11/20, train_loss: 2.0691
12/20, train_loss: 2.0587
13/20, train_loss: 2.0306
14/20, train_loss: 2.0101
15/20, train_loss: 1.9869
16/20, train_loss: 1.9626
17/20, train_loss: 1.9407
18/20, train_loss: 1.9282
19/20, train_loss: 1.9138
20/20, train_loss: 1.8963
epoch 1 average loss: 2.1116
saved new best metric model
current epoch: 1 current mean dice: 0.0194
best mean dice: 0.0194 at epoch: 1
----------
epoch 2/25
1/20, train_loss: 1.8669
2/20, train_loss: 1.8751
3/20, train_loss: 1.8459
4/20, train_loss: 1.8327
5/20, train_loss: 1.8078
6/20, train_loss: 1.7977
7/20, train_loss: 1.8020
8/20, train_loss: 1.7661
9/20, train_loss: 1.7547
10/20, train_loss: 1.7543
11/20, train_loss: 1.7297
12/20, train_loss: 1.7208
13/20, 