In [1]:
!nvidia-smi

Sun Nov 29 09:44:58 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.38       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### *Install/Import Libraries*

In [2]:
%pip install -q "monai[nibabel, tqdm]"

In [3]:
import os
import glob
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torchsummary import summary

import monai
from monai.data import Dataset, CacheDataset, DataLoader
from monai.losses import DiceLoss
from monai.metrics import compute_meandice
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    CastToTyped,
    LoadNiftid,
    Orientationd,
    RandAffined,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandGaussianNoised,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
)
from monai.transforms import Compose

In [4]:
train_split = 0.8
batch_size = 2

keys = ('image', 'label')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### *Load data*

In [5]:
DIR = './COVID-19-20/Train'

images = sorted(glob.glob(os.path.join(DIR, '*_ct.nii.gz')))
labels = sorted(glob.glob(os.path.join(DIR, '*_seg.nii.gz')))
n_train = int(train_split * len(images)) + 1
n_val = int(len(images) - n_train)

train_data = [{keys[0]: img, keys[1]: seg} for img,seg in zip(images[:n_train], labels[:n_train])]
val_data = [{keys[0]: img, keys[1]: seg} for img,seg in zip(images[-n_val:], labels[-n_val:])]

In [6]:
def get_transforms(mode='train', keys=('image', 'label')):

    if mode == 'train':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),          # ensure at least 192x192
                             RandAffined(keys,
                                         prob=0.15,
                                         rotate_range=(-0.05, 0.05), scale_range=(-0.1, 0.1),
                                         mode=("bilinear", "nearest"),
                                         as_tensor_output=False),
                             RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=4),
                             RandGaussianNoised(keys[0], prob=0.15, std=0.01),
                             RandFlipd(keys, spatial_axis=0, prob=0.5),
                             RandFlipd(keys, spatial_axis=1, prob=0.5),
                             RandFlipd(keys, spatial_axis=2, prob=0.5),
                             CastToTyped(keys, dtype=(np.float32, np.uint8)),
                             ToTensord(keys)
                            ])
    if mode == 'val':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),          # ensure at least 192x192 if not then apply padd
                             RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=4),
                             CastToTyped(keys, dtype=(np.float32, np.uint8)),
                             ToTensord(keys)
                            ])
    if mode == 'infer':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             CastToTyped(keys, dtype=(np.float32,)),
                             ToTensord(keys)
                            ])
    return transform

train_transforms = get_transforms('train', keys)
val_transforms = get_transforms('val', keys)

In [7]:
# Create DataSet

train_ds = Dataset(data = train_data,
                   transform = train_transforms)
val_ds = Dataset(data = val_data,
                 transform = val_transforms)

In [8]:
# Set DataLoader

train_loader = DataLoader(train_ds,
                          batch_size=batch_size)
val_loader = DataLoader(val_ds,
                        batch_size=batch_size)         # image-level batch to the sliding window method, not the window-level batch

In [10]:
# explore DataLoader

print('Training data Info:')
dataiter = iter(train_loader)
data = dataiter.next()
images,labels = data['image'],data['label']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

print('\nValidation data Info:')
dataiter = iter(val_loader)
data = dataiter.next()
images,labels = data['image'],data['label']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

Training data Info:
shape of images : torch.Size([8, 1, 192, 192, 16])
shape of labels : torch.Size([8, 1, 192, 192, 16])

Validation data Info:
shape of images : torch.Size([8, 1, 192, 192, 16])
shape of labels : torch.Size([8, 1, 192, 192, 16])


In [21]:
del dataiter, data, images, labels

### *Create Model*

In [9]:
def conv_block(in_chan, out_chan, final_layer=False):
    if not final_layer:
        return nn.Sequential(
            nn.Conv3d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False),
            #nn.InstanceNorm3d(out_chan),
            nn.BatchNorm3d(out_chan),
            nn.LeakyReLU(0.1))
    else:
        return nn.Sequential(
            nn.Conv3d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(out_chan))
            #nn.InstanceNorm3d(out_chan))

