In [None]:
from google.colab import drive
import os
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
root_dir_train = '/content/gdrive/My Drive/chaos/CHAOS_Train_Sets'
#root_dir_test = '/content/gdrive/My Drive/chaos/CHAOS_Test_Sets'


In [None]:
!pip install pydicom

Collecting pydicom
[?25l  Downloading https://files.pythonhosted.org/packages/f4/15/df16546bc59bfca390cf072d473fb2c8acd4231636f64356593a63137e55/pydicom-2.1.2-py3-none-any.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 5.7MB/s 
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.1.2


In [None]:
import os
import pydicom
import numpy as np
import cv2
from glob import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch

class CHAOS(Dataset):
    def __init__(self, root_dir, image_size = 512,mode='train'):
        self.root_dir = root_dir
        self.image_size = image_size
        self.mode = mode
        #self.root1=root1
        self.images , self.masks = self.get_paths(self.root_dir)
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.get_image(self.images[idx],self.image_size)
        mask = self.get_mask(self.masks[idx],self.image_size)
        
        image = torch.FloatTensor(image)
        mask = torch.FloatTensor(mask)
        image =image.permute(2,0,1)
        mask = mask.permute(2,0,1)
        return [image, mask]
    
    def get_paths(self,root_dir):
        # get paths of files
        dir_path = os.path.join(root_dir,'Train_Sets/CT/')
        dir_list = os.listdir(dir_path)

        # add paths into list
        image_paths =[]
        label_paths =[]
        for dirs in dir_list :
            if dirs =='notes.txt':
                continue
            # get image paths
            image_folder_path = dir_path+dirs+'/DICOM_anon/'
            images = glob(image_folder_path + '*.dcm')

            if 'IMG' in images[0] :
                images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split('-')[2][2:]))
            else :
                images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split(',')[0][2:]))

            # get label paths
            label_folder_path = dir_path+dirs+'/Ground/'
            labels = glob(label_folder_path+'*.png')
            labels = sorted(labels, key=lambda x: int(os.path.basename(x).split('_')[-1].split('.png')[0]))

            for i in images:
                image_paths.append(i)
            for l in labels:
                label_paths.append(l)
        
        if self.mode == 'train' :
            validation_ratio = int(len(image_paths) / 10 * 7)
            image_paths = image_paths[:validation_ratio]
            label_paths = label_paths[:validation_ratio]
        elif self.mode == 'val' : # validation
            validation_ratio = int(len(image_paths) / 10 * 7)
            test_ratio = int(len(image_paths)/10 * 9)
            image_paths = image_paths[validation_ratio:test_ratio]
            label_paths = label_paths[validation_ratio:test_ratio]
        elif self.mode == 'test' : # test
            test_ratio = int(len(image_paths)/10 * 9)
            image_paths = image_paths[test_ratio:]
            label_paths = label_paths[test_ratio:]

        print(self.mode + ' image length = ' , len(image_paths))
        print(self.mode + ' label length = ' , len(label_paths))
        return image_paths,label_paths
    
    def get_image(self,path,image_size):
        # get dcm image
        dcm = pydicom.read_file(path)
        arr = dcm.pixel_array

        arr = arr*dcm.RescaleSlope + dcm.RescaleIntercept
        min = int ( dcm.WindowCenter[0]) - int(dcm.WindowWidth[0]/2)
        max = int(dcm.WindowCenter[0]) + int(dcm.WindowWidth[0]/2)
        
        arr = cv2.resize(arr, dsize=(image_size, image_size), interpolation=cv2.INTER_AREA) # reshape image size
        #arr [ arr        #arr [ arr>max ] = max
        arr [arr < -200] = -200
        arr [arr > 250] = 250
        #arr = arr[np.newaxis,...] #  add axis
        arr = arr[...,np.newaxis]
        return arr
    
    def get_mask(self,path,image_size):
        # label
        label_image = cv2.imread(path)
        label_image = cv2.cvtColor(label_image,cv2.COLOR_BGR2GRAY)
        label_image = cv2.resize(label_image,dsize=(image_size,image_size),interpolation=cv2.INTER_AREA) # reshape
        label_image[ label_image > 0 ] = 1
        #label_image = label_image[np.newaxis,...]
        label_image = label_image[...,np.newaxis]
        return label_image

train_set = CHAOS(root_dir=root_dir_train,image_size=256,mode='train')
#val_set = CHAOS(root_dir=root_dir_train,root1='Train_Sets/CT/',image_size=256,mode='val')
val_set = CHAOS(root_dir=root_dir_train,image_size=256,mode='val')
#test_set = CHAOS(root_dir='../../liver_dataset',image_size = 256, mode='test')

train image length =  2011
train label length =  2011
val image length =  575
val label length =  575


