In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
arr = os.listdir('../input/covid19-ct-scans/ct_scans')

In [None]:
arr.sort()
for i in range(len(arr)):
    print(arr[i])

In [None]:
arr1 = os.listdir('../input/covid19-ct-scans/infection_mask')
arr1.sort()

In [None]:

for i in range(len(arr1)):
    print(arr1[i])

In [None]:
import nibabel as nib
import os
import torch
import torchvision.transforms as T

# Data original source downloaded from : https://zenodo.org/record/3757476#.Xp0FhB9fgUE
# Same data in Kaggle  : https://www.kaggle.com/andrewmvd/covid19-ct-scans

# Data and mask path
data_path = "../input/covid19-ct-scans/ct_scans/"
mask_path = "../input/covid19-ct-scans/infection_mask/"

# Lung window in HU units
HU_min = -1000
HU_max = 400

# names of files in data folder
# arr = os.listdir(data_path)

# select train and val subjects
train_data_numbers = range(16, 19)  # data from coronacases.org
# val_data_numbers = range(5, 6)  # data from radiopedia.org

# Resizing function for image and mask according to the model input dimension
resize_image = T.Resize(size=(572, 572))


# function to preprocess data
def data_preprocess(path, hu_min, hu_max):
    volume_data = nib.load(path)  # load data
    volume_data_numpy = volume_data.get_fdata()  # get data as numpy
    volume_data_tensor = torch.tensor(volume_data_numpy)  # convert to torch tensor
    volume_data_tensor_clamped = torch.clamp(volume_data_tensor, min=hu_min, max=hu_max)  # apply HU lung window
    volume_data_tensor_clamped_normalized = (volume_data_tensor_clamped-hu_min) / (hu_max-hu_min)  # normalize to [0,1]
    return volume_data_tensor_clamped_normalized


# function to obtain maask
def mask_obtain(fpath):
    mask = nib.load(fpath)  # load mask
    mask_numpy = mask.get_fdata()  # get mask as numpy
    mask_tensor = torch.tensor(mask_numpy)  # convert to torch tensor
    return mask_tensor