In [10]:
class Skipnet(nn.Module):

    def __init__(self, in_chan, out_chan, filters):
        super(Skipnet, self).__init__()

        self.in_chan = in_chan
        self.out_chan = out_chan
        self.filters = filters

        # Encoder
        self.conv_1 = conv_block(self.in_chan, self.filters * 1)
        self.conv_2 = conv_block(self.filters * 1, self.filters * 2)

        self.conv1 = conv_block(filters * 1, filters * 1)
        self.conv2 = conv_block(filters * 2, filters * 2)
        self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Bottel-neck
        self.bottel_neck = conv_block(self.filters * 2, self.filters * 4)
        self.dropout = nn.Dropout3d(0.1)

        # Decoder
        self.upsample = nn.Upsample(scale_factor=(2.0, 2.0, 2.0), mode='nearest')
        self.dconv_1 = conv_block(self.filters * 4, self.filters * 2)
        self.dconv_2 = conv_block(self.filters * 2, self.filters * 1)

        # Output and Conv_block
        self.output = conv_block(self.filters * 1, self.out_chan, final_layer=True)

    def forward(self, x):
        # Down Sampling                                                 # if filters=16 then, 
        conv1 = self.conv1(self.conv_1(x))                                         # (:, 16, :, :, 16)
        conv2 = self.conv2(self.conv_2(self.dropout(self.max_pool(conv1))))        # (:, 32, :, :, 8)

        # Bottel-neck
        bottelneck = self.bottel_neck(self.dropout(self.max_pool(conv2)))          # (:, 64, :, :, 4)

        # Up Sampling
        upconv1 = self.upsample(bottelneck)                                        # (:, 64, :, :, 8)
        upconv1 = self.dconv_1(self.dropout(upconv1))                              # (:, 32, :, :, :)

        upconv2 = self.upsample(upconv1)                                           # (:, 32, :, :, 16)
        upconv2 = self.dconv_2(self.dropout(upconv2))                              # (:, 16, :, :, :)

        out = self.output(upconv2)                                                 # (:, out_chan, :, :, 16)

        return out

In [11]:
class Unet3D(nn.Module):

    def __init__(self, in_chan, out_chan, filters):
        super(Unet3D, self).__init__()

        self.in_chan = in_chan
        self.out_chan = out_chan
        self.filters = filters

        # Encoder
        self.conv_1 = conv_block(self.in_chan, self.filters * 1)
        self.conv_2 = conv_block(self.filters * 1, self.filters * 2)
        self.conv_3 = conv_block(self.filters * 2, self.filters * 4)

        self.conv1 = conv_block(filters * 1, filters * 1)
        self.conv2 = conv_block(filters * 2, filters * 2)
        self.conv3 = conv_block(filters * 4, filters * 4)

        self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Bottel-neck
        self.bottel_neck = conv_block(self.filters * 4, self.filters * 8)
        self.dropout = nn.Dropout3d(0.1)

        # Decoder
        self.upsample = nn.Upsample(scale_factor=(2.0, 2.0, 2.0), mode='nearest')
        self.skip1 = Skipnet(in_chan=filters*4, out_chan=filters*4, filters=16)
        self.skip2 = Skipnet(in_chan=filters*2, out_chan=filters*2, filters=16)
        self.skip3 = Skipnet(in_chan=filters*1, out_chan=filters*1, filters=16)
        self.dconv_1 = conv_block(self.filters * 12, self.filters * 4)
        self.dconv_2 = conv_block(self.filters * 6, self.filters * 2)
        self.dconv_3 = conv_block(self.filters * 3, self.filters * 1)

        # Output and Conv_block
        self.output = conv_block(self.filters * 1, self.out_chan, final_layer=True)

    def forward(self, x):

        # Down Sampling
        conv1 = self.conv1(self.conv_1(x))                                            # (:, 16, :, :, 16)
        conv2 = self.conv2(self.conv_2(self.dropout(self.max_pool(conv1))))           # (:, 32, :, :, 8)
        conv3 = self.conv3(self.conv_3(self.dropout(self.max_pool(conv2))))           # (:, 64, :, :, 4)

        # Bottel-neck
        bottelneck = self.bottel_neck(self.dropout(self.max_pool(conv3)))             # (:, 128, :, :, 2)

        # Up Sampling
        upconv1 = self.upsample(bottelneck)                                           # (:, 128, :, :, 4)
        skip_1 = self.skip1(conv3)                                                    # (:, 64, :, :, 16)
        upconv1 = torch.cat([upconv1, skip_1], dim=1)                                 # (:, 128+64, :, :, :)
        upconv1 = self.dconv_1(self.dropout(upconv1))                                 # (:, 64, :, :, :)

        upconv2 = self.upsample(upconv1)                                              # (:, 64, :, :, 8)
        skip_2 = self.skip2(conv2)                                                    # (:, 32, :, :, 16)
        upconv2 = torch.cat([upconv2, skip_2], dim=1)                                 # (:, 64+32, :, :, :)
        upconv2 = self.dconv_2(self.dropout(upconv2))                                 # (:, 32, :, :, :)

        upconv3 = self.upsample(upconv2)                                              # (:, 32, :, :, 16)
        skip_3 = self.skip3(conv1)                                                    # (:, 16, :, :, 16)
        upconv3 = torch.cat([upconv3, skip_3], dim=1)                                 # (:, 32+16, :, :, :)
        upconv3 = self.dconv_3(self.dropout(upconv3))                                 # (:, 16, :, :, :)

        out = torch.sigmoid(self.output(upconv3))                                     # (:, 1, :, :, 16)

        return out

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv3d):
        torch.nn.init.xavier_normal_(m.weight)
    if isinstance(m, nn.BatchNorm3d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

model = Unet3D(in_chan=1, out_chan=1, filters=16).to(device)
model.apply(weights_init)

In [13]:
#summary(model, (1, 192, 192, 16))

### *Training*

#### Loss and Optimizer

In [14]:
epochs = 100
hist_train_loss = []
hist_val_loss = []
hist_train_dice = []
hist_val_dice = []
val_loss_min = np.Inf

PATH = './COVID-19-20/'

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)

