In [None]:
!nvidia-smi

Wed Dec  2 10:49:04 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.38       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### **Install/Import libraries**

In [None]:
%pip install -q "monai[nibabel, tqdm]"

In [None]:
import os
import glob
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from model import U2NET
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchsummary import summary

import monai
from monai.data import Dataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    CastToTyped,
    LoadNiftid,
    Orientationd,
    RandAffined,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandGaussianNoised,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
)
from monai.transforms import Compose

In [None]:
train_split = 0.8
batch_size = 1

keys = ('image', 'label')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### **Load data**

In [None]:
DIR = '/content/drive/MyDrive/COVID-19-20/Train'

images = sorted(glob.glob(os.path.join(DIR, '*_ct.nii.gz')))
labels = sorted(glob.glob(os.path.join(DIR, '*_seg.nii.gz')))
n_train = int(train_split * len(images)) + 1
n_val = int(len(images) - n_train)

train_data = [{keys[0]: img, keys[1]: seg} for img,seg in zip(images[:n_train], labels[:n_train])]
val_data = [{keys[0]: img, keys[1]: seg} for img,seg in zip(images[-n_val:], labels[-n_val:])]

In [None]:
def get_transforms(mode='train', keys=('image', 'label')):

    if mode == 'train':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),          # ensure at least 192x192
                             RandAffined(keys,
                                         prob=0.15,
                                         rotate_range=(-0.05, 0.05), scale_range=(-0.1, 0.1),
                                         mode=("bilinear", "nearest"),
                                         as_tensor_output=False),
                             RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=4),
                             RandGaussianNoised(keys[0], prob=0.15, std=0.01),
                             RandFlipd(keys, spatial_axis=0, prob=0.5),
                             RandFlipd(keys, spatial_axis=1, prob=0.5),
                             RandFlipd(keys, spatial_axis=2, prob=0.5),
                             CastToTyped(keys, dtype=(np.float32, np.uint8)),
                             ToTensord(keys)
                            ])
    if mode == 'val':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),          # ensure at least 192x192 if not then apply padd
                             RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=4),
                             CastToTyped(keys, dtype=(np.float32, np.uint8)),
                             ToTensord(keys)
                            ])
    if mode == 'infer':
        transform = Compose([
                             LoadNiftid(keys),
                             AddChanneld(keys),
                             Orientationd(keys, axcodes = "LPS"),
                             Spacingd(keys, pixdim=(1.25, 1.25, 3.0), mode=("bilinear", "nearest")[: len(keys)]),
                             ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
                             CastToTyped(keys, dtype=(np.float32,)),
                             ToTensord(keys)
                            ])
    return transform

train_transforms = get_transforms('train', keys)
val_transforms = get_transforms('val', keys)

In [None]:
# Create DataSet

train_ds = Dataset(data = train_data,
                   transform = train_transforms)
val_ds = Dataset(data = val_data,
                 transform = val_transforms)

In [None]:
# Set DataLoader

train_loader = DataLoader(train_ds,
                          batch_size=batch_size)
val_loader = DataLoader(val_ds,
                        batch_size=batch_size)         # image-level batch to the sliding window method, not the window-level batch

In [None]:
# explore DataLoader

print('Training data Info:')
dataiter = iter(train_loader)
data = dataiter.next()
images,labels = data['image'],data['label']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

print('\nValidation data Info:')
dataiter = iter(val_loader)
data = dataiter.next()
images,labels = data['image'],data['label']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

del dataiter, data, images, labels

### **Create Model**

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv3d):
        torch.nn.init.xavier_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)
    if isinstance(m, nn.BatchNorm3d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

model = U2NET(in_ch=1, out_ch=1).to(device)
model.apply(weights_init)

In [None]:
#summary(model, (1, 192, 192, 16))

### **Optimizer, Scheduler and Loss**

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=9, min_lr=0.00001, verbose=True)

class DiceLoss(nn.Module):

    def __init__(self, axis=(2, 3, 4), epsilon=0.00001):
        super().__init__()
        self.axis = axis
        self.epsilon = epsilon

    def get_loss(self, y_pred, y_true):
        dice_numerator = (2.0 * torch.sum(y_pred * y_true, axis=self.axis)) + self.epsilon
        dice_denominator = torch.sum(y_pred**2, dim=self.axis) + torch.sum(y_true**2, dim=self.axis) + self.epsilon
        dice_loss = 1 - torch.mean(dice_numerator / dice_denominator)
        return dice_loss

    def forward(self, d0, d1, d2, d3, d4, d5, d6, labels):
        loss0 = self.get_loss(d0, labels)
        loss1 = self.get_loss(d1, labels)
        loss2 = self.get_loss(d2, labels)
        loss3 = self.get_loss(d3, labels)
        loss4 = self.get_loss(d4, labels)
        loss5 = self.get_loss(d5, labels)
        loss6 = self.get_loss(d6, labels)

        loss = (loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6) / 7
        return loss