In [None]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
[K     |█                               | 10kB 18.1MB/s eta 0:00:01[K     |██▏                             | 20kB 10.0MB/s eta 0:00:01[K     |███▏                            | 30kB 8.1MB/s eta 0:00:01[K     |████▎                           | 40kB 7.3MB/s eta 0:00:01[K     |█████▎                          | 51kB 4.4MB/s eta 0:00:01[K     |██████▍                         | 61kB 4.9MB/s eta 0:00:01[K     |███████▍                        | 71kB 5.3MB/s eta 0:00:01[K     |████████▌                       | 81kB 5.4MB/s eta 0:00:01[K     |█████████▌                      | 92kB 5.4MB/s eta 0:00:01[K     |██████████▋                     | 102kB 5.8MB/s eta 0:00:01[K     |███████████▊                    | 112kB 5.8MB/s eta 0:00:01[K     |████████████▊                   | 122kB 5

In [None]:
!pip install nibabel



In [None]:
import tensorflow as tf
from tensorflow import summary
import datetime,json
import string

%load_ext tensorboard
current_time=str(datetime.datetime.now().timestamp())
train_log_dir='logs/tensorboard/train'+current_time
test_log_dir='logs/tensorboard/test'+current_time
train_summary_writer=summary.create_file_writer(train_log_dir)
test_summary_writer=summary.create_file_writer(test_log_dir)

In [None]:
'''
@tf.function
def my_func(step,loss):
  with train_summary_writer.as_default():
    tf.summary.scalar("loss",loss,step)
'''

'\n@tf.function\ndef my_func(step,loss):\n  with train_summary_writer.as_default():\n    tf.summary.scalar("loss",loss,step)\n'

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
#folder_path_fully='/content/gdrive/My Drive/Brats2018/HGG'
#test_path_fully='/content/gdrive/My Drive/Brats2018/HGG_test'

#folder_path_under='/content/gdrive/My Drive/Dataset_MRI/Under'
train_x_path='/content/gdrive/My Drive/brats_data1/train_x'
train_y_path='/content/gdrive/My Drive/brats_data1/train_y'
test_x_path='/content/gdrive/My Drive/brats_data1/test_x'
test_y_path='/content/gdrive/My Drive/brats_data1/test_y'

In [None]:
import os
from scipy.ndimage import zoom  # For resizing
#subfolders=[os.path.join(folder_path_fully,folder) for folder in os.listdir(folder_path_fully) ]

#print(subfolders)


In [None]:
import numpy as np
import nibabel as nib








In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#import torch_geometric.nn as pyg_nn
#import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

#import networkx as nx
import numpy as np
import torch
import torch.optim as optim
#from torchvision.datasets import MNIST
import multiprocessing

import numpy as np
import scipy as sp
from skimage.segmentation import slic, mark_boundaries
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
NP_TORCH_FLOAT_DTYPE = np.float32
NP_TORCH_LONG_DTYPE = np.int64


In [None]:
import torch
#import networkx as nx

%matplotlib inline

In [None]:

def dice_loss(input, target):
    """ This is a normal dice loss function for binary segmentation.
    Args:
        input: output of the segmentation network
        target: ground truth label
    Returns:
        dice score
    """
    smooth = 1
    #input = F.softmax(input, dim=1)
    # input = torch.sigmoid(input) #for binary
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    union = iflat.sum() + tflat.sum()
    dice_score = (2.*intersection + smooth)/(union + smooth)
    return 1-dice_score

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        #predict=F.softmax(predict, dim=1)
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)

        num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index
        self.dice = BinaryDiceLoss(**self.kwargs)

    def forward(self, predict, target):
        assert predict.shape == target.shape, 'predict & target shape do not match'

        total_loss = 0
        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#import torch_geometric.nn as pyg_nn
#import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

#import networkx as nx
import numpy as np
import torch
import torch.optim as optim
#from torchvision.datasets import MNIST
import multiprocessing

import numpy as np
import scipy as sp
from skimage.segmentation import slic, mark_boundaries
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
NP_TORCH_FLOAT_DTYPE = np.float32
NP_TORCH_LONG_DTYPE = np.int64


In [None]:
import torch
#import networkx as nx

%matplotlib inline

In [None]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(logdir='/content/runs')
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
import pdb
import torch
import torch.nn as nn

from torch.nn.functional import softmax


def conv3x3(in_c, out_c, kernel_size=3, stride=1, padding=1,
            bias=True, useBN=False, drop_rate=0):
    if useBN:
        return nn.Sequential(
                nn.ReflectionPad2d(padding),
                nn.Conv2d(in_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.BatchNorm2d(out_c),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(inplace=True),
                nn.ReflectionPad2d(padding),
                nn.Conv2d(out_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.BatchNorm2d(out_c),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(inplace=True))
    else:
        return nn.Sequential(
                nn.ReflectionPad2d(padding),
                nn.Conv2d(in_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(),
                nn.ReflectionPad2d(padding),
                nn.Conv2d(out_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU())


def upsample(in_c, out_c, bias=True, drop_rate=0):
	return nn.Sequential(
        #nn.ReflectionPad2d(1),
		nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=bias),
        nn.Dropout2d(p=drop_rate),
        nn.ReLU())


class UNet(nn.Module):
    def __init__(self, in_channel=1, class_num=1, useBN=False, drop_rate=0):
        super(UNet, self).__init__()
        self.output_dim = class_num
        self.drop_rate = drop_rate
        self.conv1 = conv3x3(in_channel, 64, useBN=useBN, drop_rate=self.drop_rate)
        self.conv2 = conv3x3(64, 128, useBN=useBN, drop_rate=self.drop_rate)
        self.conv3 = conv3x3(128, 256, useBN=useBN, drop_rate=self.drop_rate)
        self.conv4 = conv3x3(256, 512, useBN=useBN, drop_rate=self.drop_rate)
        self.conv5 = conv3x3(512, 1024, useBN=useBN, drop_rate=self.drop_rate)

        self.conv4m = conv3x3(1024, 512, useBN=useBN, drop_rate=self.drop_rate)
        self.conv3m = conv3x3(512, 256, useBN=useBN, drop_rate=self.drop_rate)
        self.conv2m = conv3x3(256, 128, useBN=useBN, drop_rate=self.drop_rate)
        self.conv1m = conv3x3(128, 64, useBN=useBN, drop_rate=self.drop_rate)

        self.conv0  = nn.Sequential(nn.ReflectionPad2d(1),
                                    nn.Conv2d(64, self.output_dim, 3, 1, 0),
                                    nn.Dropout2d(p=self.drop_rate),
                                    nn.ReLU())
        self.max_pool = nn.MaxPool2d(2)

        self.upsample54 = upsample(1024, 512, drop_rate=self.drop_rate)
        self.upsample43 = upsample(512, 256, drop_rate=self.drop_rate)
        self.upsample32 = upsample(256, 128, drop_rate=self.drop_rate)
        self.upsample21 = upsample(128, 64, drop_rate=self.drop_rate)

		## weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    m.bias.data.zero_()
                nn.init.normal_(m.weight.data, mean=0, std=0.01)

    def forward(self, x):
        output1 = self.conv1(x)
        output2 = self.conv2(self.max_pool(output1))
        output3 = self.conv3(self.max_pool(output2))
        output4 = self.conv4(self.max_pool(output3))
        output5 = self.conv5(self.max_pool(output4))

        conv5m_out = torch.cat((self.upsample54(output5), output4), 1)
        conv4m_out = self.conv4m(conv5m_out)
        conv4m_out = torch.cat((self.upsample43(output4), output3), 1)
        conv3m_out = self.conv3m(conv4m_out)

        conv3m_out = torch.cat((self.upsample32(output3), output2), 1)
        conv2m_out = self.conv2m(conv3m_out)

        conv2m_out = torch.cat((self.upsample21(output2), output1), 1)
        conv1m_out = self.conv1m(conv2m_out)

        final = self.conv0(conv1m_out)
        #final = softmax(final, dim=1)
        final=torch.sigmoid(final)
        return final
a=torch.zeros(1,1,192,192)
b=UNet()(a)
print(b.shape)

torch.Size([1, 1, 192, 192])


In [None]:
import math
import torch
from torch.optim.optimizer import Optimizer, required

class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
            
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    #N_sma, step_size =buffered[1], buffered[2]
                    
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                # more conservative since it's an approximated value
                if N_sma >= 5:            
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)

                p.data.copy_(p_data_fp32)

        return loss
    

In [None]:
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
    """
    Args:
        pr (torch.Tensor): A list of predicted elements
        gt (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: IoU (Jaccard) score
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = torch.nn.Sigmoid()
    elif activation == "softmax2d":
             activation_fn = torch.nn.Softmax2d()
    else:
        raise NotImplementedError(
            "Activation implemented for sigmoid and softmax2d"
        )

    pr = activation_fn(pr)

    if threshold is not None:
        pr = (pr > threshold).float()


    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp
    fn = torch.sum(gt) - tp

    score = ((1 + beta ** 2) * tp + eps) \
            / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)

    return score
class DiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid'):
        super().__init__()
        self.activation = activation
        self.eps = eps

    def forward(self, y_pr, y_gt):
        return 1 - f_score(y_pr, y_gt, beta=1., 
                           eps=self.eps, threshold=None, 
                           activation=self.activation)
class BCEDiceLoss(DiceLoss):
    __name__ = 'bce_dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid', lambda_dice=1.0, lambda_bce=1.0):
        super().__init__(eps, activation)
        if activation == None:
            self.bce = nn.BCELoss(reduction='mean')
        else:
            self.bce = nn.BCEWithLogitsLoss(reduction='mean')
        self.lambda_dice=lambda_dice
        self.lambda_bce=lambda_bce

    def forward(self, y_pr, y_gt):
        dice = super().forward(y_pr, y_gt)
        bce = self.bce(y_pr, y_gt)
        return (self.lambda_dice*dice) + (self.lambda_bce* bce)

def dice_no_threshold(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,
    #activation: str = "Sigmoid"
):

    #activation_fn = get_activation_fn(activation)
    #outputs = activation_fn(outputs)
    outputs = torch.sigmoid(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    dice = 2 * intersection / (union + eps)

    return dice



In [None]:
class double_conv(nn.Module):
    """(conv => BN => ReLU) * 2"""

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(nn.MaxPool2d(2), double_conv(in_ch, out_ch))

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
            self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256, False)
        self.up2 = up(512, 128, False)
        self.up3 = up(256, 64, False)
        self.up4 = up(128, 64, False)
        self.outc = outconv(64, n_classes)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return torch.sigmoid(x)
model = UNet(n_channels=1, n_classes=1).float()


In [None]:

from torch.utils.data import TensorDataset, DataLoader
import tqdm
from tqdm.auto import tqdm as tq
train_on_gpu = torch.cuda.is_available()
 
#model = UNet(class_num=1,useBN=True)
#model=Unet(class_num=4)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" )


#my_dataset = BRATSDataset2D(input_path=train_x_path,gt_path=train_y_path,binarise=True)
my_dataloader = DataLoader(train_set,batch_size=10, shuffle=True)
#test_dataset = BRATSDataset2D(input_path=test_x_path,gt_path=test_y_path,binarise=True)
test_dataloader = DataLoader(val_set,batch_size=10,shuffle=True)
writer = SummaryWriter()

model.cuda()

#optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = BCEDiceLoss(eps=1.0, activation=None)
optimizer = RAdam(model.parameters(), lr = 0.005)
current_lr = [param_group['lr'] for param_group in optimizer.param_groups][0]
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=2, cooldown=2)

batch_size=4
trainingloss=[]
testloss=[]
iters=[]

train_input=[]
test_input=[]
train_pred=[]
test_pred=[]
train_gt=[]
test_gt=[]
train_iter_epoch=[]
test_iter_epoch=[]
test_losses=[]
#loss_fn=BinaryDiceLoss()
#testloss=BinaryDiceLoss()

n_epochs=50
train_loss_list = []
valid_loss_list = []
dice_score_list = []
lr_rate_list = []
valid_loss_min = np.Inf # track change in validation loss
for epoch in range(1, n_epochs+1):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    dice_score = 0.0
    ###################
    # train the model #
    ###################
    model.train()
    bar = tq(my_dataloader, postfix={"train_loss":0.0})
    for data, target in bar:
        # move tensors to GPU if CUDA is available
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the batch loss
        loss = criterion(output, target)
        #print(loss)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)
        bar.set_postfix(ordered_dict={"train_loss":loss.item()})
        output = model(data)
        # calculate the batch loss
        loss = criterion(output, target)
        #print(loss)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)
        bar.set_postfix(ordered_dict={"train_loss":loss.item()})
    writer.add_scalar('Loss/train',train_loss/len(my_dataloader.dataset), epoch)
    if(epoch>=30):
      train_iter_epoch.append(epoch)
      train_input.append(data)
      train_pred.append(output)
      train_gt.append(target)
    ######################    
    # validate the model #
    ######################
    model.eval()
    del data, target
    with torch.no_grad():
        bar = tq(test_dataloader, postfix={"valid_loss":0.0, "dice_score":0.0})
        for data, target in bar:
            # move tensors to GPU if CUDA is available
            if train_on_gpu:
                data, target = data.cuda(), target.cuda()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # update average validation loss 
            valid_loss += loss.item()*data.size(0)
            dice_cof = dice_no_threshold(output.cpu(), target.cpu()).item()
            dice_score +=  dice_cof * data.size(0)
            bar.set_postfix(ordered_dict={"valid_loss":loss.item(), "dice_score":dice_cof})
    train_loss = train_loss/len(my_dataloader.dataset)
    valid_loss = valid_loss/len(test_dataloader.dataset)
    dice_score = dice_score/len(test_dataloader.dataset)
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    dice_score_list.append(dice_score)
    writer.add_scalar('Loss/test',valid_loss/len(test_dataloader.dataset) , epoch)
    lr_rate_list.append([param_group['lr'] for param_group in optimizer.param_groups])

    if(epoch>=30):
      test_iter_epoch.append(epoch)
      test_input.append(data)
      test_pred.append(output)
      test_gt.append(target)
    
         
    
    # print training/validation statistics 
    print('Epoch: {}  Training Loss: {:.6f}  Validation Loss: {:.6f} Dice Score: {:.6f}'.format(
        epoch, train_loss, valid_loss, dice_score))
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), 'model_cifar.pt')
        valid_loss_min = valid_loss
    
    scheduler.step(valid_loss)






'''




for epoch in range(50):
    # Training
  model.train()  
  for n,(local_batch, local_labels) in enumerate(my_dataloader):
        losses=[]
        if torch.min(local_batch) == 0 and torch.max(local_batch) == 0:
          continue
    
        optimizer.zero_grad()
        #local_batch=data_transform(local_batch)
        local_batch=local_batch.cuda().float()
        local_label=local_labels.cuda().float()
        #print(local_batch.shape)
        #print(local_label.shape)
        #local_batch=local_batch.permute(0,3,1,2)
        #local_label=torch.unsqueeze(local_label,0)
        #print(local_batch.shape)
        i = model(local_batch)
        
   
        loss=criterion(i,local_label)
        #print(i.shape)
        #print(local_label.shape)
        #loss=loss_fn(i,local_label)
       
        loss.backward()
        optimizer.step()
        iters.append(n)
        losses.append((float(loss)))
  if(epoch>=30):
    train_iter_epoch.append(epoch)
    train_input.append(local_batch)
    train_pred.append(i)
    train_gt.append(local_label)



  #for n, (imgs, labels) in enumerate(train_loader):      
  print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, torch.mean(torch.FloatTensor(losses))))
  writer.add_scalar('Loss/train', torch.mean(torch.FloatTensor(losses)), epoch)

    
  
  with torch.no_grad():
        model.eval()
        
        test_losses=[]
        for images,labels in test_dataloader:
          if torch.min(images) == 0 and torch.max(labels) == 0:
            continue
          #images=data_transform(images)

          images=images.cuda().float()
          #print(images.shape)

          #images=images.permute(0,3,1,2)
          #print(images.shape)
         
          outputs = model(images)
        
          
          
          #labels.permute(0,3,1,2)
          #labels=torch.unsqueeze(labels,0)
          test_loss = criterion(outputs, labels.cuda())
          
          #test_loss=testloss(outputs, labels.cuda())
          test_losses.append(float(test_loss))
  writer.add_scalar('Loss/test', (torch.mean(torch.FloatTensor(test_losses))).item(), epoch)
  if(epoch>=30):
    test_iter_epoch.append(epoch)
    test_input.append(images)
    test_pred.append(outputs)
    test_gt.append(labels)
         
                  
  print('test loss',(torch.mean(torch.FloatTensor(test_losses))).item())
  
  writer.close()

'''

HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:1005.)





HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 1  Training Loss: 0.780906  Validation Loss: 0.671730 Dice Score: 0.105353
Validation loss decreased (inf --> 0.671730).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 2  Training Loss: 0.193517  Validation Loss: 0.815705 Dice Score: 0.103927


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 3  Training Loss: 0.189656  Validation Loss: 0.189596 Dice Score: 0.123535
Validation loss decreased (0.671730 --> 0.189596).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 4  Training Loss: 0.121629  Validation Loss: 0.191788 Dice Score: 0.123603


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 5  Training Loss: 0.106404  Validation Loss: 0.148517 Dice Score: 0.129490
Validation loss decreased (0.189596 --> 0.148517).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 6  Training Loss: 0.188912  Validation Loss: 0.153506 Dice Score: 0.126170


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 7  Training Loss: 0.128276  Validation Loss: 0.118263 Dice Score: 0.126918
Validation loss decreased (0.148517 --> 0.118263).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 8  Training Loss: 0.087233  Validation Loss: 0.102888 Dice Score: 0.128357
Validation loss decreased (0.118263 --> 0.102888).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 9  Training Loss: 0.087040  Validation Loss: 0.084274 Dice Score: 0.129629
Validation loss decreased (0.102888 --> 0.084274).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 10  Training Loss: 0.069405  Validation Loss: 0.073734 Dice Score: 0.129576
Validation loss decreased (0.084274 --> 0.073734).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 11  Training Loss: 0.065852  Validation Loss: 0.072382 Dice Score: 0.129466
Validation loss decreased (0.073734 --> 0.072382).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 12  Training Loss: 0.080355  Validation Loss: 0.076290 Dice Score: 0.129695


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 13  Training Loss: 0.060726  Validation Loss: 0.086784 Dice Score: 0.129162


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 14  Training Loss: 0.070215  Validation Loss: 0.084855 Dice Score: 0.129318


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 15  Training Loss: 0.051747  Validation Loss: 0.068301 Dice Score: 0.130252
Validation loss decreased (0.072382 --> 0.068301).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 16  Training Loss: 0.048374  Validation Loss: 0.070211 Dice Score: 0.129792


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 17  Training Loss: 0.046557  Validation Loss: 0.067841 Dice Score: 0.130234
Validation loss decreased (0.068301 --> 0.067841).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 18  Training Loss: 0.045047  Validation Loss: 0.075060 Dice Score: 0.129788


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 19  Training Loss: 0.043707  Validation Loss: 0.074142 Dice Score: 0.129996


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 20  Training Loss: 0.042514  Validation Loss: 0.068267 Dice Score: 0.129922


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 21  Training Loss: 0.043821  Validation Loss: 0.076779 Dice Score: 0.129806


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 22  Training Loss: 0.040180  Validation Loss: 0.075940 Dice Score: 0.129494


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 23  Training Loss: 0.039341  Validation Loss: 0.072284 Dice Score: 0.129739


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 24  Training Loss: 0.038820  Validation Loss: 0.070855 Dice Score: 0.129703


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 25  Training Loss: 0.039216  Validation Loss: 0.074321 Dice Score: 0.130068


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 26  Training Loss: 0.037732  Validation Loss: 0.082714 Dice Score: 0.129530


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 27  Training Loss: 0.037550  Validation Loss: 0.076029 Dice Score: 0.129517


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 28  Training Loss: 0.037664  Validation Loss: 0.079174 Dice Score: 0.129787


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 29  Training Loss: 0.037198  Validation Loss: 0.075754 Dice Score: 0.129463


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 30  Training Loss: 0.037441  Validation Loss: 0.076987 Dice Score: 0.130008


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 31  Training Loss: 0.037052  Validation Loss: 0.075784 Dice Score: 0.130242


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 32  Training Loss: 0.037002  Validation Loss: 0.075936 Dice Score: 0.129772


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 33  Training Loss: 0.036744  Validation Loss: 0.076325 Dice Score: 0.130038


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 34  Training Loss: 0.037336  Validation Loss: 0.073147 Dice Score: 0.129311


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 35  Training Loss: 0.036982  Validation Loss: 0.079140 Dice Score: 0.129744


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 36  Training Loss: 0.036763  Validation Loss: 0.093190 Dice Score: 0.129091


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 37  Training Loss: 0.036620  Validation Loss: 0.074147 Dice Score: 0.129744


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 38  Training Loss: 0.036862  Validation Loss: 0.101690 Dice Score: 0.128658


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 39  Training Loss: 0.037137  Validation Loss: 0.075698 Dice Score: 0.129604


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 40  Training Loss: 0.037588  Validation Loss: 0.076979 Dice Score: 0.129737


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 41  Training Loss: 0.037348  Validation Loss: 0.075097 Dice Score: 0.129580


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 42  Training Loss: 0.036875  Validation Loss: 0.073612 Dice Score: 0.130086


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 43  Training Loss: 0.036931  Validation Loss: 0.075069 Dice Score: 0.129557


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 44  Training Loss: 0.036529  Validation Loss: 0.078702 Dice Score: 0.129858


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 45  Training Loss: 0.036566  Validation Loss: 0.072946 Dice Score: 0.129472


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 46  Training Loss: 0.037138  Validation Loss: 0.077242 Dice Score: 0.129786


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 47  Training Loss: 0.036591  Validation Loss: 0.075847 Dice Score: 0.129436


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 48  Training Loss: 0.036823  Validation Loss: 0.075850 Dice Score: 0.129204


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 49  Training Loss: 0.036887  Validation Loss: 0.075906 Dice Score: 0.130221


HBox(children=(FloatProgress(value=0.0, max=202.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))


Epoch: 50  Training Loss: 0.036667  Validation Loss: 0.074679 Dice Score: 0.130010


"\n\n\n\n\nfor epoch in range(50):\n    # Training\n  model.train()  \n  for n,(local_batch, local_labels) in enumerate(my_dataloader):\n        losses=[]\n        if torch.min(local_batch) == 0 and torch.max(local_batch) == 0:\n          continue\n    \n        optimizer.zero_grad()\n        #local_batch=data_transform(local_batch)\n        local_batch=local_batch.cuda().float()\n        local_label=local_labels.cuda().float()\n        #print(local_batch.shape)\n        #print(local_label.shape)\n        #local_batch=local_batch.permute(0,3,1,2)\n        #local_label=torch.unsqueeze(local_label,0)\n        #print(local_batch.shape)\n        i = model(local_batch)\n        \n   \n        loss=criterion(i,local_label)\n        #print(i.shape)\n        #print(local_label.shape)\n        #loss=loss_fn(i,local_label)\n       \n        loss.backward()\n        optimizer.step()\n        iters.append(n)\n        losses.append((float(loss)))\n  if(epoch>=30):\n    train_iter_epoch.append(epoch

In [None]:
plt.figure(figsize=(10,10))
plt.plot([i[0] for i in lr_rate_list])
plt.ylabel('learing rate during training', fontsize=22)
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_loss_list,  marker='o', label="Training Loss")
plt.plot(valid_loss_list,  marker='o', label="Validation Loss")
plt.ylabel('loss', fontsize=22)
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(dice_score_list)
plt.ylabel('Dice score')
plt.show()

In [None]:
# load best model
model.load_state_dict(torch.load('model_cifar.pt'))
model.eval();

In [None]:
valid_masks = []
count = 0
tr = min(len(valid_ids)*4, 2000)
probabilities = np.zeros((tr, 350, 525), dtype = np.float32)
for data, target in tq(valid_loader):
    if train_on_gpu:
        data = data.cuda()
    target = target.cpu().detach().numpy()
    outpu = model(data).cpu().detach().numpy()
    for p in range(data.shape[0]):
        output, mask = outpu[p], target[p]
        for m in mask:
            valid_masks.append(resize_it(m))
        for probability in output:
            probabilities[count, :, :] = resize_it(probability)
            count += 1
        if count >= tr - 1:
            break
    if count >= tr - 1:
        break

In [None]:
torch.save(model, '/content/gdrive/My Drive/Brats2018/resunet_weights1.pt')

In [None]:
%tensorboard --logdir /content/runs

In [None]:
#!pip install tensorboard

In [None]:
#%tensorboard --logdir=runs

In [None]:
'''
plt.plot(np.arange(50),trainingloss,label='Training loss')
#plt.show()
#plt.plot(np.arange(50),testloss,label='Test loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
'''
'''
plt.plot(np.arange(100),trainingloss,label='Training loss')
#plt.show()
plt.plot(np.arange(100),testloss,label='Test loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
#test_loss
'''


In [None]:
def plot(ori_imgs, recon_imgs, masked_imgs,masked_imgs1, figsize = (10, 10)):
    IMG_SIZE=192
    print()
    fig, axes = plt.subplots(1, 4, figsize = figsize)
    fig.subplots_adjust(hspace=0.4, wspace = 0.4, right =0.7)
    
    axes[0].imshow(np.reshape(ori_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[0].set_xlabel('channel 0')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    axes[1].imshow(np.reshape(recon_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[1].set_xlabel('channel 2')
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    
    axes[2].imshow(np.reshape(masked_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[2].set_xlabel('Channel 3')
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    axes[3].imshow(np.reshape(masked_imgs1,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[3].set_xlabel('Channel 4')
    axes[3].set_xticks([])
    axes[3].set_yticks([])
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_pred(ori_imgs, recon_imgs, masked_imgs, figsize = (10, 10)):
    IMG_SIZE=192
    print()
    fig, axes = plt.subplots(1, 3, figsize = figsize)
    fig.subplots_adjust(hspace=0.4, wspace = 0.4, right =0.7)
    
    axes[0].imshow(np.reshape(ori_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[0].set_xlabel('channel 0')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    axes[1].imshow(np.reshape(recon_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[1].set_xlabel('channel 2')
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    
    axes[2].imshow(np.reshape(masked_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[2].set_xlabel('Channel 3')
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    
    
    plt.tight_layout()
    plt.show()

In [None]:
print(test_input[0].shape)

In [None]:
def dice(im1, im2):
    """
    Computes the Dice coefficient, a measure of set similarity.
    Parameters
    ----------
    im1 : array-like, bool
        Any array of arbitrary size. If not boolean, will be converted.
    im2 : array-like, bool
        Any other array of identical size. If not boolean, will be converted.
    Returns
    -------
    dice : float
        Dice coefficient as a float on range [0,1].
        Maximum similarity = 1
        No similarity = 0
        
    Notes
    -----
    The order of inputs for `dice` is irrelevant. The result will be
    identical if `im1` and `im2` are switched.
    """
    im1 = np.asarray(im1).astype(np.bool)
    im2 = np.asarray(im2).astype(np.bool)

    if im1.shape != im2.shape:
        raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")

    # Compute Dice coefficient
    intersection = np.logical_and(im1, im2)

    return 2. * intersection.sum() / (im1.sum() + im2.sum())

In [None]:
import logging
class Logger:
    def __init__(self, model_name, logger_path):
        self.logger = logging.getLogger(model_name)
        hdlr = logging.FileHandler(logger_path)
        formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
        hdlr.setFormatter(formatter)
        self.logger.addHandler(hdlr)
        self.logger.setLevel(logging.DEBUG)

    def get_logger(self):
      return self.logger

from skimage.filters import threshold_otsu
def create_mask(predicted):
    """
    Method find the difference between the 2 images and overlay colors
    predicted, label : slices , 2D tensor
    """
    predicted = predicted.cpu().data.numpy()

    
    thresh = threshold_otsu(predicted)
    predicted_binary = predicted > thresh
    #except Exception as error:
    #    logger.exception(error)
    predicted_binary = predicted > 0.5  # exception will be thrown only if input image seems to have just one color 1.0.

    # Define colors
    black = np.array([0, 0, 0], dtype=np.uint8)  # background
    white = np.array([255, 255, 255], dtype=np.uint8)  # prediction_output

    # Make RGB array, pre-filled with black(background)
    rgb_image = np.zeros((predicted_binary.shape[0], predicted_binary.shape[1], 3), dtype=np.uint8) + black

    # Overwrite with red where threshold exceeded, i.e. where mask is True
    rgb_image[predicted_binary] = white

    return rgb_image

def create_diff_mask(predicted, label):
    """
    Method find the difference between the 2 images(predicted being grescale, label being binary) and overlay colors
    predicted, label : slices , 2D tensor
    """
    label = label.cpu().data.numpy()
    predicted = predicted.cpu().data.numpy()

    
    thresh = threshold_otsu(predicted)
    predicted_binary = predicted > thresh
    #except Exception as error:
    #    logger.exception(error)
    predicted_binary = predicted > 0.5  # exception will be thrown only if input image seems to have just one color 1.0.

    # fig, axes = plt.subplots(ncols=3, figsize=(8, 2.5))
    # ax = axes.ravel()
    # ax[0] = plt.subplot(1, 3, 1)
    # ax[1] = plt.subplot(1, 3, 2)
    # ax[2] = plt.subplot(1, 3, 3, sharex=ax[0], sharey=ax[0])
    #
    # ax[0].imshow(predicted, cmap=plt.cm.gray)
    # ax[0].set_title('Original')
    # ax[0].axis('off')
    #
    # ax[1].hist(predicted.ravel(), bins=256)
    # ax[1].set_title('Histogram')
    # ax[1].axvline(thresh, color='r')
    #
    # ax[2].imshow(predicted_binary, cmap=plt.cm.gray)
    # ax[2].set_title('Thresholded')
    # ax[2].axis('off')
    #
    # plt.show()

    diff1 = np.subtract(label, predicted_binary) > 0
    diff2 = np.subtract(predicted_binary, label) > 0

    # Define colors
    red = np.array([255, 0, 0], dtype=np.uint8)  # under_detected
    green = np.array([0, 255, 0], dtype=np.uint8)  # over_detected
    black = np.array([0, 0, 0], dtype=np.uint8)  # background
    white = np.array([255, 255, 255], dtype=np.uint8)  # prediction_output
    blue = np.array([0, 0, 255], dtype=np.uint8) # over_detected
    yellow = np.array([255, 255, 0], dtype=np.uint8)  # under_detected

    # Make RGB array, pre-filled with black(background)
    rgb_image = np.zeros((predicted_binary.shape[0], predicted_binary.shape[1], 3), dtype=np.uint8) + black

    # Overwrite with red where threshold exceeded, i.e. where mask is True
    rgb_image[predicted_binary] = white
    rgb_image[diff1] = red
    rgb_image[diff2] = blue

    return rgb_image

def create_diff_mask_binary(predicted, label):
    """
    Method find the difference between the 2 binary images and overlay colors
    predicted, label : slices , 2D tensor
    """
    predicted_label = label#.cpu().data.numpy()
    predicted_binary = predicted#.cpu().data.numpy()

    diff1 = np.subtract(predicted_label, predicted_binary) > 0
    diff2 = np.subtract(predicted_binary, predicted_label) > 0

    predicted_binary = predicted_binary > 0

    # Define colors
    red = np.array([255, 0, 0], dtype=np.uint8)  # under_detected
    green = np.array([0, 255, 0], dtype=np.uint8)  # over_detected
    black = np.array([0, 0, 0], dtype=np.uint8)  # background
    white = np.array([255, 255, 255], dtype=np.uint8)  # prediction_output
    blue = np.array([0, 0, 255], dtype=np.uint8) # over_detected
    yellow = np.array([255, 255, 0], dtype=np.uint8)  # under_detected

    # Make RGB array, pre-filled with black(background)
    rgb_image = np.zeros((predicted_binary.shape[0], predicted_binary.shape[1], 3), dtype=np.uint8) + black

    # Overwrite with red where threshold exceeded, i.e. where mask is True

    #predicted_binary=np.transpose(predicted_binary, (1, 0, 2))
    rgb_image[predicted_binary] = white
    rgb_image[diff1] = red
    rgb_image[diff2] = blue
    return rgb_image



def show_diff(label, predicted, diff_image):
    '''
   Method to display the differences between label, predicted and diff_image
   '''
    fig, axes = plt.subplots(ncols=3, figsize=(8, 2.5))
    ax = axes
    ax[0] = plt.subplot(1, 3, 1)
    ax[1] = plt.subplot(1, 3, 2)
    ax[2] = plt.subplot(1, 3, 3, sharex=ax[0], sharey=ax[0])

    ax[0].imshow(label, cmap=plt.cm.gray)
    ax[0].set_title('GroundTruth')
    ax[0].axis('off')

    ax[1].imshow(predicted, cmap=plt.cm.gray)
    ax[1].set_title('Predicted')
    ax[0].axis('off')

    ax[2].imshow(diff_image, cmap=plt.cm.gray)
    ax[2].set_title('Difference image')
    ax[2].axis('off')

    plt.show()

In [None]:
def plot_bi(ori_imgs, recon_imgs, figsize = (5, 5)):
    IMG_SIZE=192
    print()
    fig, axes = plt.subplots(1, 2, figsize = figsize)
    fig.subplots_adjust(hspace=0.4, wspace = 0.4, right =0.7)
    
    axes[0].imshow(np.reshape(ori_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[0].set_xlabel('pred')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    axes[1].imshow(recon_imgs)#np.reshape(recon_imgs,[IMG_SIZE, IMG_SIZE]))
    axes[1].set_xlabel('mask')
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    plt.tight_layout()
    plt.show()
def plot_bi1(ori_imgs, recon_imgs, recon_imgs1,figsize = (5, 5)):
    IMG_SIZE=256
    print()
    fig, axes = plt.subplots(1, 3, figsize = figsize)
    fig.subplots_adjust(hspace=0.4, wspace = 0.4, right =0.7)
    
    axes[0].imshow(np.reshape(ori_imgs,[IMG_SIZE, IMG_SIZE]), cmap = 'gray')
    axes[0].set_xlabel('input')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    axes[1].imshow(recon_imgs)#np.reshape(recon_imgs,[IMG_SIZE, IMG_SIZE]))
    axes[1].set_xlabel('pred')
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    axes[2].imshow(recon_imgs1)#np.reshape(recon_imgs,[IMG_SIZE, IMG_SIZE]))
    axes[2].set_xlabel('mask')
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    plt.tight_layout()
    plt.show()

In [None]:
for each in range(len(test_iter_epoch)):
  inpti=test_input[each]
  predti=test_pred[each]
  gtti=test_gt[each]
  print('epoch no :',test_iter_epoch[each])
  
  #print('input',inpti.shape)
  #print('pred',predti.shape)
  #print('gt',gtti.shape)
  inpti=inpti.squeeze().cpu()
  predti=predti.squeeze().cpu()
  gtti=gtti.squeeze().cpu()
  predti=predti.squeeze().cpu()
  print('dice_loss:',dice(predti,gtti))
  #print(predt1.shape)
  predti=predti.numpy()
  #print(predti.shape)
  
  thresh = threshold_otsu(predti)
  predti[predti<=thresh] = 0
  predti[predti>thresh] = 1
  #plot(inpti[0],inpti[1],inpti[2],inpti[3])
  colur_maski=create_diff_mask_binary(predti, gtti)
  
  plot_bi1(inpti,predti,colur_maski)
  

In [None]:
for each in range(len(train_iter_epoch)):
  inptr=train_input[each]
  predtr=train_pred[each]
  gttr=train_gt[each]
  print('epoch no:',train_iter_epoch[each])
  
 
  inptr=inptr.squeeze().cpu()
  predtr=predtr.squeeze().cpu()
  predtr=predtr.detach().numpy()
  
  gttr=gttr.squeeze().cpu()
  print('dice_loss :',dice(predtr,gttr))
  
  thresh1 = threshold_otsu(predtr)
  predtr[predtr<=thresh1] = 0
  predtr[predtr>thresh1] = 1
  #plot(inptr[0],inptr[1],inptr[2],inptr[3])
  colur_masktr=create_diff_mask_binary(predtr, gttr)
  
  plot_bi1(inptr,predtr,colur_masktr)
  


In [None]:
print(len(test_input))
print(len(test_pred))
print(len(test_gt))

In [None]:
for each1 in test_pred:
  print(each1.shape)