# add zero padding if size of image less than required input size
def padding_size(slices):
    if (572-slices.size()[1]) % 2 == 0:  # See if the difference between required size and data size is even or odd
        # if difference is even, pad same number to either side
        pad1 = (572-slices.size()[1]) // 2
        pad2 = (572-slices.size()[1]) // 2
    else:
        # if difference is even, pad one side one value more than other
        pad1 = (572-slices.size()[1]) // 2
        pad2 = ((572-slices.size()[1]) // 2)+1

    if (572-slices.size()[2]) % 2 == 0:  # See if the difference between required size and data size is even or odd
        # if difference is even, pad same number to either side
        pad3 = (572-slices.size()[2]) // 2
        pad4 = (572-slices.size()[2]) // 2
    else:
        # if difference is even, pad one side one value more than other
        pad3 = (572-slices.size()[2]) // 2
        pad4 = ((572-slices.size()[2]) // 2)+1

    return [pad4, pad3, pad2, pad1]  # return the number of zero padding in each side of the slice

In [None]:
# for i in range(len(arr)) :
#     print (arr[i])

In [None]:
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# training data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Initialize for stacking
train_data = torch.empty((1, 572, 572))
train_label = torch.empty((1, 572, 572))

# Function to form train data and train label
for i in train_data_numbers:
    file_path = data_path +arr[i]  # path of the data
    data = data_preprocess(file_path, HU_min, HU_max)  # preprocess data
    data = data.permute(2, 0,
                        1)  # change the dimension (H,W,C) ---> (C,H,W) , since ConstantPad2d works with this config
    P = padding_size(data)  # Obtain the required padding sizes
    data = torch.nn.ConstantPad2d(P, 0)(data)  # pad the slices according to the padding sizes obtained
    train_data = torch.cat((train_data, data), 0)  # stack the data(slices) along dimension C
    file_path_mask = mask_path  +arr1[i]  # path to the mask
    label = mask_obtain(file_path_mask)  # obtain the mask
    # NOTE: Since we padded the data, mask should also have same size, so pad mask also
    label = label.permute(2, 0,
                          1)  # change the dimension (H,W,C) ---> (C,H,W) , since ConstantPad2d works with this config
    label = torch.nn.ConstantPad2d(P, 0)(label)  # pad the maks according to the padding sizes of the slices
    train_label = torch.cat((train_label, label), 0)  # stack the masks along dimension C

# remove the empty
train_data = train_data[1:train_data.size()[0], :, :]
train_label = train_label[1:train_label.size()[0], :, :]
# Determine which slices are all black
idx = []
for i in range(train_label.size()[0]):
    img_max = torch.max(train_label[i, :, :])
    if img_max == 1:
        idx.append(i)  # having white regions

# Choose data without completely black mask, i.e, having atleast some white segmented region
train_data_new = train_data[idx, :, :]
train_label_new = train_label[idx, :, :]
# (C,H,W) ---> (H,W,C) since Dataset class has this config (this part is not necessary if we change the config of
# Dataset class)
train_data_new = train_data_new.permute(1, 2, 0)
train_label_new = train_label_new.permute(1, 2, 0)

In [None]:
data_path1 = '../input/covid19-ct-scans/ct_scans/coronacases_org_002.nii'
mask_path1 = '../input/covid19-ct-scans/infection_mask/coronacases_002.nii'

In [None]:
# # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# # # validation data
# # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# # Initialize for stacking
# val_data = torch.empty((1, 572, 572))
# val_label = torch.empty((1, 572, 572))

# # Function to form train data and train label
# for i in val_data_numbers:
#     file_path = data_path +arr[i]  # Path to validation data
#     data = data_preprocess(file_path, HU_min, HU_max)  # Preprocess the data
#     data = data.permute(2, 0,
#                         1)  # change the dimension (H,W,C) ---> (C,H,W) , since ConstantPad2d works with this config
#     P = padding_size(data)  # Obtain padding sizes
#     data = torch.nn.ConstantPad2d(P, 0)(data)  # zero pad the slices
#     val_data = torch.cat((val_data, data), 0)  # stack the data along dimension C

#     file_path_mask = mask_path +arr1[i]
#     label = mask_obtain(file_path_mask)  # Path to mask
#     label = label.permute(2, 0, 1)  # change the dimension (H,W,C) ---> (C,H,W) , since ConstantPad2d works with this config
#     label = torch.nn.ConstantPad2d(P, 0)(label)  # zero pad mask
#     val_label = torch.cat((val_label, label), 0)  # stack the masks along dimension C

# # remove the empty
# val_data = val_data[1:val_data.size()[0], :, :]
# val_label = val_label[1:val_label.size()[0], :, :]

# # Determine which slices are all black
# idx = []
# for i in range(val_label.size()[0]):
#     img_max = torch.max(val_label[i, :, :])
#     if img_max == 1:
#         idx.append(i)  # having white regions

# # Choose data without completely black mask, i.e, having atleast some white segmented region
# val_data_new = val_data[idx, :, :]
# val_label_new = val_label[idx, :, :]

# # (C,H,W) ---> (H,W,C) since Dataset class has this config (this part is not necessary if we change the config of
# # Dataset class)
# val_data_new = val_data_new.permute(1, 2, 0)
# val_label_new = val_label_new.permute(1, 2, 0)

In [None]:
import torch.nn.functional as F
# from torchvision.transforms.functional import sigmoid

def DCE(inputs, targets, smooth=1):
    inputs = torch.sigmoid(inputs)

    # flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice = (2. * intersection+smooth) / (inputs.sum()+targets.sum()+smooth)
    return 1-dice


def GDCE(inputs, targets, smooth=1):
    inputs = F.sigmoid(inputs)

    # flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    # targets and corresponding predictions of class 1
    idx1 = (targets == 0)
    T1 = targets[idx1]
    P1 = inputs[idx1]

    # targets and corresponding predictions of class 2
    idx2 = (targets == 1)
    T2 = targets[idx2]
    P2 = inputs[idx2]

    # Weights for each class
    W1 = 1 / (len(T1) * len(T1))
    W2 = 1 / (len(T2) * len(T2))

    # Numerator and denominator of generalized dice loss
    NR = W1 * (T1 * P1).sum()+W2 * (T2 * P2).sum()
    DR = W1 * (T1+P1).sum()+W2 * (T2+P2).sum()

    loss = 1-(((2 * NR) + smooth)/ (DR+smooth))

    return loss

def  FocalTverskyLoss(inputs, targets, smooth=1, alpha=0.7, beta=0.3, gamma=(4/3)):
    # comment out if your model contains a sigmoid or equivalent activation layer
    inputs = torch.sigmoid(inputs)

    # flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    # True Positives, False Positives & False Negatives
    TP = (inputs * targets).sum()
    FP = ((1-targets) * inputs).sum()
    FN = (targets * (1-inputs)).sum()

    Tversky = (TP+smooth) / (TP+alpha * FP+beta * FN+smooth)
    FocalTversky = (1-Tversky) ** gamma

    return FocalTversky

In [None]:
import torch
import torch.nn as nn
from torch.nn import Module

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Two sequential convolution section of Unet
def doubleconv(inp, out):
    double_conv = nn.Sequential(
        nn.Conv2d(inp, out, kernel_size=3),
        nn.BatchNorm2d(out,track_running_stats=False),
        nn.ReLU(inplace=True),
        nn.Conv2d(out, out, kernel_size=3),
        nn.BatchNorm2d(out,track_running_stats=False),
        nn.ReLU(inplace=True)
    )
    return double_conv


# Crop the encoder feature to the size of corresponding decoder for concatenation
def crop_feat1(input_tensor, target_tensor):
    out_size = target_tensor.size()[2]
    inp_size = input_tensor.size()[2]
    delta = (inp_size-out_size) // 2
    if (inp_size-out_size) % 2 == 0:
        result = input_tensor[:, :, delta:inp_size-delta, delta:inp_size-delta]
    else:
        result = input_tensor[:, :, delta:inp_size-delta-1, delta:inp_size-delta-1]

    return result


# Spatial Channel Attention Block
class sca(Module):
    def __init__(self, inp):
        super(sca, self).__init__()
        self.c_attn_conv = nn.Sequential(nn.Conv2d(inp, inp // 16, 1, bias=False),
                                         nn.ReLU(),
                                         nn.Conv2d(inp // 16, inp, 1, bias=False)
                                         )
        self.c_sig = nn.Sigmoid()

        self.avg_ch = nn.AdaptiveAvgPool2d(1)
        self.max_ch = nn.AdaptiveMaxPool2d(1)
        self.s_attn_conv = nn.Sequential(nn.Conv2d(2, 1, kernel_size=7, padding=7 // 2, bias=False),
                                         nn.Sigmoid())

    def forward(self, input_tensor):
        # Channel Attention
        avg_ch_pool = self.avg_ch(input_tensor)
        max_ch_pool = self.max_ch(input_tensor)

        out_1 = self.c_attn_conv(avg_ch_pool)
        out_2 = self.c_attn_conv(max_ch_pool)

        c_sum = out_1+out_2
        ch_out = self.c_sig(c_sum)
        input_tensor = input_tensor * ch_out
        # Spatial Attention
        avg_pool = torch.mean(input_tensor, dim=1, keepdim=True)
        max_pool = torch.max(input_tensor, dim=1, keepdim=True)
        x = torch.cat([avg_pool, max_pool.values], dim=1)

        x = self.s_attn_conv(x)
        x = torch.mul(input_tensor, x)

        return x


# Function to apply spatial channel attention block
def spatial_channel_attn(input_tensor):
    input_tensor = input_tensor.type(torch.cuda.FloatTensor)
    inp = input_tensor.size()[1]

    sca_model = sca(inp).to(device)
    x = sca_model(input_tensor)

    return x


# Atrous spatial pyramid pooling block
class aspp(Module):
    def __init__(self, inp, out):
        super(aspp, self).__init__()
        self.aconv0 = nn.Sequential(nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=4, dilation=4, bias=False),
                                    nn.BatchNorm2d(out,track_running_stats=False),
                                    nn.ReLU(inplace=True))
        self.aconv1 = nn.Sequential(nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=6, dilation=6, bias=False),
                                    nn.BatchNorm2d(out,track_running_stats=False),
                                    nn.ReLU(inplace=True))
        self.aconv2 = nn.Sequential(nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=12, dilation=12, bias=False),
                                    nn.BatchNorm2d(out,track_running_stats=False),
                                    nn.ReLU(inplace=True))
        self.aconv3 = nn.Sequential(nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=18, dilation=18, bias=False),
                                    nn.BatchNorm2d(out,track_running_stats=False),
                                    nn.ReLU(inplace=True))
        self.aconv4 = nn.Sequential(nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=24, dilation=24, bias=False),
                                    nn.BatchNorm2d(out,track_running_stats=False),
                                    nn.ReLU(inplace=True))
        self.final_conv = nn.Sequential(
            nn.Conv2d(out * 5, inp, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
            nn.BatchNorm2d(inp,track_running_stats=False),
            nn.ReLU(inplace=True))

    def forward(self, input_tensor):
        x0 = self.aconv0(input_tensor)
        x1 = self.aconv1(input_tensor)
        x2 = self.aconv2(input_tensor)
        x3 = self.aconv3(input_tensor)
        x4 = self.aconv4(input_tensor)
        x = torch.cat((x0, x1, x2, x3, x4), dim=1)
        aspp_out = self.final_conv(x)

        return aspp_out


# Function to carry out atrous spatial pyramid pooling (In our case input and output is made to be of same size)
def atrous_spatial_pyramid_pooling(input_tensor):
    input_tensor = input_tensor.type(torch.cuda.FloatTensor)
    inp = input_tensor.size()[1]
    out = inp // 4

    asppmodel = aspp(inp, out).to(device)
    aspp_out = asppmodel(input_tensor)
    res_aspp = aspp_out+input_tensor
    return res_aspp


# Unet Model
class Unet(Module):
    def __init__(self):
        super(Unet, self).__init__()

        self.maxpool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.DownConv1 = doubleconv(1, 64)
        self.DownConv2 = doubleconv(64, 128)
        self.DownConv3 = doubleconv(128, 256)
        self.DownConv4 = doubleconv(256, 512)
        self.DownConv5 = doubleconv(512, 1024)

        self.UpTrans1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.UpTrans2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.UpTrans3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.UpTrans4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.UpConv1 = doubleconv(1024, 512)
        self.UpConv2 = doubleconv(512, 256)
        self.UpConv3 = doubleconv(256, 128)
        self.UpConv4 = doubleconv(128, 64)

        self.out1 = nn.Conv2d(64, 1, kernel_size=1)
        self.out2 = nn.Conv2d(128, 1, kernel_size=1)
        self.out3 = nn.Conv2d(256, 1, kernel_size=1)
        self.out4 = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x):
        # x=(batch size, channel, height, width)
        # encoder

        x1 = self.DownConv1(x)
        x2 = self.maxpool2x2(x1)
        x3 = self.DownConv2(x2)
        x4 = self.maxpool2x2(x3)
        x5 = self.DownConv3(x4)
        x6 = self.maxpool2x2(x5)
        x7 = self.DownConv4(x6)
        x8 = self.maxpool2x2(x7)
        x9 = self.DownConv5(x8)

        # ASPP
        x9_aspp = atrous_spatial_pyramid_pooling(x9)

        # decoder
        z1 = self.UpTrans1(spatial_channel_attn(x9_aspp))
        y1 = crop_feat1(x7, z1)
        x10 = self.UpConv1(torch.cat([y1, z1], 1))

        z2 = self.UpTrans2(spatial_channel_attn(x10))
        y2 = crop_feat1(x5, z2)
        x11 = self.UpConv2(torch.cat([y2, z2], 1))

        z3 = self.UpTrans3(spatial_channel_attn(x11))
        y3 = crop_feat1(x3, z3)
        x12 = self.UpConv3(torch.cat([y3, z3], 1))

        z4 = self.UpTrans4(spatial_channel_attn(x12))
        y4 = crop_feat1(x1, z4)
        x13 = self.UpConv4(torch.cat([y4, z4], 1))

        # ASPP
        x13_aspp = atrous_spatial_pyramid_pooling(x13)
        out1 = self.out1(x13_aspp)
        out2 = self.out2(x12)
        out3 = self.out3(x11)
        out4 = self.out4(x10)
        return out1, out2, out3, out4