def dice_coefficient(y_true, y_pred, axis=(2, 3, 4), epsilon=0.00001):
    dice_numerator = (2.0 * torch.sum(y_pred * y_true, axis=axis)) + epsilon
    dice_denominator = torch.sum(y_pred, dim=axis) + torch.sum(y_true, dim=axis) + epsilon
    dice_coefficient = torch.mean(dice_numerator / dice_denominator)
    return dice_coefficient

### **Training**

In [None]:
epochs = 50
hist_train_loss = []
hist_val_loss = []
hist_train_dice = []
hist_val_dice = []
val_loss_min = np.Inf

soft_dice_loss = DiceLoss()
PATH = '/content/drive/MyDrive/COVID-19-20/'

#### *From Begining*

In [None]:
for epoch in range(epochs):
    train_loss = 0.0
    val_loss = 0.0
    num_train_sample = 0.0
    num_val_sample = 0.0

    dice_sum = 0.
    dice_count = 0
    model.train()
    for batch in tqdm(train_loader):
        #optimizer.zero_grad()
        optimizer.zero_grad(set_to_none=True)
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        d, d1, d2, d3, d4, d5, d6 = model(inputs)
        loss = soft_dice_loss(d, d1, d2, d3, d4, d5, d6, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*inputs.size(0)
        num_train_sample += inputs.size(0)
        dice = dice_coefficient(labels, d)
        dice_sum += dice
        dice_count += 1
        del d, d1, d2, d3, d4, d5, d6, loss              # del temporary outputs and loss
    train_dice = dice_sum / dice_count
    train_loss = train_loss / num_train_sample
    hist_train_loss.append(train_loss)
    hist_train_dice.append(train_dice)

    model.eval()
    with torch.no_grad():
        dice_sum = 0.
        dice_count = 0
        for batch in tqdm(val_loader):
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)

            d, d1, d2, d3, d4, d5, d6 = model(inputs)
            loss = soft_dice_loss(d, d1, d2, d3, d4, d5, d6, labels)
            val_loss += loss.item()*inputs.size(0)
            num_val_sample += inputs.size(0)
            dice = dice_coefficient(labels, d)
            dice_sum += dice
            dice_count += 1
        val_dice = dice_sum / dice_count
        val_loss = val_loss / num_val_sample
        hist_val_loss.append(val_loss)
        hist_val_dice.append(val_dice)

    print(f'\nEpoch: {epoch+1}: \nTrain loss:      {train_loss}, \tTrain Dice:      {train_dice}, \nValidation loss: {val_loss}, \tValidation Dice: {val_dice}')
    print('Learning Rate:', {optimizer.param_groups[0]['lr']})

    if val_loss <= val_loss_min:
        print(f'Validation loss is decreased from {val_loss_min} ---> {val_loss}.\nSaving Model ...')
        torch.save(model.state_dict(), PATH+'U2net3D_AutoEncoder_model.pth')
        val_loss_min = val_loss

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'hist_train_loss': hist_train_loss,
                'hist_val_loss': hist_val_loss,
                'hist_train_dice': hist_train_dice,
                'hist_val_dice': hist_val_dice,
                'val_loss_min': val_loss_min},
               PATH + 'U2net3D_AutoEncoder_model_checkpoints.pt')

100%|██████████| 160/160 [39:12<00:00, 14.70s/it]
100%|██████████| 39/39 [05:17<00:00,  8.13s/it]



Epoch: 1: 
Train loss:      0.8336648002266884, 	Train Dice:      0.09252109378576279, 
Validation loss: 0.7061381317101992, 	Validation Dice: 0.19312192499637604
Learning Rate: {0.001}
Validation loss is decreased from inf ---> 0.7061381317101992.
Saving Model ...


100%|██████████| 160/160 [26:44<00:00, 10.03s/it]
100%|██████████| 39/39 [02:30<00:00,  3.86s/it]



Epoch: 2: 
Train loss:      0.737138481810689, 	Train Dice:      0.17908094823360443, 
Validation loss: 0.7256916593282651, 	Validation Dice: 0.2031329870223999
Learning Rate: {0.001}


100%|██████████| 160/160 [26:43<00:00, 10.02s/it]
100%|██████████| 39/39 [02:27<00:00,  3.79s/it]



Epoch: 3: 
Train loss:      0.7116575364023447, 	Train Dice:      0.20767350494861603, 
Validation loss: 0.8067365510341449, 	Validation Dice: 0.13234709203243256
Learning Rate: {0.001}


100%|██████████| 160/160 [26:30<00:00,  9.94s/it]
100%|██████████| 39/39 [02:30<00:00,  3.86s/it]



Epoch: 4: 
Train loss:      0.7063496949151158, 	Train Dice:      0.22521273791790009, 
Validation loss: 0.7182975373206995, 	Validation Dice: 0.22539862990379333
Learning Rate: {0.001}


