In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from glob import glob
import numpy as np
import nibabel as nib
import os
import logging
import pickle
import yaml
from torch.utils.data import Dataset
from torchvision import transforms
from torch.autograd import Function, Variable
from tqdm import tqdm
import math

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
torch.backends.cudnn.enabled = False
torch.cuda.get_device_name()

'Tesla P40'

In [4]:
print ('Available devices ', torch.cuda.device_count())
print('Active CUDA Device: GPU', torch.cuda.current_device())

Available devices  1
Active CUDA Device: GPU 0


In [5]:
BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

In [6]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [7]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion,
                                  momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [8]:
class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(True)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(
                    self.num_inchannels[branch_index],
                    num_channels[branch_index] * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm3d(
                    num_channels[branch_index] * block.expansion,
                    momentum=BN_MOMENTUM
                ),
            )

        layers = []
        layers.append(
            block(
                self.num_inchannels[branch_index],
                num_channels[branch_index],
                stride,
                downsample
            )
        )
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(
                block(
                    self.num_inchannels[branch_index],
                    num_channels[branch_index]
                )
            )

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels)
            )

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            nn.Conv3d(
                                num_inchannels[j],
                                num_inchannels[i],
                                1, 1, 0, bias=False
                            ),
                            nn.BatchNorm3d(num_inchannels[i]),
                            nn.Upsample(scale_factor=2**(j-i), mode='nearest')
                        )
                    )
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv3d(
                                        num_inchannels[j],
                                        num_outchannels_conv3x3,
                                        3, 2, 1, bias=False
                                    ),
                                    nn.BatchNorm3d(num_outchannels_conv3x3)
                                )
                            )
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv3d(
                                        num_inchannels[j],
                                        num_outchannels_conv3x3,
                                        3, 2, 1, bias=False
                                    ),
                                    nn.BatchNorm3d(num_outchannels_conv3x3),
                                    nn.ReLU(True)
                                )
                            )
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []

        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse

In [9]:
blocks_dict = {
    'BASIC': BasicBlock,
    'BOTTLENECK': Bottleneck
}