def dice_coefficient(y_true, y_pred, axis=(2, 3, 4), epsilon=0.00001):
    dice_numerator = (2.0 * torch.sum(y_pred * y_true, axis=axis)) + epsilon
    dice_denominator = torch.sum(y_pred, dim=axis) + torch.sum(y_true, dim=axis) + epsilon
    dice_coefficient = torch.mean(dice_numerator / dice_denominator)
    return dice_coefficient

def soft_dice_loss(y_true, y_pred, axis=(2, 3, 4), epsilon=0.00001):
    dice_numerator = (2.0 * torch.sum(y_pred * y_true, axis=axis)) + epsilon
    dice_denominator = torch.sum(y_pred**2, dim=axis) + torch.sum(y_true**2, dim=axis) + epsilon
    dice_loss = 1 - torch.mean(dice_numerator / dice_denominator)
    return dice_loss

#### Training Loop

*  **optimizer.zero_grad(set_to_none=True)**
    
    *  instead of setting to zero, set the grads to None. This is will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

In [None]:
for epoch in range(epochs):
    train_loss = 0.0
    val_loss = 0.0
    num_train_sample = 0.0
    num_val_sample = 0.0

    dice_sum = 0.
    dice_count = 0
    model.train()
    for batch in tqdm(train_loader):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)
        #optimizer.zero_grad()
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = soft_dice_loss(labels, outputs)
        loss.backward(retain_graph=False)
        optimizer.step()
        train_loss += loss.item()*inputs.size(0)
        num_train_sample += inputs.size(0)
        dice = dice_coefficient(labels, outputs)
        dice_sum += dice
        dice_count += 1
    train_dice = dice_sum / dice_count
    train_loss = train_loss / num_train_sample
    hist_train_loss.append(train_loss)
    hist_train_dice.append(train_dice)

    model.eval()
    with torch.no_grad():
        dice_sum = 0.
        dice_count = 0
        for batch in tqdm(val_loader):
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            loss = soft_dice_loss(labels, outputs)
            val_loss += loss.item()*inputs.size(0)
            num_val_sample += inputs.size(0)
            dice = dice_coefficient(labels, outputs)
            dice_sum += dice
            dice_count += 1
        val_dice = dice_sum / dice_count
        val_loss = val_loss / num_val_sample
        hist_val_loss.append(val_loss)
        hist_val_dice.append(val_dice)

    print(f'\nEpoch: {epoch+1}: \nTrain loss:      {train_loss}, \tTrain Dice:      {train_dice}, \nValidation loss: {val_loss}, \tValidation Dice: {val_dice}')
    print('Learning Rate:', {optimizer.param_groups[0]['lr']})

    if val_loss <= val_loss_min:
        print(f'Validation loss is decreased from {val_loss_min} ---> {val_loss}.\nSaving Model ...')
        torch.save(model.state_dict(), PATH+'Unet3D_AutoEncoder_model.pth')
        val_loss_min = val_loss

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'hist_train_loss': hist_train_loss,
                'hist_val_loss': hist_val_loss,
                'hist_train_dice': hist_train_dice,
                'hist_val_dice': hist_val_dice,
                'val_loss_min': val_loss_min},
               PATH + 'Unet3D_AutoEncoder_model_checkpoints.pt')

    scheduler.step(val_loss)

### *Training from checkpoints*

In [17]:
checkpoint = torch.load(PATH + 'Unet3D_AutoEncoder_model_checkpoints.pt')

epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])

hist_train_loss = checkpoint['hist_train_loss']
hist_val_loss = checkpoint['hist_val_loss']
hist_train_dice = checkpoint['hist_train_dice']
hist_val_dice = checkpoint['hist_val_dice']
val_loss_min = checkpoint['val_loss_min']

In [18]:
epoch

13

