In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
import nibabel as nib
import os 
import glob
import  matplotlib.pyplot as plt
from skimage.transform import rescale, resize, downscale_local_mean
import random
import math
from torch.utils.data import DataLoader
from utils import *

In [None]:
case = '/home/arisa/Documents/qazal/case'
sli_path = '/home/arisa/Documents/qazal/slice'

In [None]:
def make_box(mask):
    '''
    this function return dict object contain bounding box of each pixel label and bbox of true value in bolean array
    '''
    box = []
    box_bool=[]
    obj_ids = np.unique(mask)
    if mask.dtype == bool:
        if len(np.unique(mask))==2:
            
            y_min = np.nonzero(mask)[0].min()
            y_max = np.nonzero(mask)[0].max()
            x_min = np.nonzero(mask)[1].min()
            x_max = np.nonzero(mask)[1].max()
            box=[x_min, y_min, x_max, y_max]
    else :
        mask_bool = mask.astype(np.bool)
        if len(np.unique(mask_bool))==2:
            
            y_min = np.nonzero(mask_bool)[0].min()
            y_max = np.nonzero(mask_bool)[0].max()
            x_min = np.nonzero(mask_bool)[1].min()
            x_max = np.nonzero(mask_bool)[1].max()
            box_bool=[x_min, y_min, x_max, y_max]
        
        for i in  obj_ids[1:]:
            y_min = np.nonzero(mask==i)[0].min()
            y_max = np.nonzero(mask==i)[0].max()
            x_min = np.nonzero(mask==i)[1].min()
            x_max = np.nonzero(mask==i)[1].max()
            box.append([x_min, y_min, x_max, y_max])
    return_object = {'bbox_bool': box} if mask.dtype ==bool else {'bbox_label': dict(zip([f"label{int(i)}" for i in obj_ids[1:]],box)), 'bbox_bool' : box_bool}
    return return_object

In [None]:
def slice_label(masks):

    l0 = []
    l1_0 =[]
    l2_0 = []
    l2_1_0 = []
    for i in range(masks.shape[-1]):
        uni = np.unique(masks[:,:,i])

        if len(uni) ==1:
            l0.append(i)
        elif (0 in uni) and (1 in uni) and not (2 in uni):
            l1_0.append(i)
        elif (0 in uni) and (2 in uni) and not (1 in uni):
            l2_0.append(i)
        else :
            l2_1_0.append(i)
    
    return {'label_0':l0, 'label_1_0':l1_0,'label_2_0': l2_0, 'label_2_1_0': l2_1_0}

In [None]:
def normalization_channel(image, dim_channel = 2):
    if isinstance(image, torch.Tensor):
        mean = torch.mean(image, dim_channel, keepdims=True)
        std = torch.std(image, dim_channel, keepdims=True)
        normal = (image-mean)/std
    
    else :
        mean = np.mean(image, dim_channel, keepdims=True)
        std = np.std(image, dim_channel, keepdims=True)
        normal = (image-mean)/std
    
    return normal

def normalization(image):
    
        mean = np.mean(image)
        std = np.std(image)
        normal = (image-mean)/std
    
        return normal

In [None]:
def crop_specific(im, mask, crop_size):
        
        k=make_box(mask.astype(np.bool))
        if len(k['bbox_bool'])!=0:
            b_list = k['bbox_bool']
            x_min, y_min, x_max, y_max = b_list
            
            length_x = x_max -x_min + 1
            length_y = y_max - y_min + 1
            del_x = crop_size - length_x 
            del_y = crop_size - length_y 

            if del_x%2==0:
                kx = del_x//2
                sx=0
            else:
                kx = del_x//2

                sx = 1
            if del_y%2==0:
                ky = del_y//2
                sy=0
            else:
                ky = del_y//2
                sy = 1
            
            crop = mask[ y_min - ky - sy : y_max + ky + 1, x_min - kx - sx : x_max + kx  + 1]
            crop_im = im[ y_min - ky - sy : y_max + ky + 1, x_min - kx - sx : x_max + kx  + 1]
            return crop_im, crop,

# Dataset