# %%%%%%% model  check %%%%%%%%%

# if __name__ == "__main__":
#     image = torch.rand((2, 1, 512, 512))
#     model = Unet()
#     model.to(device)
#     image=image.to(device)
#     y = model(image)

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as F
# from data_preparation import train_data_new, train_label_new, val_data_new, val_label_new
# from model_all import Unet
from torch.optim import Adam
import torchvision.transforms as T
import warnings
from torch.cuda import amp
from tqdm import tqdm
# from loss import DCE
import matplotlib.pyplot as plt

# Ignore warnings
warnings.filterwarnings("ignore")

# Create GradScaler for mixed precision training
scaler = amp.GradScaler()

# Choose the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Function to crop mask to the size of output of U-net
def crop_feat(inp):
    inp_size = inp.size()[2]  # here it is 572 as decided during data preparation
    delta = (inp_size-388) // 2
    return inp[:, delta:inp_size-delta, delta:inp_size-delta]


# Function to augment data
# Function to augment data
def data_augmentation(image, mask):
    # Horizontal flip
    if torch.rand(1) > 0.5:
        image = F.hflip(image)
        mask = F.hflip(mask)
    # Vertical flip
    if torch.rand(1) > 0.5:
        image = F.vflip(image)
        mask = F.vflip(mask)
    # Rotate 90 degree or -90 degree
    if torch.rand(1) > 0.5:
        if torch.rand(1) > 0.5:
            image = torch.rot90(image, 1, [0, 1])
            mask = torch.rot90(mask, 1, [0, 1])
        else:
            image = torch.rot90(image, -1, [0, 1])
            mask = torch.rot90(mask, -1, [0, 1])
    return image, mask