In [19]:
while epoch < epochs:
    train_loss = 0.0
    val_loss = 0.0
    num_train_sample = 0.0
    num_val_sample = 0.0

    dice_sum = 0.
    dice_count = 0
    model.train()
    for batch in tqdm(train_loader):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)
        optimizer.zero_grad()
        #optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = soft_dice_loss(labels, outputs)
        loss.backward(retain_graph=False)
        optimizer.step()
        train_loss += loss.item()*inputs.size(0)
        num_train_sample += inputs.size(0)
        dice = dice_coefficient(labels, outputs)
        dice_sum += dice
        dice_count += 1
    train_dice = dice_sum / dice_count
    train_loss = train_loss / num_train_sample
    hist_train_loss.append(train_loss)
    hist_train_dice.append(train_dice)

    model.eval()
    with torch.no_grad():
        dice_sum = 0.
        dice_count = 0
        for batch in tqdm(val_loader):
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            loss = soft_dice_loss(labels, outputs)
            val_loss += loss.item()*inputs.size(0)
            num_val_sample += inputs.size(0)
            dice = dice_coefficient(labels, outputs)
            dice_sum += dice
            dice_count += 1
        val_dice = dice_sum / dice_count
        val_loss = val_loss / num_val_sample
        hist_val_loss.append(val_loss)
        hist_val_dice.append(val_dice)

    print(f'\nEpoch: {epoch+1}: \nTrain loss:      {train_loss}, \tTrain Dice:      {train_dice}, \nValidation loss: {val_loss}, \tValidation Dice: {val_dice}')
    print('Learning Rate:', {optimizer.param_groups[0]['lr']})

    if val_loss <= val_loss_min:
        print(f'Validation loss is decreased from {val_loss_min} ---> {val_loss}.\nSaving Model ...')
        torch.save(model.state_dict(), PATH+'Unet3D_AutoEncoder_model.pth')
        val_loss_min = val_loss

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'hist_train_loss': hist_train_loss,
                'hist_val_loss': hist_val_loss,
                'hist_train_dice': hist_train_dice,
                'hist_val_dice': hist_val_dice,
                'val_loss_min': val_loss_min},
               PATH + 'Unet3D_AutoEncoder_model_checkpoints.pt')

    print()
    scheduler.step(val_loss)
    epoch += 1

100%|██████████| 80/80 [20:05<00:00, 15.07s/it]
100%|██████████| 20/20 [02:04<00:00,  6.21s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 7: 
Train loss:      0.6898237600922584, 	Train Dice:      0.16612032055854797, 
Validation loss: 0.7313506786639874, 	Validation Dice: 0.14454589784145355
Learning Rate: {0.01}
Validation loss is decreased from 0.733027763855763 ---> 0.7313506786639874.
Saving Model ...



100%|██████████| 80/80 [20:24<00:00, 15.30s/it]
100%|██████████| 20/20 [01:58<00:00,  5.92s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 8: 
Train loss:      0.6796185404062272, 	Train Dice:      0.17473052442073822, 
Validation loss: 0.7575683211668943, 	Validation Dice: 0.1276153177022934
Learning Rate: {0.01}



100%|██████████| 80/80 [19:26<00:00, 14.59s/it]
100%|██████████| 20/20 [01:59<00:00,  5.97s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 9: 
Train loss:      0.6575667887926102, 	Train Dice:      0.19324195384979248, 
Validation loss: 0.7839267865205423, 	Validation Dice: 0.119595967233181
Learning Rate: {0.01}



100%|██████████| 80/80 [19:40<00:00, 14.75s/it]
100%|██████████| 20/20 [01:59<00:00,  6.00s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 10: 
Train loss:      0.6853767760097981, 	Train Dice:      0.1830826848745346, 
Validation loss: 0.9138992627461752, 	Validation Dice: 0.046740513294935226
Learning Rate: {0.01}



100%|██████████| 80/80 [19:45<00:00, 14.82s/it]
100%|██████████| 20/20 [01:59<00:00,  5.97s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 11: 
Train loss:      0.6928727746009826, 	Train Dice:      0.17283180356025696, 
Validation loss: 0.7314542868198493, 	Validation Dice: 0.16831672191619873
Learning Rate: {0.01}



100%|██████████| 80/80 [19:26<00:00, 14.59s/it]
100%|██████████| 20/20 [02:00<00:00,  6.01s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 12: 
Train loss:      0.6519338838756085, 	Train Dice:      0.1992054134607315, 
Validation loss: 0.6927454395171924, 	Validation Dice: 0.17935971915721893
Learning Rate: {0.01}
Validation loss is decreased from 0.7313506786639874 ---> 0.6927454395171924.
Saving Model ...



100%|██████████| 80/80 [20:23<00:00, 15.30s/it]
100%|██████████| 20/20 [02:04<00:00,  6.23s/it]
  0%|          | 0/80 [00:00<?, ?it/s]


Epoch: 13: 
Train loss:      0.6626156479120254, 	Train Dice:      0.1980482041835785, 
Validation loss: 0.7145175842138437, 	Validation Dice: 0.16964475810527802
Learning Rate: {0.01}



 34%|███▍      | 27/80 [07:18<14:20, 16.24s/it]


RuntimeError: ignored