In [None]:
class promise(torch.utils.data.Dataset):
    


    def __init__(self,slice_root, crop=192, mode = 'train', booli = True, test_size = 0.3):
        super().__init__()
        self.bool = booli
        self.crop = crop
        self.image = glob.glob(slice_root + '/*HK*Case[0-9][0-9]_[0-9]*') +glob.glob(slice_root + '/*BIDMC*Case[0-9][0-9]_[0-9]*')+glob.glob(slice_root + '/*UCL*Case[0-9][0-9]_[0-9]*')# all image
        self.segment = glob.glob(slice_root + '/*HK*[sS]eg*') +glob.glob(slice_root + '/*BIDMC*[sS]eg*')+glob.glob(slice_root + '/*UCL*[sS]eg*')
        self.image.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        self.segment.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        
        im_train = []
        seg_train = []

        im_test = []
        seg_test = []

        im_val = []
        seg_val = []
        for en, i in enumerate(os.walk(case)):
            if en>0 and (os.path.basename(i[0])=='BIDMC' or os.path.basename(i[0])=='HK' or
                        os.path.basename(i[0])=='UCL') :
                patients_list = glob.glob(i[0] + '/' + '*[0-9].*')
                patients_list.sort()
                segment_list = glob.glob(i[0] + '/' + '*n.*')
                segment_list.sort()
                test_val_len = math.ceil(test_size * len(patients_list))
               
            
                random.seed(23)
                image_case_val_test = random.sample(patients_list, k = test_val_len)
                random.seed(23)
                segment_case_val_test = random.sample(segment_list, k = test_val_len)
        
        
                image_case_train = list(set(patients_list) - set(image_case_val_test))
                segment_case_train = list(set(segment_list) - set(segment_case_val_test))
        
    
                test_len = math.ceil(0.7 * len(image_case_val_test))
                random.seed(23)
                image_case_test = random.sample(image_case_val_test, k = test_len)
                random.seed(23)
                segment_case_test = random.sample(segment_case_val_test, k = test_len)
        
        
                image_case_val = list(set(image_case_val_test) - set(image_case_test))
                segment_case_val = list(set(segment_case_val_test) - set(segment_case_test))
                
                
                
                
            
                folder = os.path.basename(i[0])

                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_train]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_train]

                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]
                im_train.extend(bas)
                seg_train.extend(bas1)
        
        
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_test]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_test]
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]
                im_test.extend(bas)
                seg_test.extend(bas1)
        
        
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_val]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_val]
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]

                im_val.extend(bas)
                seg_val.extend(bas1)
        
        
        image_slice_train = []
        segment_slice_train= []
        for im, seg in zip(im_train, seg_train):
            image_slice_train.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_train.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        image_slice_test = []
        segment_slice_test = []
        for im, seg in zip(im_test, seg_test):
            image_slice_test.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_test.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        image_slice_val = []
        segment_slice_val = []
        for im, seg in zip(im_val, seg_val):
            image_slice_val.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_val.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            

        
        
        image_slice_train.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_train.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))        
           
        image_slice_test.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_test.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))           
        image_slice_val.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_val.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))        
        
        if mode =='train':
            self.general_im = image_slice_train
            self.general_seg = segment_slice_train
        elif mode == 'test':
            self.general_im = image_slice_test
            self.general_seg = segment_slice_test
        elif mode=='val':
            self.general_im = image_slice_val
            self.general_seg = segment_slice_val    

    def __len__(self):
        
        return len(self.general_im) 
    def __getitem__(self, val):
        crop =self.crop
        image_array = np.load(self.general_im[val])
        segment_array = np.load(self.general_seg[val])
        k=make_box(segment_array.astype(np.bool))
        if len(k['bbox_bool'])!=0:
        
             im_c , seg_c = crop_specific(image_array, segment_array, crop)
        else :
            im_c , seg_c = resize(image_array, (crop,crop)), resize(segment_array, (crop,crop))
        image_array_normal = normalization(im_c)
        
        if self.bool:
            seg_c[seg_c>=1.0]=1.0
        org_dim_image = np.expand_dims(image_array_normal, 0)
        org_dim_seg = np.expand_dims(seg_c, 0)
        org_dim_image = np.float32(org_dim_image)
        org_dim_seg = np.float32(org_dim_seg)

        
        
        return {'image':org_dim_image, 'seg':org_dim_seg}
    
    
    