# Dataset class
class CovidSegData(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.transform = transform
        self.labels = labels

    def __len__(self):
        return self.labels.size()[2]  # Data in (H,W,C) format obtained from data preparation

    def __getitem__(self, item):
        # Select each slice
        slices = self.data[:, :, item]
        masks = self.labels[:, :, item]
        if self.transform is not None:  # if transform is True carry out data augmentation
            slices, masks = data_augmentation(slices, masks)
        return slices, masks


# Loss function
def loss_function(preds1, preds2, preds3, preds4, GT):
    # resize groundtruth according to the predicted label dimensions
    GT1a = GT  # Ground truth already cropped to the output dimension of Unet
    resize_mask2 = T.Resize(size=(preds2.size()[2], preds2.size()[3]))
    GT2 = resize_mask2(GT)
    resize_mask3 = T.Resize(size=(preds3.size()[2], preds3.size()[3]))
    GT3 = resize_mask3(GT)
    resize_mask4 = T.Resize(size=(preds4.size()[2], preds4.size()[3]))
    GT4 = resize_mask4(GT)


    D1f = FocalTverskyLoss(preds1, GT1a, gamma=1)
    D2f = FocalTverskyLoss(preds2, GT2)
    D3f = FocalTverskyLoss(preds3, GT3)
    D4f = FocalTverskyLoss(preds4, GT4)
    # weighted loss -- deep supervision
    FTL_final = (0.5 * D1f)+(0.2 * D2f)+(0.2 * D3f)+(0.1 * D4f)

    loss_final = FTL_final

    return loss_final


# Load data and dataloaders for training and validation
train_dataset = CovidSegData(train_data_new, train_label_new, transform=True)
train_loader = DataLoader(train_dataset, batch_size=2,num_workers=2, shuffle=True, pin_memory=True)
# val_dataset = CovidSegData(val_data_new, val_label_new, transform=False)

# val_loader = DataLoader(val_dataset, batch_size=2, num_workers=2, shuffle=False, pin_memory=True)

# Initialize network
model = Unet()#torch.load('/media/hp/DATA/CT scan work/Covid19 CTseg (Seg ICASSP)/model_all_rad_100epochs.pth')
model.to(device)  # Move the model to GPU
# Optimizer and Scheduler
optimizer = Adam(model.parameters(), lr=2e-4, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
num_epochs = 50  # number of epochs
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)  # for progress bar display
    for batch_idx, (images, GT1) in loop:
        model.train()  # model in training mode
        dataset_size = 0
        running_loss = 0.0
        # Crop mask to the size of output of Unet,i.e, 388 in this case
        GT = crop_feat(GT1)
        # Get data to cuda
        images = images.to(device)
        GT = GT.to(device)

        batch_size = images.size()[0]  # Size of each batch
        optimizer.zero_grad()  # zeroing gradients
        images = torch.unsqueeze(images, 1)  # get correct input dimensions
        images = images.type(torch.cuda.FloatTensor)  # Convert input to Float tensor
        with amp.autocast():  # forward part with autocasting -- mixed precision training (MPT)
            preds1, preds2, preds3, preds4 = model(images)  # predictions
            loss = loss_function(preds1, preds2, preds3, preds4, GT)  # loss
        scaler.scale(loss).backward()  # scales loss and create scaled gradients for MPT
        # unscale the gradients of the optimizer assigned params, skips optimizer.step if Nan or Inf present
        scaler.step(optimizer)
        scaler.update()  # update scale for next iteration
        scheduler.step()  # update learning scheduler

        # Epoch loss calculation
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size
        
       
        # Update progress bar
        loop.set_description(f"Epoch : [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(), epoch_loss=epoch_loss)
    # Visualization ( the output of unet need to be passed through sigmoid and do thresholding at 0.5 to make it binary
    # Final Output of Unet
    print(epoch_loss)
    z = torch.sigmoid(preds1) > 0.5
    z1 = z[0].detach().cpu()
    z2 = z1.squeeze()
    # Intermediate output 1
    t1 = torch.sigmoid(preds2) > 0.5
    t2 = t1[0].detach().cpu().squeeze()
    # Intermediate output  2
    a1 = torch.sigmoid(preds3) > 0.5
    a2 = a1[0].detach().cpu().squeeze()
    # Intermediate output 3
    b1 = torch.sigmoid(preds4) > 0.5
    b2 = b1[0].detach().cpu().squeeze()
    # Input slice
    y1 = images[0].detach().cpu()
    y = crop_feat(y1).squeeze()
    # Mask -Ground Truth
    p1 = GT[0].detach().cpu().squeeze()

    # FIgure with all outputs, mask and input slice
    f, axarr = plt.subplots(1, 6)
    axarr[0].imshow(y, cmap='gray')
    axarr[1].imshow(p1, cmap='gray')
    axarr[2].imshow(z2, cmap='gray')
    axarr[3].imshow(t2, cmap='gray')
    axarr[4].imshow(a2, cmap='gray')
    axarr[5].imshow(b2, cmap='gray')
    plt.show()
    #

In [None]:
torch.save(model.state_dict(),'./weights5.pth')

In [None]:
import torch.nn as nn
import torch

class conv2DBatchNormRelu(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
                 bias=True,dilation=1,is_batchnorm=True):
        super(conv2DBatchNormRelu,self).__init__()
        if is_batchnorm:
            self.cbr_unit=nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
                          bias=bias,dilation=dilation),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        else:
            self.cbr_unit=nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                          bias=bias, dilation=dilation),
                nn.ReLU(inplace=True)
            )

    def forward(self,inputs):
        outputs=self.cbr_unit(inputs)
        return outputs