In [10]:
class PoseHighResolutionNet(nn.Module):

    def __init__(self, cfg):
        self.inplanes = 64
        extra = cfg['MODEL']['EXTRA']
        super(PoseHighResolutionNet, self).__init__()

        # stem net
        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm3d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(Bottleneck, 64, 4)

        self.stage2_cfg = extra['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition1 = self._make_transition_layer([256], num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        self.stage3_cfg = extra['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        self.stage4_cfg = extra['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=False)

        self.final_layer_1 = nn.Conv3d(
            in_channels=pre_stage_channels[0],
            out_channels=cfg['MODEL']['NUM_JOINTS'],
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.final_layer_2 = nn.Conv3d(
            in_channels=2,
            out_channels=1,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.sigmoid = F.sigmoid


    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(
                        nn.Sequential(
                            nn.Conv3d(
                                num_channels_pre_layer[i],
                                num_channels_cur_layer[i],
                                3, 1, 1, bias=False
                            ),
                            nn.BatchNorm3d(num_channels_cur_layer[i]),
                            nn.ReLU(inplace=True)
                        )
                    )
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(
                        nn.Sequential(
                            nn.Conv3d(
                                inchannels, outchannels, 3, 2, 1, bias=False
                            ),
                            nn.BatchNorm3d(outchannels),
                            nn.ReLU(inplace=True)
                        )
                    )
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(
                    self.inplanes, planes * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm3d(planes * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_stage(self, layer_config, num_inchannels,
                    multi_scale_output=True):
        num_modules = layer_config['NUM_MODULES']
        num_branches = layer_config['NUM_BRANCHES']
        num_blocks = layer_config['NUM_BLOCKS']
        num_channels = layer_config['NUM_CHANNELS']
        block = blocks_dict[layer_config['BLOCK']]
        fuse_method = layer_config['FUSE_METHOD']

        modules = []
        for i in range(num_modules):
            # multi_scale_output is only used last module
            if not multi_scale_output and i == num_modules - 1:
                reset_multi_scale_output = False
            else:
                reset_multi_scale_output = True

            modules.append(
                HighResolutionModule(
                    num_branches,
                    block,
                    num_blocks,
                    num_inchannels,
                    num_channels,
                    fuse_method,
                    reset_multi_scale_output
                )
            )
            num_inchannels = modules[-1].get_num_inchannels()

        return nn.Sequential(*modules), num_inchannels

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.stage2_cfg['NUM_BRANCHES']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.stage3_cfg['NUM_BRANCHES']):
            if self.transition2[i] is not None:
                x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)

        x_list = []
        for i in range(self.stage4_cfg['NUM_BRANCHES']):
            if self.transition3[i] is not None:
                x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage4(x_list)

        x = self.final_layer_1(y_list[0])
        x = self.final_layer_2(x)
        x = self.sigmoid(x)
        #print(x.size())

        return x


In [11]:
cfg = {'MODEL' : 
                {'SIGMA': 2, 'EXTRA': 
                                    {'FINAL_CONV_KERNEL': 1, 
                                    
                                    'STAGE2': {'NUM_CHANNELS': [32, 64], 'NUM_MODULES': 1, 'FUSE_METHOD': 'SUM', 'BLOCK': 'BASIC', 'NUM_BRANCHES': 2, 'NUM_BLOCKS': [4, 4]}, 
                                    
                                    'STAGE4': {'NUM_CHANNELS': [32, 64, 128, 256], 'NUM_MODULES': 3, 'FUSE_METHOD': 'SUM', 'BLOCK': 'BASIC', 'NUM_BRANCHES': 4, 'NUM_BLOCKS': [4, 4, 4, 4]}, 
                                    
                                    'STAGE3': {'NUM_CHANNELS': [32, 64, 128], 'NUM_MODULES': 4, 'FUSE_METHOD': 'SUM', 'BLOCK': 'BASIC', 'NUM_BRANCHES': 3, 'NUM_BLOCKS': [4, 4, 4]}}, 
                                    
                                    'NAME': 'hrnet', 
                                    'INIT_WEIGHTS': True, 
                                    'NUM_JOINTS': 2}}
                                    

In [12]:
net = PoseHighResolutionNet(cfg)
net = nn.DataParallel(net)
n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Number of parameters in network: ', n_params)
#2d : 28536113

Number of parameters in network:  84791749


In [13]:
#from google.colab import drive
#drive.mount('/content/gdrive')

In [14]:
# Path to the folder that contains folders of segmentation data
#train_path = "C:\\SSRM\\Unprocessed training dataset\\TrainingDataset_MSSEG\\train\\*\\"
#test_path = "C:\\SSRM\\Unprocessed training dataset\\TrainingDataset_MSSEG\\test\\*\\"
#val_path = "C:\\SSRM\\Unprocessed training dataset\\TrainingDataset_MSSEG\\val\\*\\"

train_path = "data/Preprocessed/train/*/"
test_path = "data/Preprocessed/test/*/"
val_path = "data/Preprocessed/validation/*/"

train_image_mask_paths = []
test_image_mask_paths = []
val_image_mask_paths = []


block_size = (16,16,16)

#Load training images
directory_paths = glob(train_path)
for path in directory_paths:
    # Load all the paths for each Flair set of data (1 Flair data and all its segmentation paths)
    flair_path = path + 'FLAIR_preprocessed.nii.gz'
    seg_path = path + 'Consensus.nii.gz'
    train_image_mask_paths.append((flair_path,seg_path))
    
directory_paths = glob(test_path)
for path in directory_paths:
    # Load all the paths for each Flair set of data (1 Flair data and all its segmentation paths)
    flair_path = path + 'FLAIR_preprocessed.nii.gz'
    seg_path = path + 'Consensus.nii.gz'
    test_image_mask_paths.append((flair_path,seg_path))
    
directory_paths = glob(val_path)
for path in directory_paths:
    # Load all the paths for each Flair set of data (1 Flair data and all its segmentation paths)
    flair_path = path + 'FLAIR_preprocessed.nii.gz'
    seg_path = path + 'Consensus.nii.gz'
    val_image_mask_paths.append((flair_path,seg_path))

#train_image_mask_paths = train_image_mask_paths[:1]
#test_image_mask_paths = test_image_mask_paths[:1]
val_image_mask_paths = val_image_mask_paths[:1]
    

In [15]:
def zero_padding(data, block_size):
    # Calculate final size to be achieved
    ceil_val = math.ceil(data.shape[0]/block_size[0])
    #Calculate required padding size 
    pad_val_c = (block_size[0] * ceil_val) - data.shape[0]
    
    # Calculate final size to be achieved
    ceil_val = math.ceil(data.shape[1]/block_size[1])
    #Calculate required padding size
    pad_val_h = (block_size[1] * ceil_val) - data.shape[1]
    
    # Calculate final size to be achieved
    ceil_val = math.ceil(data.shape[2]/block_size[2])
    # Calculate required padding size
    pad_val_w = (block_size[2] * ceil_val) - data.shape[2]
    
    # Constant padding
    #data = data.numpy()
    data = np.pad(data, ((0,pad_val_c),(0,pad_val_h),(0,pad_val_w)), 'constant')
    #data = np.array(data, dtype=np.int16)
    
    #changed dtype to float
    data = np.array(data, dtype=np.float32)
    return data


def get_data_blocks(data, block_size ):
    x = torch.from_numpy(data)
    # Add a dimension at 0th position
    x = x.unsqueeze(0)
    # Kernel Size
    kc, kh, kw = block_size[0], block_size[1], block_size[2]
    # stride
    dc, dh, dw = block_size[0], block_size[1], block_size[2]
    patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
    unfold_shape = patches.size()
    patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)
    #Return Patches and Unfold Shape
    return patches, unfold_shape

In [16]:
def preprocess_image(image_mask_paths):
    img_mask_list = []

    for i in tqdm(range(len(image_mask_paths))):
        
        #load the img and mask
        vol = nib.load(image_mask_paths[i][0])
        m = nib.load(image_mask_paths[i][1])
        
        # Get data, normalize the image and pad
        img = np.array(vol.get_data(), np.float32) 
        img = img / np.amax(img)
        img_padded = zero_padding(img, block_size)
        
        mask = np.array(m.get_data(),np.uint8)
        mask = mask / np.amax(mask)
        mask_padded = zero_padding(mask, block_size)

        # Generate data blocks of block_size
        img_blocks, unfold_shape_img = get_data_blocks(data = img_padded, block_size = block_size)
        mask_blocks, unfold_shape_mask = get_data_blocks(data = mask_padded, block_size = block_size)

        img_array = img_blocks.numpy()
        #print(img_array.shape)
        #img_array = img_array[:,3606:3607, : , : ]
        #print(block_img.shape)
        mask_array = mask_blocks.numpy()
        #mask_array = mask_array[:,3606:3607, : , : ]
        #print(block_mask.shape)
        #print(np.sum(block_mask))
        
        '''final_sum = 0
        #print(mask_array.shape[1])
        for i in range(mask_array.shape[1]):
            temp_sum = np.sum(mask_array[:,i:i+1, : , : ])
            if temp_sum > final_sum:
                final_sum = temp_sum
                block_mask = mask_array[:,i:i+1, : , : ]
                block_img = img_array[:,i:i+1, : , : ]
                #print(temp_sum)
                index = i
                #print("index = ", i)

        img_array = block_img
        mask_array = block_mask'''
        
        #mask_array = mask_array[: , :mask_array.shape[1]//4, : , :]
        #img_array = img_array[: , :img_array.shape[1]//4, : , :]
        #print(mask_array.shape)
        
        
        
        for i in range(len(img_array[0])):
            if np.sum(mask_array[0][i]) !=0:
                img_mask_list.append((img_array[0][i], mask_array[0][i]))

    return img_mask_list 
#a = preprocess_image(train_image_mask_paths)

In [17]:
#print(train_image_mask_paths)
#Training:
train_img_masks = preprocess_image(train_image_mask_paths)

#Training:
test_img_masks = preprocess_image(test_image_mask_paths)

#Validation:
val_img_masks = preprocess_image(val_image_mask_paths)

print('No. of blocks containing lesion: ',len(train_img_masks))
#print(len(val_img_masks))

100%|██████████| 9/9 [00:13<00:00,  1.53s/it]
100%|██████████| 3/3 [00:04<00:00,  1.49s/it]
100%|██████████| 1/1 [00:01<00:00,  1.82s/it]

No. of blocks containing lesion:  1678





In [18]:
train_val_masks = []
train_val_masks.extend(train_img_masks)
train_val_masks.extend(val_img_masks)

In [19]:
class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors.
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        image, label = sample['img'], sample['label']
        image = image[None,:,:]
        label = label[None,:,:]

        return {'img': torch.from_numpy(image.copy()).type(torch.FloatTensor),
                'label': torch.from_numpy(label.copy()).type(torch.FloatTensor)}

In [20]:
class CustomDataset(Dataset):
    def __init__(self, image_masks, transforms=None):

        self.image_masks = image_masks
        self.transforms = transforms

    def __len__(self):  # return count of sample we have

        return len(self.image_masks)

    def __getitem__(self, index):

        image = self.image_masks[index][0] # H, W, C
        mask = self.image_masks[index][1]

#       image = np.transpose(image, axes=[2, 0, 1]) # C, H, W

        sample = {'img': image, 'label': mask}

        if transforms:
            sample = self.transforms(sample)

        return sample

train_dataset = CustomDataset(train_val_masks, transforms=transforms.Compose([ToTensor()]))
val_dataset = CustomDataset(val_img_masks, transforms=transforms.Compose([ToTensor()]))

In [21]:

# define dice coefficient 
class DiceCoeff(Function):
    """Dice coeff for one pair of input image and target image"""
    def forward(self, prediction, target):
        self.save_for_backward(prediction, target)
        eps = 0.0001 # in case union = 0
        # Calculate intersection and union. 
        # You can convert the input image into a vector with input.contiguous().view(-1)
        # Then use torch.dot(A, B) to calculate the intersection.
        A = prediction.view(-1)
        B = target.view(-1)
        inter = torch.dot(A.float(),B.float())
        union = torch.sum(A.float()) + torch.sum(B.float()) - inter + eps
        # Calculate DICE 
        d = inter / union
        return d

# Calculate dice coefficients for batches
def dice_coeff(prediction, target):
    """Dice coeff for batches"""
    s = torch.FloatTensor(1).zero_()
    
    # For each pair of input and target, call DiceCoeff().forward(prediction, target) to calculate dice coefficient
    # Then average
    for i, (a,b) in enumerate(zip(prediction, target)):
        s += DiceCoeff().forward(a,b)
    s = s / (i + 1)
    return s


In [22]:
def eval_net(net, dataset):
    # set net mode to evaluation
    net.eval()
    tot = 0
    print('Validation began')
    print('val len: ', len(dataset))
    #print(next(net.parameters()).is_cuda)
    for i, b in enumerate(dataset):
        img = b['img'].to(device)
        B = img.shape[0]
        true_mask = b['label'].to(device)

        # Feed the image to the network to get predicted mask
        mask_pred = net.forward(img.float())
        #print('predicted')
        
        # For all pixels in predicted mask, set them to 1 if larger than 0.5. Otherwise set them to 0
        #mask_pred = mask_pred > 0.5
        
        # calculate dice_coeff()
        # note that you should add all the dice_coeff in validation/testing dataset together
        # call dice_coeff() here
        masks_probs_flat = mask_pred.view(mask_pred.numel())
        true_masks_flat = true_mask.view(true_mask.numel())
        
        tot += dice_coeff(true_masks_flat,masks_probs_flat)
        #print('tot: ',tot)
        #tot += dice_coeff(true_mask,mask_pred)
        # Return average dice_coeff()
    print('Validation done!')
    return tot / (i + 1)



In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [24]:
#from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter('/content/gdrive/My Drive/IVP Project/Dataset/runs/hrnet')

In [25]:
def dice_loss(input, target):
    smooth = 1.

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

In [26]:
from torch import optim
epochs = 400 # e.g. 10, or more until dice converge
batch_size = 1 # e.g. 16
lr = 0.01        # e.g. 0.01, 0.00001
N_train = len(train_img_masks)
model_save_path = './' #'/content/gdrive/My Drive/IVP Project/Dataset/models/hrnet/'  # directory to same the model after each epoch.

#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=0.99, weight_decay=0.0005)
optimizer = optim.Adam(net.parameters(), lr = lr, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma=0.1)

criterion = nn.BCELoss()

net.to(device)
loss_graph_values = {}
loss_graph_list = []

# Start training
for epoch in range(epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    net.train()
    # Reload images and masks for training and validation and perform random shuffling at the begining of each epoch
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    epoch_loss = 0
    count = 0
    #print('train len: ', len(train_loader))
    #print('Active CUDA Device: GPU', torch.cuda.current_device())
    for i, b in enumerate(train_loader):
        # Get images and masks from each batch
        imgs = b['img']
        true_masks = b['label']
        
        imgs = imgs.to(device)
        true_masks = true_masks.to(device)
        #print('True mask shape: ',true_masks.shape)
        
        # Feed your images into the network
        masks_pred = net.forward(imgs.float())
        #print('mask:',masks_pred.shape)
        #print('true:',true_masks.shape)
        masks_probs = masks_pred[:,0,:,:,:]
        masks_probs = masks_probs.unsqueeze(1)

        masks_probs_flat = masks_probs.reshape(1,-1)
        masks_probs_flat = masks_probs_flat.squeeze()

        true_masks_flat = true_masks.reshape(1,-1)
        true_masks_flat = true_masks_flat.squeeze()

        # Calculate the loss by comparing the predicted masks vector and true masks vector
        # And sum the losses together
        loss = criterion(masks_probs_flat,true_masks_flat.float())
        #loss = dice_loss(masks_probs_flat,true_masks_flat.float())
        epoch_loss += loss.item()
        if count % 50 == 0:
            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
        count = count + 1
        # optimizer.zero_grad() clears x.grad for every parameter x in the optimizer.
        # It’s important to call this before loss.backward(), otherwise you’ll accumulate the gradients from multiple passes.
        optimizer.zero_grad()
        # loss.backward() computes dloss/dx for every parameter x which has requires_grad=True.
        # These are accumulated into x.grad for every parameter x
        loss.backward()
        # optimizer.step updates the value of x using the gradient x.grad.
        optimizer.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
    loss_graph_values[epoch] = epoch_loss / i
    loss_graph_list.append(epoch_loss / i)
    #writer.add_scalar('training loss', (epoch_loss / (i+1)), epoch)

    # Perform validation with eval_net() on the validation data
    #val_dice = eval_net(net,val_loader)
    #print('Validation Dice Coeff: {}'.format(val_dice))
    scheduler.step()
    
    # Save the model after each epoch
    '''if os.path.isdir(model_save_path):
        torch.save(net.state_dict(),model_save_path + 'Brain_Seg_Epoch{}.pth'.format(epoch + 1))
    else:
        os.makedirs(model_save_path, exist_ok=True)
        torch.save(net.state_dict(),model_save_path + 'Brain_Seg_Epoch{}.pth'.format(epoch + 1))
    print('Checkpoint {} saved !'.format(epoch + 1))'''


Starting epoch 1/400.
0.0000 --- loss: 0.827476
0.0298 --- loss: 8.603207
0.0596 --- loss: 0.177520
0.0894 --- loss: 0.120447
0.1192 --- loss: 0.453063
0.1490 --- loss: 0.311244
0.1788 --- loss: 0.253262
0.2086 --- loss: 0.224499
0.2384 --- loss: 0.127449
0.2682 --- loss: 0.134853
0.2980 --- loss: 0.219378
0.3278 --- loss: 0.215672
0.3576 --- loss: 0.206397
0.3874 --- loss: 0.205711
0.4172 --- loss: 0.446368
0.4470 --- loss: 0.442088
0.4768 --- loss: 0.311061
0.5066 --- loss: 0.169029
0.5364 --- loss: 0.552770
0.5662 --- loss: 0.183194
0.5959 --- loss: 0.165732
0.6257 --- loss: 0.350513
0.6555 --- loss: 0.573462
0.6853 --- loss: 0.171759
0.7151 --- loss: 0.093578
0.7449 --- loss: 0.061222
0.7747 --- loss: 0.431967
0.8045 --- loss: 0.133333
0.8343 --- loss: 1.583256
0.8641 --- loss: 0.080998
0.8939 --- loss: 0.326599
0.9237 --- loss: 0.195657
0.9535 --- loss: 0.198137
0.9833 --- loss: 0.092074
1.0131 --- loss: 0.078587
1.0429 --- loss: 0.075507
Epoch finished ! Loss: 0.4354578933403106


KeyboardInterrupt: 

In [None]:
'''if os.path.isdir(model_save_path):
    torch.save(net.state_dict(),model_save_path + 'Brain_Seg_Epoch{}.pth'.format(epoch + 1))
else:
        os.makedirs(model_save_path, exist_ok=True)
        torch.save(net.state_dict(),model_save_path + 'Brain_Seg_Epoch{}.pth'.format(epoch + 1))
print('Checkpoint {} saved !'.format(epoch + 1))'''

In [27]:
#net.load_state_dict(torch.load('Brain_Seg_Epoch116.pth'))
#val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=0)
val_dice = eval_net(net,val_loader)
print('Validation Dice Coeff: {}'.format(val_dice))

Validation began
val len:  98
Validation done!
Validation Dice Coeff: tensor([0.0351], grad_fn=<DivBackward0>)


In [None]:
def predict_img(net,full_img,out_threshold=0.1):
    # set the mode of your network to evaluation
    net.eval()

    X_img = torch.from_numpy(full_img).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():

        output_img = net(X_img.float())
        out_probs = output_img.squeeze(0).squeeze(0)
        out_mask_np = (out_probs>out_threshold).cpu().numpy().astype('uint8')

    return out_mask_np

In [None]:
def reconstruct(blocks, unfold_shape):

    blocks_orig = blocks.view(unfold_shape)
    output_c = unfold_shape[1] * unfold_shape[4]
    output_h = unfold_shape[2] * unfold_shape[5]
    output_w = unfold_shape[3] * unfold_shape[6]
    blocks_orig = blocks_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
    blocks_orig = blocks_orig.view(1, output_c, output_h, output_w)
    # Remove the dimension at 0th position and convert to numpy
    blocks_orig = blocks_orig.squeeze(0).detach().numpy()
    return blocks_orig

In [None]:
train_img_masks_1 = [(train_img_masks[0][0], train_img_masks[0][1])]
print(len(train_img_masks_1))
print(train_img_masks_1[0][1].shape)

In [None]:
mask_pred = predict_img(net=net,full_img=train_img_masks_1[0][0], out_threshold=0.1)

In [None]:
import matplotlib.pyplot as plt

#Image and mask alignment
def show_slices(slices):
    """ Function to display row of image slices """
    fig, axes = plt.subplots(1, len(slices), figsize=(8,8))
    for i, slc in enumerate(slices):
        axes[i].imshow(slc.T, cmap="gray", origin="lower")

for i in range(1):
    for j in range(len(train_img_masks[i][0][0][0])):
        show_slices([train_img_masks_1[i][0][:, :, j], train_img_masks[i][1][:, :, j], mask_pred[:,:,j]])

In [None]:
show_slices([train_img_masks_1[0][0][:,:,8], train_img_masks_1[0][1][:,:,8]])
show_slices([train_img_masks_1[0][0][:,:,8], mask_pred[:,:,8]])

In [None]:
import matplotlib.pylab as plt

lists = sorted(loss_graph_values.items()) # sorted by key, return a list of tuples
lists = lists[:200]
x, y = zip(*lists) # unpack a list of pairs into two tuples

plt.plot(x, y)
plt.show()

In [None]:
loss_graph_values

In [None]:
lists