In [None]:
class ISBIDataset(torch.utils.data.Dataset):
    


    def __init__(self,slice_root, crop=192, mode = 'train', booli = True, test_size = 0.2):
        super().__init__()
        self.bool = booli
        self.crop = crop
        self.image = glob.glob(slice_root + '/*ISBI*Case[0-9][0-9]_[0-9]*') # all image
        self.segment = glob.glob(slice_root + '/*ISBI*[sS]eg*')
        self.image.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        self.segment.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        
        im_train = []
        seg_train = []

        im_test = []
        seg_test = []

        im_val = []
        seg_val = []
        for en, i in enumerate(os.walk(case)):
            if en>0 and (os.path.basename(i[0])=='ISBI' or os.path.basename(i[0])=='ISBI-1.5') :
                patients_list = glob.glob(i[0] + '/' + '*[0-9].*')
                patients_list.sort()
                segment_list = glob.glob(i[0] + '/' + '*n.*')
                segment_list.sort()
                test_val_len = math.ceil(test_size * len(patients_list))
               
            
                random.seed(23)
                image_case_val_test = random.sample(patients_list, k = test_val_len)
                random.seed(23)
                segment_case_val_test = random.sample(segment_list, k = test_val_len)
        
        
                image_case_train = list(set(patients_list) - set(image_case_val_test))
                segment_case_train = list(set(segment_list) - set(segment_case_val_test))
        
    
                test_len = math.ceil(0.6 * len(image_case_val_test))
                random.seed(23)
                image_case_test = random.sample(image_case_val_test, k = test_len)
                random.seed(23)
                segment_case_test = random.sample(segment_case_val_test, k = test_len)
        
        
                image_case_val = list(set(image_case_val_test) - set(image_case_test))
                segment_case_val = list(set(segment_case_val_test) - set(segment_case_test))
                
                
                
                
            
                folder = os.path.basename(i[0])

                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_train]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_train]

                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]
                im_train.extend(bas)
                seg_train.extend(bas1)
        
        
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_test]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_test]
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]
                im_test.extend(bas)
                seg_test.extend(bas1)
        
        
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_val]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_val]
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]

                im_val.extend(bas)
                seg_val.extend(bas1)
        
        
        image_slice_train = []
        segment_slice_train= []
        for im, seg in zip(im_train, seg_train):
            image_slice_train.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_train.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        image_slice_test = []
        segment_slice_test = []
        for im, seg in zip(im_test, seg_test):
            image_slice_test.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_test.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        image_slice_val = []
        segment_slice_val = []
        for im, seg in zip(im_val, seg_val):
            image_slice_val.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            segment_slice_val.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            

        
        
        image_slice_train.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_train.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))        
           
        image_slice_test.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_test.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))           
        image_slice_val.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        segment_slice_val.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))        
        
        if mode =='train':
            self.general_im = image_slice_train
            self.general_seg = segment_slice_train
        elif mode == 'test':
            self.general_im = image_slice_test
            self.general_seg = segment_slice_test
        elif mode=='val':
            self.general_im = image_slice_val
            self.general_seg = segment_slice_val    

    def __len__(self):
        
        return len(self.general_im) 
    def __getitem__(self, val):
        crop =self.crop
        image_array = np.load(self.general_im[val])
        segment_array = np.load(self.general_seg[val])
        k=make_box(segment_array.astype(np.bool))
        if len(k['bbox_bool'])!=0:
        
             im_c , seg_c = crop_specific(image_array, segment_array, crop)
        else :
            im_c , seg_c = resize(image_array, (crop,crop)), resize(segment_array, (crop,crop))
        image_array_normal = normalization(im_c)
        
        if self.bool:
            seg_c[seg_c>=1.0]=1.0
        org_dim_image = np.expand_dims(image_array_normal, 0)
        org_dim_seg = np.expand_dims(seg_c, 0)
        org_dim_image = np.float32(org_dim_image)
        org_dim_seg = np.float32(org_dim_seg)

        
        
        return {'image':org_dim_image, 'seg':org_dim_seg}
    
    
    