class segnetDown2(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(segnetDown2,self).__init__()
        self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)

    def forward(self,inputs):
        outputs=self.conv1(inputs)
        outputs=self.conv2(outputs)
        unpooled_shape=outputs.size()
        outputs,indices=self.maxpool_with_argmax(outputs)
        return outputs,indices,unpooled_shape

class segnetDown3(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(segnetDown3,self).__init__()
        self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)

    def forward(self,inputs):
        outputs=self.conv1(inputs)
        outputs=self.conv2(outputs)
        outputs=self.conv3(outputs)
        unpooled_shape=outputs.size()
        outputs,indices=self.maxpool_with_argmax(outputs)
        return outputs,indices,unpooled_shape


class segnetUp2(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(segnetUp2,self).__init__()
        self.unpool=nn.MaxUnpool2d(2,2)
        self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)

    def forward(self,inputs,indices,output_shape):
        outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
        outputs=self.conv1(outputs)
        outputs=self.conv2(outputs)
        return outputs

class segnetUp3(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(segnetUp3,self).__init__()
        self.unpool=nn.MaxUnpool2d(2,2)
        self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)

    def forward(self,inputs,indices,output_shape):
        outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
        outputs=self.conv1(outputs)
        outputs=self.conv2(outputs)
        outputs=self.conv3(outputs)
        return outputs