100%|██████████| 160/160 [26:51<00:00, 10.07s/it]
100%|██████████| 39/39 [02:32<00:00,  3.90s/it]



Epoch: 5: 
Train loss:      0.7104199239052832, 	Train Dice:      0.22383485734462738, 
Validation loss: 0.7222568400395222, 	Validation Dice: 0.21088379621505737
Learning Rate: {0.001}


100%|██████████| 160/160 [26:57<00:00, 10.11s/it]
100%|██████████| 39/39 [02:31<00:00,  3.89s/it]



Epoch: 6: 
Train loss:      0.6646099496632815, 	Train Dice:      0.263853520154953, 
Validation loss: 0.7170091561782055, 	Validation Dice: 0.21115851402282715
Learning Rate: {0.001}


100%|██████████| 160/160 [26:52<00:00, 10.08s/it]
100%|██████████| 39/39 [02:27<00:00,  3.78s/it]



Epoch: 7: 
Train loss:      0.6810004183091223, 	Train Dice:      0.25827500224113464, 
Validation loss: 0.7519041276895083, 	Validation Dice: 0.1833687424659729
Learning Rate: {0.001}


  0%|          | 0/160 [00:00<?, ?it/s]

Epoch     7: reducing learning rate of group 0 to 1.0000e-04.


 86%|████████▋ | 138/160 [23:03<03:40, 10.03s/it]


RuntimeError: ignored

#### *Training from checkpoints*

In [None]:
checkpoint = torch.load(PATH + 'U2net3D_AutoEncoder_model_checkpoints.pt')

epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer.defaults['lr'] = 0.0001

hist_train_loss = checkpoint['hist_train_loss']
hist_val_loss = checkpoint['hist_val_loss']
hist_train_dice = checkpoint['hist_train_dice']
hist_val_dice = checkpoint['hist_val_dice']
val_loss_min = checkpoint['val_loss_min']

In [None]:
while epoch < epochs:
    train_loss = 0.0
    val_loss = 0.0
    num_train_sample = 0.0
    num_val_sample = 0.0

    dice_sum = 0.
    dice_count = 0
    model.train()
    for batch in tqdm(train_loader):
        #optimizer.zero_grad()
        optimizer.zero_grad(set_to_none=True)
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        d, d1, d2, d3, d4, d5, d6 = model(inputs)
        loss = soft_dice_loss(d, d1, d2, d3, d4, d5, d6, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*inputs.size(0)
        num_train_sample += inputs.size(0)
        dice = dice_coefficient(labels, d)
        dice_sum += dice
        dice_count += 1
        del d, d1, d2, d3, d4, d5, d6, loss              # del temporary outputs and loss
    train_dice = dice_sum / dice_count
    train_loss = train_loss / num_train_sample
    hist_train_loss.append(train_loss)
    hist_train_dice.append(train_dice)

    model.eval()
    with torch.no_grad():
        dice_sum = 0.
        dice_count = 0
        for batch in tqdm(val_loader):
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)

            d, d1, d2, d3, d4, d5, d6 = model(inputs)
            loss = soft_dice_loss(d, d1, d2, d3, d4, d5, d6, labels)
            val_loss += loss.item()*inputs.size(0)
            num_val_sample += inputs.size(0)
            dice = dice_coefficient(labels, d)
            dice_sum += dice
            dice_count += 1
        val_dice = dice_sum / dice_count
        val_loss = val_loss / num_val_sample
        hist_val_loss.append(val_loss)
        hist_val_dice.append(val_dice)

    print(f'\nEpoch: {epoch+1}: \nTrain loss:      {train_loss}, \tTrain Dice:      {train_dice}, \nValidation loss: {val_loss}, \tValidation Dice: {val_dice}')
    print('Learning Rate:', {optimizer.param_groups[0]['lr']})

    if val_loss <= val_loss_min:
        print(f'Validation loss is decreased from {val_loss_min} ---> {val_loss}.\nSaving Model ...')
        torch.save(model.state_dict(), PATH+'U2net3D_AutoEncoder_model.pth')
        val_loss_min = val_loss

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'hist_train_loss': hist_train_loss,
                'hist_val_loss': hist_val_loss,
                'hist_train_dice': hist_train_dice,
                'hist_val_dice': hist_val_dice,
                'val_loss_min': val_loss_min},
               PATH + 'U2net3D_AutoEncoder_model_checkpoints.pt')

    epoch += 1

100%|██████████| 160/160 [25:50<00:00,  9.69s/it]
100%|██████████| 39/39 [02:27<00:00,  3.79s/it]



Epoch: 10: 
Train loss:      0.6596331609413028, 	Train Dice:      0.28460660576820374, 
Validation loss: 0.8296552575551547, 	Validation Dice: 0.13870349526405334
Learning Rate: {0.001}


 68%|██████▊   | 108/160 [17:58<09:19, 10.76s/it]