In [None]:
class i2cv(torch.utils.data.Dataset):
    


    def __init__(self,slice_root, crop=192, mode = 'train', booli = True):
        super().__init__()
        self.bool = booli
        self.crop = crop
        self.image = glob.glob(slice_root + '/*I2CVB*Case[0-9][0-9]_[0-9]*') # all image
        self.segment = glob.glob(slice_root + '/*I2CVB*[sS]eg*')
        self.image.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        self.segment.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        
        j_i = []
        s_i = []
        # this for loop select randomly test data and append in 2 list(test data is case id)
        for en, i in enumerate(os.walk(case)):
            if en>0 and (os.path.basename(i[0])=='I2CVB') :
                patients_list = glob.glob(i[0] + '/' + '*[0-9].*')
                patients_list.sort()
                segment_list = glob.glob(i[0] + '/' + '*n.*')
                segment_list.sort()
                test_len = int(0.2 * len(patients_list))
                #im_val = patients_list[:val_len]
                #seg_val = segment_list[:val_len]
            
                random.seed(23)
                image_case_test = random.sample(patients_list, k = test_len)
                random.seed(23)
                segment_case_test = random.sample(segment_list, k = test_len)
                
                
                
                
                
                
                
            
            
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_test]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_test]

                folder = os.path.basename(i[0])
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]

                j_i.extend(bas)
                s_i.extend(bas1)
        self.image_slice_test = []
        self.segment_slice_test = []
        for im, seg in zip(j_i, s_i):
            self.image_slice_test.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            self.segment_slice_test.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        self.image_slice_test.sort()
        self.segment_slice_test.sort()
        
        self.train_im = list(set(self.image) - set(self.image_slice_test))
        self.train_im.sort()
        
        self.train_seg = list(set(self.segment) - set(self.segment_slice_test))
        self.train_seg.sort()
                
        if mode =='train':
            self.general_im = self.train_im
            self.general_seg = self.train_seg
        elif mode == 'test':
            self.general_im = self.image_slice_test
            self.general_seg = self.segment_slice_test            

    def __len__(self):
        
        return len(self.general_im) 
    def __getitem__(self, val):
        crop =self.crop
        image_array = np.load(self.general_im[val])
        segment_array = np.load(self.general_seg[val])
        k=make_box(segment_array.astype(np.bool))
        if len(k['bbox_bool'])!=0:
        
             im_c , seg_c = crop_specific(image_array, segment_array, crop)
        else :
            im_c , seg_c = resize(image_array, (crop,crop)), resize(segment_array, (crop,crop))
        image_array_normal = normalization(im_c)
        
        if self.bool:
            seg_c[seg_c>=1.0]=1.0
        org_dim_image = np.expand_dims(image_array_normal, 0)
        org_dim_seg = np.expand_dims(seg_c, 0)
        org_dim_image = np.float32(org_dim_image)
        org_dim_seg = np.float32(org_dim_seg)

        
        
        return {'image':org_dim_image, 'seg':org_dim_seg}
    
    
    