class segnet(nn.Module):
    def __init__(self,in_channels=1,num_classes=1):
        super(segnet,self).__init__()
        self.down1=segnetDown2(in_channels=1,out_channels=64)
        self.down2=segnetDown2(64,128)
        self.down3=segnetDown3(128,256)
        self.down4=segnetDown3(256,512)
        self.down5=segnetDown3(512,1024)

        self.up5=segnetUp3(1024,512)
        self.up4=segnetUp3(512,256)
        self.up3=segnetUp3(256,128)
        self.up2=segnetUp2(128,64)
        self.up1=segnetUp2(64,64)
        
        self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1)
        self.out1 = conv2DBatchNormRelu(64, 1, 3,1,1)
        self.out2 = conv2DBatchNormRelu(128, 1, 3,1,1)
        self.out3 = conv2DBatchNormRelu(256, 1,3,1,1)
        self.out4 = conv2DBatchNormRelu(512, 1, 3,1,1)

    def forward(self,inputs):
        down1,indices_1,unpool_shape1=self.down1(inputs)
        down2,indices_2,unpool_shape2=self.down2(down1)
        down3,indices_3,unpool_shape3=self.down3(down2)
        down4,indices_4,unpool_shape4=self.down4(down3)
        down5,indices_5,unpool_shape5=self.down5(down4)

        up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
        up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
        up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
        up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
        up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
        outputs=self.finconv(up1)
        
        out1 = self.out1(up2)
        out2 = self.out2(up3)
        out3 = self.out3(up4)
        out4 = self.out4(up5)
        return outputs,out1,out2,out3,out4



