Import

In [None]:
#Import Python Packages
import copy;
from functools import reduce
from glob import glob
from glob import iglob
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
from operator import __add__
import os
from os.path import splitext
from os import listdir
import pandas as pd;
from PIL import Image, ImageOps
import PIL
from random import randint
import random;
import sys
from scipy import ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import label
from scipy.special import softmax
from skimage import exposure
from skimage import feature
from skimage import transform as tf
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch import optim
from torch.nn import Parameter
from torch.nn.modules import Conv2d, Module
from tqdm import tqdm
from typing import Any
import warnings
warnings.filterwarnings('ignore')


Neural Network

In [None]:


class GaborConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=48, dilation=1,groups=1,bias=False, padding_mode="reflect"): 
        super().__init__()
        self.is_calculated = False

        self.conv_layer = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.kernel_size = self.conv_layer.kernel_size

        # small addition to avoid division by zero
        self.delta = 1e-3

        #Frequency
        self.freq = Parameter(
            (math.pi / 2)
            * math.sqrt(2)
            ** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor),
            requires_grad=True,
        )
        #Theta
        self.theta = Parameter(
            (math.pi / 8)
            * torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor),
            requires_grad=True,
        )
        #Sigma
        self.sigma = Parameter(math.pi / self.freq, requires_grad=True)

        #Psi
        self.psi = Parameter(
            math.pi * torch.rand(out_channels, in_channels), requires_grad=True
        )

        self.x0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False
        )
        self.y0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False
        )

        self.y, self.x = torch.meshgrid(
            [
                torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]),
                torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]),
            ]
        )
        self.y = Parameter(self.y.clone())
        self.x = Parameter(self.x.clone())

        self.weight = Parameter(
            torch.empty(self.conv_layer.weight.shape, requires_grad=True),
            requires_grad=True,
        )

        self.register_parameter("freq", self.freq)
        self.register_parameter("theta", self.theta)
        self.register_parameter("sigma", self.sigma)
        self.register_parameter("psi", self.psi)
        self.register_parameter("x_shape", self.x0)
        self.register_parameter("y_shape", self.y0)
        self.register_parameter("y_grid", self.y)
        self.register_parameter("x_grid", self.x)
        self.register_parameter("weight", self.weight)

    def forward(self, input_tensor):
        if self.training:
            self.calculate_weights()
            self.is_calculated = False
        if not self.training:
            if not self.is_calculated:
                self.calculate_weights()
                self.is_calculated = True
        return self.conv_layer(input_tensor)

    def calculate_weights(self):
        for i in range(self.conv_layer.out_channels):
            for j in range(self.conv_layer.in_channels):
                sigma = self.sigma[i, j].expand_as(self.y) 
                freq = self.freq[i, j].expand_as(self.y) 
                theta = self.theta[i, j].expand_as(self.y) 
                psi = self.psi[i, j].expand_as(self.y) 

                rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta)
                roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta)

                g = torch.exp(
                    -0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2)
                )
                g = g * torch.cos(freq * rotx + psi)
                g = g / (2 * math.pi * sigma ** 2)
                self.conv_layer.weight.data[i, j] = g

    def _forward_unimplemented(self, *inputs: Any):
        """
        code checkers makes implement this method,
        looks like error in PyTorch
        """
        raise NotImplementedError
        
        

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class filt_cat(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

    def forward(self, x1, x2):     

        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x;

class SNP_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(SNP_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.g0 = GaborConv2d(n_channels, n_channels, kernel_size=(96, 96))
        self.fc = filt_cat(n_channels, 2*n_channels)
        self.inc = DoubleConv(2*n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)



    def forward(self, x):
        f = self.g0(x)
        x0 = self.fc(f,x);
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        m = nn.Softmax(dim=1)
        logits = m(x)
        return logits



Segmentation Metrics

In [None]:
def segmentation_metrics(confusion_matrix, per_class=False):
    ret = {
        'mIoU': confusion_matrix.get_mean_class_iou(),
        'mDice': confusion_matrix.get_mean_class_dice(),
        'oAcc': confusion_matrix.get_overall_accuracy(),
        'mSens': confusion_matrix.get_mean_class_sensitivity(),
        'mSpec': confusion_matrix.get_mean_class_specificity(),
        'mF1': confusion_matrix.get_mean_class_f1(),
    }
    if per_class:
        for c,v in enumerate(confusion_matrix.get_dice_per_class()):
            ret['Dice{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_sensitivity_per_class()):
            ret['Sens{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_specificity_per_class()):
            ret['Spec{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_f1_per_class()):
            ret['F1{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_ground_truth_ratios()):
            ret['RGt{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_prediction_ratios()):
            ret['RPr{}'.format(c)] = v
    return ret


class ConfusionMatrix():
    """Maintains a running confusion matrix for a K-class classification problem.
    Rows corresponds to ground-truth targets and columns corresponds to predicted targets."""

    def __init__(self, k):
        self.cm = np.zeros(shape=(k, k), dtype=object)  # array of python ints for unlimited precision
        self.k = k

    def reset(self):
        self.cm.fill(0)

    def add(self, predicted, target):
        """
        Adds new results to the confusion matrix. Filters elements without ground truth (class = -100).

        :param predicted: N, BHW or BDHW-sized tensor of class integers or NK, BKHW or BKDHW-sized tensor
                          of predicted probabilities
        :param target: N, BHW or BDHW-sized tensor of class integers or NK, BKHW or BKDHW-sized tensor
                          of one-hot encoded classes
        """
        # TODO: rewrite into pytorch for speed (no need to copy back to device)

        if isinstance(predicted, torch.Tensor):
            predicted = predicted.detach().cpu().numpy()
        if isinstance(target, torch.Tensor):
            target = target.detach().cpu().numpy()

        if np.ndim(predicted) > 1 and predicted.shape[1] == self.k:
            predicted = np.argmax(predicted, 1)
        else:
            assert (predicted.max() < self.k) and (predicted.min() >= 0), \
                'predicted values are not between 0 and k-1'

        if np.ndim(target) > 1 and target.shape[1] == self.k:
            assert (target >= 0).all() and (target <= 1).all(), \
                'in one-hot encoding, target values should be 0 or 1'
            invalid_idx = target.sum(1) < 0.5
            target = np.argmax(target, 1)
            target[invalid_idx] = -100

        predicted = np.ravel(predicted)
        target = np.ravel(target)
        assert predicted.shape[0] == target.shape[0], \
            'number of targets and predicted outputs do not match'

        # Remove predictions for elements without ground truth
        valid_idx = target != -100
        target = target[valid_idx]
        predicted = predicted[valid_idx]

        # from https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
        # (sklearn.metrics.confusion_matrix is 100x slower)

        
        x = predicted + self.k * target
 
        
        bincount_2d = np.bincount(x.astype(np.int64),
                                  minlength=self.k ** 2)
        
        
        assert bincount_2d.size == self.k ** 2
        cm = bincount_2d.reshape((self.k, self.k))

        self.cm += cm

    def value(self, normalized=False):
        """
        :return confusion matrix of K rows and K columns
        """
        conf = self.cm.astype(np.float64)
        if normalized:
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return conf

    def get_iou_per_class(self):
        """
        :return per-class intersection-over-union / Jaccard coefficient (NaN if class not present nor predicted)
        """
        cm = self.value()
        tp = np.diag(cm)
        rowsum = cm.sum(axis=0)
        colsum = cm.sum(axis=1)
        with np.errstate(invalid='ignore'):
            return tp / (rowsum + colsum - tp)

    def get_mean_class_iou(self):
        """
        :return mean per-class intersection-over-union / Jaccard coefficient (over classes present)
        """
        iou = self.get_iou_per_class()
        iou = iou[~np.isnan(iou)]
        return np.mean(iou)

    def get_overall_accuracy(self):
        """
        :return overall sensitivity (class-unspecific)
        """
        cm = self.value()
        tp = np.diag(cm).sum()
        oa = tp / max(1, np.sum(cm))
        return oa

    def get_sensitivity_per_class(self):
        """
        :return per-class sensitivity (NaN if class not present)
        """
        cm = self.value()
        tp = np.diag(cm)
        with np.errstate(invalid='ignore'):
            return tp / np.sum(cm, axis=1)

    def get_mean_class_sensitivity(self):
        """
        :return mean per-class sensitivity (over classes present)
        """
        sens = self.get_sensitivity_per_class()
        sens = sens[~np.isnan(sens)]
        return np.mean(sens)


    def get_specificity_per_class(self):
        """
        :return per-class specificity (NaN if class not present)
        """
        cm = copy.deepcopy(self.value())
        tp = np.diag(copy.deepcopy(cm))
        tn = copy.deepcopy(tp);
        for i in range(0,len(tn)):
          tn[i] = tp.sum() - tp[i];
          cm[i,i] = 0;

        with np.errstate(invalid='ignore'):
            return tn /(tn +  np.sum(cm, axis=0))

    def get_mean_class_specificity(self):
        """
        :return mean per-class specificity (over classes present)
        """
        spec = self.get_specificity_per_class()
        spec = spec[~np.isnan(spec)]
        return np.mean(spec)

    def get_f1_per_class(self):
        """
        :return per-class f1 (NaN if class not present)
        """
        cm = self.value()
        tp = np.diag(cm)
        with np.errstate(invalid='ignore'):
            return tp / (0.5*(np.sum(cm, axis=1) + np.sum(cm, axis=0)))


    def get_mean_class_f1(self):
        """
        :return mean per-class f1 (over classes present)
        """
        f1 = self.get_f1_per_class()
        f1 = f1[~np.isnan(f1)]
        return np.mean(f1)



    def get_dice_per_class(self):
        """
        :return per-class Dice coefficient (NaN if class not present nor predicted)
        """
        with np.errstate(invalid='ignore'):
            jacc = self.get_iou_per_class()
            return 2 * jacc / (jacc + 1)

    def get_mean_class_dice(self):
        """
        :return mean per-class Dice coefficient (over classes present)
        """
        dice = self.get_dice_per_class()
        dice = dice[~np.isnan(dice)]
        return np.mean(dice)



Find Edges

In [None]:
def find_edges(color_mask):
    

    #Dimensions
    [rows, cols] = color_mask.shape

    struct1 = ndimage.generate_binary_structure(2, 2)

    all_edges = np.zeros([rows, cols]);
    for color_idx in range(0,4):

        #Binary Mask
        mask = copy.deepcopy(color_mask);
        mask[np.where(color_mask==color_idx)] = 1;
        mask[np.where(color_mask!=color_idx)] = 0;

        #Empty Positions
        up = np.zeros([rows, cols])
        down = np.zeros([rows, cols])
        left = np.zeros([rows, cols])
        right = np.zeros([rows, cols])
        up_left = np.zeros([rows, cols])
        up_right = np.zeros([rows, cols])
        down_left = np.zeros([rows, cols])
        down_right = np.zeros([rows, cols])

        #Shift Positions
        up[:rows-1, :] = mask[1:rows,:]
        down[1:rows,:] = mask[0:rows-1,:]
        left[:,:cols-1] = mask[:,1:cols]
        right[:,1:cols] = mask[:,:cols-1]
        up_left[0:rows-1,0:cols-1] = mask[1:rows,1:cols]
        up_right[0:rows-1,1:cols] = mask[1:rows,0:cols-1]
        down_left[1:rows,0:cols-1] = mask[0:rows-1,1:cols]
        down_right[1:rows,1:cols] = mask[0:rows-1,0:cols-1]

        #Fill if Coincides with the Center
        conn = np.zeros([8,rows, cols])
        conn[0] = mask*down_right
        conn[1] = mask*down
        conn[2] = mask*down_left
        conn[3] = mask*right
        conn[4] = mask*left
        conn[5] = mask*up_right
        conn[6] = mask*up
        conn[7] = mask*up_left
        

        #Find Edges & Non-Edges
        sum_conn = np.sum(conn,axis=0)
        not_full = np.where(sum_conn<8,np.full_like(sum_conn, 1),np.full_like(sum_conn, 0))
        salient = np.where(sum_conn>0,np.full_like(sum_conn, 1),np.full_like(sum_conn, 0))
        edge = not_full*salient;
   
        if(color_idx==1):
            all_edges = edge;

        else:
            all_edges +=edge;


    #Remove Overlap
    all_edges[np.where(all_edges>1)] = 1;
    all_edges = ndimage.binary_dilation(all_edges, structure=struct1).astype(all_edges.dtype)



    return all_edges

Dataset

In [None]:
class BasicDataset(Dataset):
    
    
    def __init__(self, images, masks, masks_AAR, images_compression, masks_compression, Train):
        

        #Attributes
        self.images = images;
        self.masks = masks;
        self.masks_AAR = masks_AAR;
        self.images_compression = images_compression;
        self.masks_compression = masks_compression;
        self.Train = Train;
        self.Positions = np.array([[0,0], [0,96], [0,192], [0,288], [96,0], [96,96], [96,192], [96,288], [192,0], [192,96], [192,192], [192,288], [288,0], [288,96], [288,192], [288,288]])

    def __len__(self):
        
        #Length
        return np.shape(self.images)[0]



    def __getitem__(self, i):


        #Extract Images and Masks
        img = self.images[i,:,:].squeeze().copy();
        mask = self.masks[i,:,:].squeeze().copy();
        mask_AAR = self.masks_AAR[i,:,:].squeeze().copy();
        
        #Rows and Columns
        indices = np.arange(0,16)
        np.random.shuffle(indices);
        row_shift = randint(-48,48);
        col_shift = randint(-48,48);

        #Normalize
        img = img/np.max(img);
        img = img - 0.5;
        
        #Iterate   
        for g_index in range(0,16):

            #Select Window
            d_row = 96;
            d_col = 96;
            idx = indices[g_index];
            row = self.Positions[idx,0] + row_shift;
            col = self.Positions[idx,1] + col_shift;
            if(row<0):
                row = 0;
            if(col<0):
                col = 0;
            if(row>288):
                row = 288;
            if(col>288):
                col = 288;            
            

            #Crop 
            cropped_img = img[row:row+d_row,col:col+d_col].copy();           
            cropped_mask = mask[row:row+d_row,col:col+d_col].copy();                  
            cropped_mask_AAR = mask_AAR[row:row+d_row,col:col+d_col].copy();

            
            #Add Compression Artifact
            if(self.Train==1 and randint(0,1) ==0):

                #Select Compression Image
                c_idx = randint(0, np.shape(self.images_compression)[0]-1);
                img_compression = self.images_compression[c_idx,:,:].squeeze().copy();
                mask_compression = self.masks_compression[c_idx,:,:].squeeze().copy();              
                     
                    
                #Normalize
                img_compression = img_compression/np.max(img_compression);
                img_compression = img_compression - 0.5;
                
                #Select Window
                row = randint(0,288);
                col = randint(0,288);              
                d_row = 96;
                d_col = 96;
                
                
                #Crop 
                cropped_img_compression = img_compression[row:row+d_row,col:col+d_col].copy();
                cropped_mask_compression = mask_compression[row:row+d_row,col:col+d_col].copy();  
                

                
                #Filter
                sigma_var = 5;
                filter_smooth = ndimage.gaussian_filter(cropped_mask_compression, sigma=(sigma_var, sigma_var), order=0)
                filter_smooth_opp = 1 - filter_smooth;
                filter_binary = np.round(filter_smooth);


                #Apply Filters
                cropped_img = (cropped_img * filter_smooth_opp) + (cropped_img_compression * filter_smooth);
                mask[np.where(filter_binary==1)] = 0;

                #Find Edges
                edge = find_edges(cropped_mask);
                edge[np.where(filter_binary==1)] = 0;


            else:
                #Find Edges
                edge = find_edges(cropped_mask);           
            
            
               
            #Expand Dim
            cropped_img = np.expand_dims(cropped_img,axis = 0);
            cropped_mask = np.expand_dims(cropped_mask,axis = 0);
            cropped_mask_AAR = np.expand_dims(cropped_mask_AAR,axis = 0);
            edge = np.expand_dims(edge,axis = 0);

            #Add to Group
            if(g_index ==0):
                images = cropped_img;
                masks = cropped_mask;
                masks_AAR = cropped_mask_AAR;
                edges = edge;   
            else:
                images = np.concatenate((images, cropped_img), 0);
                masks = np.concatenate((masks, cropped_mask), 0);
                masks_AAR = np.concatenate((masks_AAR, cropped_mask_AAR), 0);
                edges = np.concatenate((edges, edge), 0);



        
        return {'images': images, 'masks': masks,'masks_AAR': masks_AAR, 'edges': edges}

Evaluate Network

In [None]:
def eval_net(net, loader, tissues, device, n_val, batch_size):
    
    #Evaluation Mode
    net.eval();
    
    #Initialize
    val_dice = [];
    val_dice_0 = [];
    val_dice_1 = [];
    val_dice_2 = [];
    val_dice_3 = [];
    val = np.zeros((tissues+1))

    #Display Progress Bar
    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
        
        #Iterate
        for batch in loader:
          
            #Load
            imgs = torch.from_numpy(np.reshape(np.array(batch['images']), (-1,1,96,96))).to(device=device, dtype=torch.float32);
            targets = torch.from_numpy(np.reshape(np.array(batch['masks']), (-1,1,96,96))).to(device=device, dtype=torch.long);
            targets_AAR = torch.from_numpy(np.reshape(np.array(batch['masks_AAR']), (-1,1,96,96))).to(device=device, dtype=torch.long);


            #Predict
            predictions = net(imgs);

            #Applanation Artifact Removal
            targets = targets*targets_AAR;

                
            #DSC Calculation
            cm = ConfusionMatrix(tissues)
            cm.add(predictions, targets.squeeze(1))
            acc_out = segmentation_metrics(cm);
            dice_per_class = cm.get_f1_per_class();
            val_dice.append(acc_out['mF1']);
            val_dice_0.append(dice_per_class[0]);
            val_dice_1.append(dice_per_class[1]);
            val_dice_2.append(dice_per_class[2]);
            val_dice_3.append(dice_per_class[3]);
                        
            #Update Progress Bar
            pbar.update(batch_size);

            
    #Average Dice
    val[0] = np.nanmean(np.array(val_dice));
    val[1] = np.nanmean(np.array(val_dice_0));
    val[2] = np.nanmean(np.array(val_dice_1));
    val[3] = np.nanmean(np.array(val_dice_2));
    val[4] = np.nanmean(np.array(val_dice_3));


    return val;


Train Network

In [None]:

def train_net(net, epochs, tissues, lr, batch_size, t ,g, python_path, torch_path, Experiment_Name, Model_Name):

  

    #Load Numpy Files
    PID =  np.load(python_path + 'SNP_Net_Groups.npy');   
    images = np.load(python_path + 'SNP_Net_Images.npy');
    masks = np.load(python_path + 'SNP_Net_Masks.npy');
    masks_AAR = np.load(python_path  + 'AAR_Net_Masks.npy');                 
    images_compression = np.load(python_path + 'SNP_Net_Images_Compression.npy');             
    masks_compression = np.load(python_path + 'SNP_Net_Masks_Compression.npy'); 
  

    #Correct Applanation Artifact Removal
    masks_AAR = (masks_AAR==0);
    masks_AAR = masks_AAR *1.0;

    #Split into Training & Validation Set
    indices = np.arange(0,np.shape(images)[0]);
    val_idx = t + 1;
    if(val_idx >20):
      val_idx = 0;
    indices_train = indices[np.where((PID!= t) & (PID!= val_idx))]
    indices_val = indices[np.where(PID== val_idx)]    
    
 
    #Shuffle Data
    np.random.shuffle(indices_train);
    images_train = images[indices_train];
    masks_train = masks[indices_train];
    masks_AAR_train = masks_AAR[indices_train];
    np.random.shuffle(indices_val);
    images_val = images[indices_val];
    masks_val = masks[indices_val];
    masks_AAR_val = masks_AAR[indices_val];


    #Construct Datasets
    train_dataset = BasicDataset(images_train, masks_train, masks_AAR_train, images_compression, masks_compression, 1);
    val_dataset = BasicDataset(images_val, masks_val, masks_AAR_val, images_compression, masks_compression, 0); 
    
    #Construct Loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True);
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True);


    #Write the Training Parameters to the Log File
    logging.info(f'''Started training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {len(train_dataset)}
        Validation size: {len(val_dataset)}
        Device:          {device.type}
    ''');


    #Optimizer and Scheduler
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8);
    scheduler = StepLR(optimizer, step_size = 20, gamma = 0.5)
  

    #Results
    training = np.zeros((tissues+1,epochs));
    validation = np.zeros((tissues+1,epochs));


    #Iterate through Epochs
    for epoch in range(epochs):
        
        #Begin Training
        net.train();


        #Initialize Epoch Loss
        epoch_loss = 0;
        
        #Initialize 
        train_dice = [];
        train_dice_0 = [];
        train_dice_1 = [];
        train_dice_2 = [];
        train_dice_3 = [];

        #Display Progress Bar
        with tqdm(total=len(train_dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            
            
            #Iterate
            for batch in train_loader:
                
                # ---------------------
                #  Use Network
                # ---------------------
                
                
                #Load
                imgs = torch.from_numpy(np.reshape(np.array(batch['images']), (-1,1,96,96))).to(device=device, dtype=torch.float32);
                targets = torch.from_numpy(np.reshape(np.array(batch['masks']), (-1,1,96,96))).to(device=device, dtype=torch.long);
                targets_AAR = torch.from_numpy(np.reshape(np.array(batch['masks_AAR']), (-1,1,96,96))).to(device=device, dtype=torch.long);
                edges = torch.from_numpy(np.reshape(np.array(batch['edges']), (-1,1,96,96))).to(device=device, dtype=torch.float32);


                #Predict
                predictions = net(imgs);


                #Applanation Artifact Removal
                targets = targets*targets_AAR;
      
     
                # ---------------------
                #  Train Encoder-Decoder 
                # ---------------------

                
                #Ratio
                total_size = np.size(targets.cpu().numpy());
                class_0 = np.where(targets.cpu().numpy()==0);            
                class_1 = np.where(targets.cpu().numpy()==1);                
                class_2 = np.where(targets.cpu().numpy()==2);                
                class_3 = np.where(targets.cpu().numpy()==3);  
                if(len(class_0[0]) !=0):
                    ratio_0 = total_size/len(class_0[0]);    
                else:
                    ratio_0 = 0;                  
                if(len(class_1[0]) !=0):
                    ratio_1 = total_size/len(class_1[0]);    
                else:
                    ratio_1 = 0;               
                if(len(class_2[0]) !=0):
                    ratio_2 = total_size/len(class_2[0]);    
                else:
                    ratio_2 = 0;                              
                if(len(class_3[0]) !=0):
                    ratio_3 = total_size/len(class_3[0]);    
                else:
                    ratio_3 = 0;       
                
                #Inverse Class Weights
                weights = [np.power(ratio_0,1/5), np.power(ratio_1, 1/5), np.power(ratio_2,1/5), np.power(ratio_3,1/5)];
                class_weights = torch.FloatTensor(weights).cuda();    
           
                
                #Loss Function
                cross_entropy = nn.CrossEntropyLoss(weight = class_weights, reduction = 'none');
                loss_matrix = cross_entropy(predictions, targets.squeeze(1)); 
                loss_matrix = loss_matrix*targets_AAR.squeeze(1);

                #Inverse-Class and Edge Loss
                inverse_loss = torch.mean(loss_matrix);
                edge_loss = torch.mean(loss_matrix*edges)

                #Overall Loss  
                loss =  (inverse_loss + 3*edge_loss )/3.0;   
                

                #Back Propagation: Discriminator
                optimizer.zero_grad();
                loss.backward();
                optimizer.step();

 

                # ---------------------
                #  Metrics
                # ---------------------

            
                #DSC Calculation
                cm = ConfusionMatrix(tissues)
                cm.add(predictions, targets.squeeze(1))
                acc_out = segmentation_metrics(cm);
                dice_per_class = cm.get_f1_per_class();
                train_dice.append(acc_out['mF1']);
                train_dice_0.append(dice_per_class[0]);
                train_dice_1.append(dice_per_class[1]);
                train_dice_2.append(dice_per_class[2]);
                train_dice_3.append(dice_per_class[3]);

                #Update 
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                #Update Progress Bar
                pbar.update(batch_size)


        #Average Training Dice
        training[0,epoch] = np.nanmean(np.array(train_dice));
        training[1,epoch] = np.nanmean(np.array(train_dice_0));
        training[2,epoch] = np.nanmean(np.array(train_dice_1));
        training[3,epoch] = np.nanmean(np.array(train_dice_2));
        training[4,epoch] = np.nanmean(np.array(train_dice_3));

        #Average Validation Dice
        validation[:,epoch] = eval_net(net, val_loader, tissues, device, len(train_dataset), batch_size);

        
        #Print Results
        logging.info('Average Train Dice: ' + str(training[0,epoch]))
        logging.info('Average Train Dice Background: ' + str(training[1,epoch]))  
        logging.info('Average Train Dice Nerve: ' + str(training[2,epoch]))  
        logging.info('Average Train Dice Neuroma: ' + str(training[4,epoch]))  
        logging.info('Average Train Dice Immune: ' + str(training[3,epoch]))  
        
        logging.info('Average Val Dice: ' + str(validation[0,epoch]))   
        logging.info('Average Val Dice Background: ' + str(validation[1,epoch]))  
        logging.info('Average Val Dice Nerve: ' + str(validation[2,epoch]))  
        logging.info('Average Val Dice Neuroma: ' + str(validation[4,epoch])) 
        logging.info('Average Val Dice Immune: ' + str(validation[3,epoch]))  


        #Learning Rate Scheduler
        scheduler.step()            
            
        #Save Model
        torch.save(net.state_dict(), torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth')
        
        #Save Training Metrics
        np.save(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(epoch+1) +'.npy', training);
        np.save(torch_path  +'Validation_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(epoch+1) +'.npy', validation);
 


    return training, validation

Initialize 

In [None]:
#Cast to Cuda
CUDA_VISIBLE_DEVICES=0
cuda = True if torch.cuda.is_available() else False
if cuda:
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');



Main

In [None]:
#Parameters
Experiment_Name = 'Original';
Model_Name = 'SNP_Net';


#Hyper Parameters
epochs = 100;
lr = 0.0005; 
batch_size = 5;
tissues = 4; 
channels = 1;

#Path
python_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Python/SNP-Net/';
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/SNP-Net/';


#Set up Basic Configuration of the Log File
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s');
logging.info(f'Using device {device}');



try:
    if(cuda): 

        for t in range(0,21):

            for g in range(0,5):
                
                #Load Network 
                net = SNP_Net(n_channels=channels, n_classes=tissues); 
                net.to(device=device)

                #Train
                training, validation  =  train_net(net, epochs, tissues, lr, batch_size, t ,g, python_path, torch_path, Experiment_Name, Model_Name);


                #Save Training Metrics
                np.save(torch_path  +'Training_Dice_'+ Model_Name +'_' + Experiment_Name +'_Group_' + str(t) +'_' + str(g) + '.npy', training);
                np.save(torch_path  +'Validation_Dice_'+  Model_Name +'_' + Experiment_Name +'_Group_' + str(t) +'_' + str(g) + '.npy', validation);


                epoch = np.argmax(validation[0,60:]) + 60; 


                #Delete Other Models
                for i in range(1,epochs+1):
                    if(i !=epoch+1):
                        if(os.path.exists(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(i) +'.pth')):
                                os.remove(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(i) +'.pth');

                    if(os.path.exists(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(i) +'.npy')):
                        os.remove(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(i) +'.npy');
                        os.remove(torch_path  +'Validation_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(i) +'.npy');


except KeyboardInterrupt:
 
    pass  