In [None]:
class dat(torch.utils.data.Dataset):
    


    def __init__(self,slice_root, crop=192, mode = 'train', booli = True):
        super().__init__()
        self.bool = booli
        self.crop = crop
        self.image = glob.glob(slice_root + '/*Case[0-9][0-9]_[0-9]*') # all image
        self.segment = glob.glob(slice_root + '/*[sS]eg*')
        self.image.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        self.segment.sort(key = lambda x : int(x.split('.')[-2].split('_')[-1]))
        
        j_i = []
        s_i = []
        # this for loop select randomly test data and append in 2 list(test data is case id)
        for en, i in enumerate(os.walk(case)):
            if en>0 :
                patients_list = glob.glob(i[0] + '/' + '*[0-9].*')
                patients_list.sort()
                segment_list = glob.glob(i[0] + '/' + '*n.*')
                segment_list.sort()
                test_len = int(0.2 * len(patients_list))
                #im_val = patients_list[:val_len]
                #seg_val = segment_list[:val_len]
            
                random.seed(23)
                image_case_test = random.sample(patients_list, k = test_len)
                random.seed(23)
                segment_case_test = random.sample(segment_list, k = test_len)
                
                
                
                
                
                
                
            
            
                file_name = [f.split('/')[-1].split('.')[0] for f in image_case_test]
                file_name_s = [f.split('/')[-1].split('.')[0] for f in segment_case_test]

                folder = os.path.basename(i[0])
                bas = [folder + '_' + i for i in  file_name]
                bas1 = [folder + '_' + i for i in  file_name_s]

                j_i.extend(bas)
                s_i.extend(bas1)
        self.image_slice_test = []
        self.segment_slice_test = []
        for im, seg in zip(j_i, s_i):
            self.image_slice_test.extend(glob.glob(sli_path + '/' + '*{}_[0-9]*'.format(im)))
            self.segment_slice_test.extend(glob.glob(sli_path + '/' + '*{}*'.format(seg)))
            
        self.image_slice_test.sort()
        self.segment_slice_test.sort()
        
        self.train_im = list(set(self.image) - set(self.image_slice_test))
        self.train_im.sort()
        
        self.train_seg = list(set(self.segment) - set(self.segment_slice_test))
        self.train_seg.sort()
                
        if mode =='train':
            self.general_im = self.train_im
            self.general_seg = self.train_seg
        elif mode == 'test':
            self.general_im = self.image_slice_test
            self.general_seg = self.segment_slice_test            

    def __len__(self):
        
        return len(self.general_im) 
    def __getitem__(self, val):
        crop =self.crop
        image_array = np.load(self.general_im[val])
        segment_array = np.load(self.general_seg[val])
        k=make_box(segment_array.astype(np.bool))
        if len(k['bbox_bool'])!=0:
        
             im_c , seg_c = crop_specific(image_array, segment_array, crop)
        else :
            im_c , seg_c = resize(image_array, (crop,crop)), resize(segment_array, (crop,crop))
        image_array_normal = normalization(im_c)
        if self.bool:
            
            seg_c[seg_c>=1.0]=1.0

        
        org_dim_image = np.expand_dims(image_array_normal, 0)
        org_dim_seg = np.expand_dims(seg_c, 0)
        org_dim_image = np.float32(org_dim_image)
        org_dim_seg = np.float32(org_dim_seg)

        
        
        return {'image':org_dim_image, 'seg':org_dim_seg}
    
    
    

# model

In [2]:
class resnet_block(nn.Module):
    
    def __init__(self, input_activation, intermediate, expand = 1, stride = 1, down = None):
        super().__init__()
        self.expand = expand
        output = intermediate * self.expand
        self.conv1x1_1 = nn.Conv2d(input_activation, intermediate, 1)
        self.BN1 = nn.BatchNorm2d(intermediate)
        
        self.conv3x3 =  nn.Conv2d(intermediate, intermediate, 3, stride=stride, padding=1)
        self.BN2 = nn.BatchNorm2d(intermediate)

        self.conv1x1_2 =  nn.Conv2d(intermediate, output, 1)
        self.BN3 = nn.BatchNorm2d(output)
       
        self.down = nn.Conv2d(input_activation, output, 1, stride=stride)
        
    def forward(self , inp):
        inp1 = inp
        c = F.relu(self.BN1(self.conv1x1_1(inp)))
        c = F.relu(self.BN2(self.conv3x3(c)))
        c = F.relu(self.BN3(self.conv1x1_2(c)))
        if self.down!=None:
            inp1 = self.down(inp)
        
        out = F.relu(c + inp1) 
        out=  F.dropout2d(out, p=0.2)
        
        return out

In [3]:
class block(nn.Module):
    
    def __init__(self, input_activation, intermediate, expand = 1, stride = 1, p = 0.1):
        super().__init__()
     
        self.p = p
        self.conv3x3 =  nn.Conv2d(input_activation, intermediate, 3, stride=stride, padding=1)
        self.BN1 = nn.BatchNorm2d(intermediate)
        
        
        self.conv3x3_1 =  nn.Conv2d(intermediate, intermediate, 3, stride=stride, padding=1)
        self.BN2 = nn.BatchNorm2d(intermediate)

        
    def forward(self , inp):
        c = F.relu(self.BN1(self.conv3x3(inp)))
        c = F.relu(self.BN2(self.conv3x3_1(c)))
   
        
       
        out=  F.dropout2d(c, p = self.p)
        
        return out