In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as F
# from data_preparation import train_data_new, train_label_new, val_data_new, val_label_new
# from model_all import Unet
from torch.optim import Adam
import torchvision.transforms as T
import warnings
from torch.cuda import amp
from tqdm import tqdm
# from loss import DCE
import matplotlib.pyplot as plt

# Ignore warnings
warnings.filterwarnings("ignore")

# Create GradScaler for mixed precision training
scaler = amp.GradScaler()

# Choose the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Function to crop mask to the size of output of U-net
def crop_feat(inp):
    inp_size = inp.size()[2]  # here it is 572 as decided during data preparation
    delta = (inp_size-388) // 2
    return inp[:, delta:inp_size-delta, delta:inp_size-delta]

def data_augmentation(image, mask):
    if torch.rand(1) > 0.5:
        image = F.hflip(image)
        mask = F.hflip(mask)
    if torch.rand(1) > 0.5:
        image = F.vflip(image)
        mask = F.vflip(mask)
    if torch.rand(1) > 0.5:
        if torch.rand(1) > 0.5:
            image = torch.rot90(image, 1, [0, 1])
            mask = torch.rot90(mask, 1, [0, 1])
        else:
            image = torch.rot90(image, -1, [0, 1])
            mask = torch.rot90(mask, -1, [0, 1])
    return image, mask