In [4]:
class Unet(nn.Module):
    
    def __init__(self, n_class):

        
        
        super().__init__()
        self.inp = nn.Conv2d(1,16,3,padding=1)
        self.en_block1 = block(16,32)
        self.en_block2 = block(32,64)
        self.en_block3 = block(64,128)
        self.en_block4 = block(128,256)
        self.en_block5 = block(256,512)
        self.en_block6 = block(512, 1024)

        
        self.transpose5 = nn.ConvTranspose2d(1024,512,2,2)
        self.transpose4 = nn.ConvTranspose2d(512,256,2,2)

        self.transpose3 = nn.ConvTranspose2d(256,128,2,2)
        self.transpose2 = nn.ConvTranspose2d(128,64,2,2)
        self.transpose1 = nn.ConvTranspose2d(64,32,2,2)
        
        self.de_block1 = block(64,32)
        self.de_block2 = block(128,64)
        self.de_block3 = block(256,128)

        self.de_block4 = block(512, 256)
        self.de_block5 = block(1024, 512)
        self.out_conv = nn.Conv2d(32, n_class, 1)

        
    

    def forward(self, inp):
        inp  = self.inp(inp)
        el1 = self.en_block1(inp) #  (32,h,w)
        max1 = nn.MaxPool2d(2)(el1) # (32,h//2, w//2)

        el2 = self.en_block2(max1)    #(64, h//2, w//2)

        max2 = nn.MaxPool2d(2)(el2)  #(64, h//4, w//4)

        el3 = self.en_block3(max2)    #(128, h//4, w//4)

        max3 = nn.MaxPool2d(2)(el3)  #(128, h//8, w//8)


        el4 = self.en_block4(max3)    #(256, h//8, w//8)

        max4 = nn.MaxPool2d(2)(el4)  #(256, h//16, w//16)

        el5 = self.en_block5(max4)  #(512, h//16, w//16)

        max5 = nn.MaxPool2d(2)(el5)  #(512, h//32, w//32)

        
        el6 = self.en_block6(max5)  #(1024, h//32, w//32)


        tl5 = self.transpose5(el6)  #(512, h//16, w//16)

        cat5 = torch.cat([tl5, el5], 1) #(1024, h//16, h//16 )

        d5 =  self.de_block5(cat5)      #(512, h//16, w//16

        
        tl4 = self.transpose4(d5)       #(256, h//8, w//8)
        cat4 = torch.cat([tl4, el4], 1) #(512, h//8, w//8)
        d4 =  self.de_block4(cat4)     #(256, h//8, w//8)
        
        tl3 = self.transpose3(d4)        #(128, h//4, w//4)
        cat3 = torch.cat([tl3, el3], 1)  #(256, h//4, w//4)
        d3 =  self.de_block3(cat3)        #(128, h//4, w//4)
        
        
        tl2 = self.transpose2(d3)          #(64, h//2, w//2)
        cat2 = torch.cat([tl2, el2], 1)   #(128, h//2, w//2)
        d2 =  self.de_block2(cat2)         #(64, h//2, w//2)
        
        tl1 = self.transpose1(d2)          #(32, h, w)
        cat1 = torch.cat([tl1, el1], 1) #(64, h, w)
        d1 =  self.de_block1(cat1)        #(32, h, w)
        output = torch.sigmoid(self.out_conv(d1)) 

        return output

In [5]:
class Unet_resnet(nn.Module):
    
    def __init__(self, n_class):

        
        
        super().__init__()
        self.inp = nn.Conv2d(1,16,3,padding=1)
        self.en_block1 = resnet_block(16,32)
        self.en_block2 = resnet_block(32,64)
        self.en_block3 = resnet_block(64,128)
        self.en_block4 = resnet_block(128,256)
        self.en_block5 = resnet_block(256,512)
        self.en_block6 = resnet_block(512, 1024)

        
        self.transpose5 = nn.ConvTranspose2d(1024,512,2,2)
        self.transpose4 = nn.ConvTranspose2d(512,256,2,2)

        self.transpose3 = nn.ConvTranspose2d(256,128,2,2)
        self.transpose2 = nn.ConvTranspose2d(128,64,2,2)
        self.transpose1 = nn.ConvTranspose2d(64,32,2,2)
        
        self.de_block1 = resnet_block(64,32)
        self.de_block2 = resnet_block(128,64)
        self.de_block3 = resnet_block(256,128)

        self.de_block4 = resnet_block(512, 256)
        self.de_block5 = resnet_block(1024, 512)
        self.out_conv = nn.Conv2d(32, n_class, 1)

        
    

    def forward(self, inp):
        inp  = self.inp(inp)
        el1 = self.en_block1(inp) #  (32,h,w)
        max1 = nn.MaxPool2d(2)(el1) # (32,h//2, w//2)

        el2 = self.en_block2(max1)    #(64, h//2, w//2)

        max2 = nn.MaxPool2d(2)(el2)  #(64, h//4, w//4)

        el3 = self.en_block3(max2)    #(128, h//4, w//4)

        max3 = nn.MaxPool2d(2)(el3)  #(128, h//8, w//8)


        el4 = self.en_block4(max3)    #(256, h//8, w//8)

        max4 = nn.MaxPool2d(2)(el4)  #(256, h//16, w//16)

        el5 = self.en_block5(max4)  #(512, h//16, w//16)

        max5 = nn.MaxPool2d(2)(el5)  #(512, h//32, w//32)

        
        el6 = self.en_block6(max5)  #(1024, h//32, w//32)


        tl5 = self.transpose5(el6)  #(512, h//16, w//16)

        cat5 = torch.cat([tl5, el5], 1) #(1024, h//16, h//16 )

        d5 =  self.de_block5(cat5)      #(512, h//16, w//16

        
        tl4 = self.transpose4(d5)       #(256, h//8, w//8)
        cat4 = torch.cat([tl4, el4], 1) #(512, h//8, w//8)
        d4 =  self.de_block4(cat4)     #(256, h//8, w//8)
        
        tl3 = self.transpose3(d4)        #(128, h//4, w//4)
        cat3 = torch.cat([tl3, el3], 1)  #(256, h//4, w//4)
        d3 =  self.de_block3(cat3)        #(128, h//4, w//4)
        
        
        tl2 = self.transpose2(d3)          #(64, h//2, w//2)
        cat2 = torch.cat([tl2, el2], 1)   #(128, h//2, w//2)
        d2 =  self.de_block2(cat2)         #(64, h//2, w//2)
        
        tl1 = self.transpose1(d2)          #(32, h, w)
        cat1 = torch.cat([tl1, el1], 1) #(64, h, w)
        d1 =  self.de_block1(cat1)        #(32, h, w)
        output = torch.sigmoid(self.out_conv(d1)) 

        return output

In [None]:
dataset_train = promise(sli_path,192,mode = 'train', booli = False)
dataset_test = promise(sli_path,192,mode = 'test',  booli = False)
dataset_val = promise(sli_path,192,mode = 'val',  booli = False)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset_train, 32, shuffle = True)
dataloadert = torch.utils.data.DataLoader(dataset_test, 32, shuffle = True)
dataloaderv = torch.utils.data.DataLoader(dataset_val, 32, shuffle = True)

In [6]:
m = Unet(1)

# train

In [None]:
def one_hot(mask, class_label = [0,1,2]):
    one_hot_list = []
    for label in class_label:
        mask_bool = mask[:,0,:,:] == label
        one_hot_list.append(mask_bool.to(torch.float32))
    one_hot_tensor = torch.stack(one_hot_list,1)
    return one_hot_tensor

In [None]:
def ce_dice(inputs, target):
    cect = ce(inputs, target[:,0,:,:])
    dice = dice_multi_loss(inputs, target)
    return cect + dice

In [None]:
def dice_loss(inputs, target):
    smooth = 1e-8
    intersection = 2.0 * ((target * inputs).sum()) + smooth
    union = target.sum() + inputs.sum() + smooth

    return 1 - (intersection / union)

In [None]:
def dice_multi_loss(inputs, target, class_label = [0,1,2]):
    one_hot_target = one_hot(target)
    dice_sum = 0
    for label in class_label:
        
        dice = dice_loss(inputs[:,label,:,:], one_hot_target[:,label,:,:])
        dice_sum+= dice
    avg = dice_sum / len(class_label)
    return avg
        

In [None]:
def dice_metric(inputs, target):
    smooth = 1e-8
    #in_flat = inputs.contiguous().view(-1)
    #tar_flat = target.contiguous().view(-1)

    intersection = 2.0 * ((target * inputs).sum()) + smooth
    union = target.sum() + inputs.sum() + smooth

    return  (intersection / union)

In [None]:
def dice_multi_metric(inputs, target, class_label = [0,1,2]):
    one_hot_target = one_hot(target)
    dice_sum = 0
    for label in class_label:
        
        dice = dice_metric(inputs[:,label,:,:], one_hot_target[:,label,:,:])
        dice_sum+= dice
    avg = dice_sum / len(class_label)
    return avg
        