class CovidSegData(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.transform = transform
        self.labels = labels
    def __len__(self):
        return self.labels.size()[2]  # Data in (H,W,C) format obtained from data preparation
    def __getitem__(self, item):
        # Select each slice
        slices = self.data[:, :, item]
        masks = self.labels[:, :, item]
        if self.transform is not None:  # if transform is True carry out data augmentation
            slices, masks = data_augmentation(slices, masks)
        return slices, masks

def loss_function(preds1, preds2, preds3, preds4, GT):
    # resize groundtruth according to the predicted label dimensions
    GT1a = GT  # Ground truth already cropped to the output dimension of Unet
    resize_mask2 = T.Resize(size=(preds2.size()[2], preds2.size()[3]))
    GT2 = resize_mask2(GT)
    resize_mask3 = T.Resize(size=(preds3.size()[2], preds3.size()[3]))
    GT3 = resize_mask3(GT)
    resize_mask4 = T.Resize(size=(preds4.size()[2], preds4.size()[3]))
    GT4 = resize_mask4(GT)


    D1f = FocalTverskyLoss(preds1, GT1a, gamma=1)
    D2f = FocalTverskyLoss(preds2, GT2)
    D3f = FocalTverskyLoss(preds3, GT3)
    D4f = FocalTverskyLoss(preds4, GT4)
    # weighted loss -- deep supervision
    FTL_final = (0.5 * D1f)+(0.2 * D2f)+(0.2 * D3f)+(0.1 * D4f)

    loss_final = FTL_final

    return loss_final

# Load data and dataloaders for training and validation
train_dataset = CovidSegData(train_data_new, train_label_new, transform=True)
train_loader = DataLoader(train_dataset, batch_size=2,num_workers=2, shuffle=True, pin_memory=True)

# val_dataset = CovidSegData(val_data_new, val_label_new, transform=False)
# val_loader = DataLoader(val_dataset, batch_size=2, num_workers=2, shuffle=False, pin_memory=True)

# Initialize network
model = segnet() #torch.load('/media/hp/DATA/CT scan work/Covid19 CTseg (Seg ICASSP)/model_all_rad_100epochs.pth')
# model1= torch.load('kaggle/input/k-fold-weights/weights2.pth')
model.to(device)  # Move the model to GPU

# Optimizer and Scheduler
optimizer = Adam(model.parameters(), lr=2e-4, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
num_epochs = 25  # number of epochs
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)  # for progress bar display
    for batch_idx, (images, GT1) in loop:
        model.train()  # model in training mode
        dataset_size = 0
        running_loss = 0.0
        # Crop mask to the size of output of Unet,i.e, 388 in this case
#         print("before", GT1.size())
        GT = GT1
#         print("After", GT.size())
        # Get data to cuda
        images = images.to(device)
        GT = GT.to(device)

        batch_size = images.size()[0]  # Size of each batch
        optimizer.zero_grad()  # zeroing gradients
        images = torch.unsqueeze(images, 1)  # get correct input dimensions
        images = images.type(torch.cuda.FloatTensor)  # Convert input to Float tensor
        with amp.autocast():  # forward part with autocasting -- mixed precision training (MPT)
            preds1, preds2, preds3, preds4 = model(images)  # predictions
            loss = loss_function(preds1, preds2, preds3, preds4, GT)  # loss
        scaler.scale(loss).backward()  # scales loss and create scaled gradients for MPT
        # unscale the gradients of the optimizer assigned params, skips optimizer.step if Nan or Inf present
        scaler.step(optimizer)
        scaler.update()  # update scale for next iteration
        scheduler.step()  # update learning scheduler

        # Epoch loss calculation
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size
        # Update progress bar
        loop.set_description(f"Epoch : [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(), epoch_loss=epoch_loss)
    # Visualization ( the output of unet need to be passed through sigmoid and do thresholding at 0.5 to make it binary
    # Final Output of Unet
    z = torch.sigmoid(preds1) > 0.5
    z1 = z[0].detach().cpu()
    z2 = z1.squeeze()
    # Input slice
    y1 = images[0].detach().cpu()
    y = crop_feat(y1).squeeze()
    # Mask -Ground Truth
    p1 = GT[0].detach().cpu().squeeze()

    # FIgure with all outputs, mask and input slice
    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(y, cmap='gray')
    axarr[1].imshow(p1, cmap='gray')
    axarr[2].imshow(z2, cmap='gray')
    plt.show()
    #