In [None]:
def dice_multi_metric3(inputs, target, class_label = [1,2]):
    predict_label = torch.argmax(inputs, 1)
    one_hot_target = one_hot(target)
    dice_sum = 0
    lis = []
    for label in class_label:
        inp = (predict_label==label).to(torch.float32)
        target = one_hot_target[:,label,:,:]

        dice_l = dice_metric(inp, target)
        lis.append(dice_l)
        
   
    return lis[0], lis[1]

In [None]:
def accuracy(output, label):
    smooth = 1e-8
  
    output[output>0.8] =1.0
    correct = (output == label)
    out = correct.sum((1,2,3))
    o = (smooth + out)/(torch.sum(label[0,0,...]) + smooth)
    return o

In [None]:
def bce_dice_loss(inputs, target):
    inputs = inputs.to(device)
    target = inputs.to(device)

    dicescore = dice_coef_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = dice_coef_loss(inputs, target)

    return bceloss + dicescore

In [None]:
device = torch.device("cuda:0" if  torch.cuda.is_available() else "cpu")

In [None]:


def train_step(model, dataload_train, dataload_validation, loss = dice_loss,
               epochs = 100,
               lr = 0.001,
              best = 0.86):

    epochs = epochs
    model.to(device)
    Loss = loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    hist_loss_per_epoch = []
    hist_val_per_epoch = []
    for epoch in range(epochs):

        hist_val_loss_per_batch = []
        hist_loss_per_batch = []
        hist_dicemean_val_per_batch = []
        hist_dicel1_val_per_batch = []
        hist_dicel2_val_per_batch = []

        acc_val_item = []



        for batch, data in enumerate(dataload_train, 1): # dataload return a dict {'image': tensor, 'seg': tensor}
            image = data['image'].to(device)
            segment = data['seg'].to(device)
            
            segment_hat = model(image)

            loss = Loss(segment_hat, segment) + bc(segment_hat[:,0,...], segment[:,0,...])
                      
            hist_loss_per_batch.append(loss.item())
              
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
       
        mean_loss = sum(hist_loss_per_batch)/batch
        print(f'Loss in epoch{epoch} is :', mean_loss)
        hist_loss_per_epoch.append(mean_loss)
        hist_all_metric = []
        for batch_val, data_val in enumerate(dataload_validation, 1):
            
            imaget = data_val['image'].to(device)
            segment_t = data_val['seg'].to(device)

            with torch.set_grad_enabled(False):
                segment_hat_t = model(imaget)
                
                val_loss = Loss(segment_hat_t,segment_t)
                dice_val= dice_metric(segment_hat_t,segment_t)
                dice_val1,dice_val2 = dice_multi_metric3(segment_hat_t,segment_t)
                
                
                hist_val_loss_per_batch.append(val_loss.item())
                hist_dicemean_val_per_batch.append(dice_val.item())
                hist_dicel1_val_per_batch.append(dice_val1.item())
                hist_dicel2_val_per_batch.append(dice_val2.item())

                
                
        mean_val_loss = sum(hist_val_loss_per_batch)/batch_val
        mean_dice_val = sum(hist_dicemean_val_per_batch)/batch_val
        mean_dice_val_l1 = sum(hist_dicel1_val_per_batch)/batch_val

        mean_dice_val_l2 = sum(hist_dicel2_val_per_batch)/batch_val

        hist_val_per_epoch.append([mean_val_loss,mean_dice_val])
        
        print(f'dice_val in epoch{epoch} is :', mean_dice_val)
        print(f'l1_val in epoch{epoch} is :', mean_dice_val_l1)
        print(f'l2_val in epoch{epoch} is :', mean_dice_val_l2)
        print('------------------------------------------')
        if mean_dice_val > best:
            
            torch.save({"epoch" : epoch, 
                        "model_state" : model.state_dict(),
                       "optimizer" : optimizer.state_dict()},
                       
                       f"model_{model.__class__.__name__}_epoch_{epoch}_loss_{Loss.__name__}.pt")

        
    return hist_loss_per_epoch, hist_val_per_epoch



                

In [None]:
hist = train_step(m , dataloader, dataloaderv,loss=dice_loss,lr=0.001, epochs=1)