Import

In [None]:
#Import Python Packages
import copy
from functools import reduce
from glob import glob
from glob import iglob
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
from operator import __add__
import os
from os.path import splitext
from os import listdir
import pandas as pd;
import PIL
from PIL import Image, ImageEnhance
from PIL import Image, ImageOps
from random import sample
from random import randint
import random;
from scipy import stats
from scipy import ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import label
from scipy.special import softmax
from skimage import exposure
from skimage import feature
from skimage import transform as tf
import sys
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn.modules import Conv2d, Module
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from typing import Any
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

Neural Network

In [None]:



class GaborConv2d(Module):
#    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=False, padding_mode="zeros"):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=48, dilation=1,groups=1,bias=False, padding_mode="reflect"): 
        super().__init__()
        self.is_calculated = False

        self.conv_layer = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.kernel_size = self.conv_layer.kernel_size

        # small addition to avoid division by zero
        self.delta = 1e-3

        #Frequency
        self.freq = Parameter(
            (math.pi / 2)
            * math.sqrt(2)
            ** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor),
            requires_grad=True,
        )
        #Theta
        self.theta = Parameter(
            (math.pi / 8)
            * torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor),
            requires_grad=True,
        )
        #Sigma
        self.sigma = Parameter(math.pi / self.freq, requires_grad=True)

        #Psi
        self.psi = Parameter(
            math.pi * torch.rand(out_channels, in_channels), requires_grad=True
        )

        self.x0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False
        )
        self.y0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False
        )

        self.y, self.x = torch.meshgrid(
            [
                torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]),
                torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]),
            ]
        )
        self.y = Parameter(self.y.clone());
        self.x = Parameter(self.x.clone());

        self.weight = Parameter(
            torch.empty(self.conv_layer.weight.shape, requires_grad=True),
            requires_grad=True,
        )

        self.register_parameter("freq", self.freq)
        self.register_parameter("theta", self.theta)
        self.register_parameter("sigma", self.sigma)
        self.register_parameter("psi", self.psi)
        self.register_parameter("x_shape", self.x0)
        self.register_parameter("y_shape", self.y0)
        self.register_parameter("y_grid", self.y)
        self.register_parameter("x_grid", self.x)
        self.register_parameter("weight", self.weight)

    def forward(self, input_tensor):
        if self.training:
            self.calculate_weights()
            self.is_calculated = False
        if not self.training:
            if not self.is_calculated:
                self.calculate_weights()
                self.is_calculated = True
        return self.conv_layer(input_tensor)

    def calculate_weights(self):
        for i in range(self.conv_layer.out_channels):
            for j in range(self.conv_layer.in_channels):
                sigma = self.sigma[i, j].expand_as(self.y) 
                freq = self.freq[i, j].expand_as(self.y) 
                theta = self.theta[i, j].expand_as(self.y) 
                psi = self.psi[i, j].expand_as(self.y) 

                rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta)
                roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta)

                g = torch.exp(
                    -0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2)
                )
                g = g * torch.cos(freq * rotx + psi)
                g = g / (2 * math.pi * sigma ** 2)
                self.conv_layer.weight.data[i, j] = g

    def _forward_unimplemented(self, *inputs: Any):
        """
        code checkers makes implement this method,
        looks like error in PyTorch
        """
        raise NotImplementedError
        
        






class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class filt_cat(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

    def forward(self, x1, x2):     

        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x;



class AAR_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(AAR_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear


        self.g0 = GaborConv2d(n_channels, n_channels, kernel_size=(144, 144))

        
        self.fc = filt_cat(n_channels, 2*n_channels)
        self.inc = DoubleConv(2*n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        f = self.g0(x)
        x0 = self.fc(f,x);
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        m = nn.Softmax(dim=1)
        logits = m(x)
        
        return logits



GAN Generator

In [None]:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        #Embeddings Layers
        self.label_emb = nn.Embedding(n_classes, latent_dim)

        # Initial size before upsampling
        self.init_size = img_size // 4  

        #Resize
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        #Convolution Blocks
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):

        #Multiply Noise by Labels
        gen_input = torch.mul(self.label_emb(labels), noise)

        #Resize
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)

        #Convolution Blocks
        img = self.conv_blocks(out)

        #Output
        return img;

Segmentation Metrics

In [None]:
def segmentation_metrics(confusion_matrix, per_class=False):
    ret = {
        'mIoU': confusion_matrix.get_mean_class_iou(),
        'mDice': confusion_matrix.get_mean_class_dice(),
        'oAcc': confusion_matrix.get_overall_accuracy(),
        'mSens': confusion_matrix.get_mean_class_sensitivity(),
        'mSpec': confusion_matrix.get_mean_class_specificity(),
        'mF1': confusion_matrix.get_mean_class_f1(),
    }
    if per_class:
        for c,v in enumerate(confusion_matrix.get_dice_per_class()):
            ret['Dice{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_sensitivity_per_class()):
            ret['Sens{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_specificity_per_class()):
            ret['Spec{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_f1_per_class()):
            ret['F1{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_ground_truth_ratios()):
            ret['RGt{}'.format(c)] = v
        for c,v in enumerate(confusion_matrix.get_prediction_ratios()):
            ret['RPr{}'.format(c)] = v
    return ret


class ConfusionMatrix():
    """Maintains a running confusion matrix for a K-class classification problem.
    Rows corresponds to ground-truth targets and columns corresponds to predicted targets."""

    def __init__(self, k):
        self.cm = np.zeros(shape=(k, k), dtype=object)  # array of python ints for unlimited precision
        self.k = k

    def reset(self):
        self.cm.fill(0)

    def add(self, predicted, target):
        """
        Adds new results to the confusion matrix. Filters elements without ground truth (class = -100).

        :param predicted: N, BHW or BDHW-sized tensor of class integers or NK, BKHW or BKDHW-sized tensor
                          of predicted probabilities
        :param target: N, BHW or BDHW-sized tensor of class integers or NK, BKHW or BKDHW-sized tensor
                          of one-hot encoded classes
        """
        # TODO: rewrite into pytorch for speed (no need to copy back to device)

        if isinstance(predicted, torch.Tensor):
            predicted = predicted.detach().cpu().numpy()
        if isinstance(target, torch.Tensor):
            target = target.detach().cpu().numpy()

        if np.ndim(predicted) > 1 and predicted.shape[1] == self.k:
            predicted = np.argmax(predicted, 1)
        else:
            assert (predicted.max() < self.k) and (predicted.min() >= 0), \
                'predicted values are not between 0 and k-1'

        if np.ndim(target) > 1 and target.shape[1] == self.k:
            assert (target >= 0).all() and (target <= 1).all(), \
                'in one-hot encoding, target values should be 0 or 1'
            invalid_idx = target.sum(1) < 0.5
            target = np.argmax(target, 1)
            target[invalid_idx] = -100

        predicted = np.ravel(predicted)
        target = np.ravel(target)
        assert predicted.shape[0] == target.shape[0], \
            'number of targets and predicted outputs do not match'

        # Remove predictions for elements without ground truth
        valid_idx = target != -100
        target = target[valid_idx]
        predicted = predicted[valid_idx]

        # from https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
        # (sklearn.metrics.confusion_matrix is 100x slower)

        
        x = predicted + self.k * target
 
        
        bincount_2d = np.bincount(x.astype(np.int64),
                                  minlength=self.k ** 2)
        
        #print(bincount_2d.size)
        #print(self.k ** 2)
        
        assert bincount_2d.size == self.k ** 2
        cm = bincount_2d.reshape((self.k, self.k))

        self.cm += cm

    def value(self, normalized=False):
        """
        :return confusion matrix of K rows and K columns
        """
        conf = self.cm.astype(np.float64)
        if normalized:
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return conf

    def get_iou_per_class(self):
        """
        :return per-class intersection-over-union / Jaccard coefficient (NaN if class not present nor predicted)
        """
        cm = self.value()
        tp = np.diag(cm)
        rowsum = cm.sum(axis=0)
        colsum = cm.sum(axis=1)
        with np.errstate(invalid='ignore'):
            return tp / (rowsum + colsum - tp)

    def get_mean_class_iou(self):
        """
        :return mean per-class intersection-over-union / Jaccard coefficient (over classes present)
        """
        iou = self.get_iou_per_class()
        iou = iou[~np.isnan(iou)]
        return np.mean(iou)

    def get_overall_accuracy(self):
        """
        :return overall sensitivity (class-unspecific)
        """
        cm = self.value()
        tp = np.diag(cm).sum()
        oa = tp / max(1, np.sum(cm))
        return oa

    def get_sensitivity_per_class(self):
        """
        :return per-class sensitivity (NaN if class not present)
        """
        cm = self.value()
        tp = np.diag(cm)
        with np.errstate(invalid='ignore'):
            return tp / np.sum(cm, axis=1)

    def get_mean_class_sensitivity(self):
        """
        :return mean per-class sensitivity (over classes present)
        """
        sens = self.get_sensitivity_per_class()
        sens = sens[~np.isnan(sens)]
        return np.mean(sens)


    def get_specificity_per_class(self):
        """
        :return per-class specificity (NaN if class not present)
        """
        cm = copy.deepcopy(self.value())
        tp = np.diag(copy.deepcopy(cm))
        tn = copy.deepcopy(tp);
        for i in range(0,len(tn)):
          tn[i] = tp.sum() - tp[i];
          cm[i,i] = 0;

        with np.errstate(invalid='ignore'):
            return tn /(tn +  np.sum(cm, axis=0))

    def get_mean_class_specificity(self):
        """
        :return mean per-class specificity (over classes present)
        """
        spec = self.get_specificity_per_class()
        spec = spec[~np.isnan(spec)]
        return np.mean(spec)

    def get_f1_per_class(self):
        """
        :return per-class f1 (NaN if class not present)
        """
        cm = self.value()
        tp = np.diag(cm)
        with np.errstate(invalid='ignore'):
            return tp / (0.5*(np.sum(cm, axis=1) + np.sum(cm, axis=0)))


    def get_mean_class_f1(self):
        """
        :return mean per-class f1 (over classes present)
        """
        f1 = self.get_f1_per_class()
        f1 = f1[~np.isnan(f1)]
        return np.mean(f1)



    def get_dice_per_class(self):
        """
        :return per-class Dice coefficient (NaN if class not present nor predicted)
        """
        with np.errstate(invalid='ignore'):
            jacc = self.get_iou_per_class()
            return 2 * jacc / (jacc + 1)

    def get_mean_class_dice(self):
        """
        :return mean per-class Dice coefficient (over classes present)
        """
        dice = self.get_dice_per_class()
        dice = dice[~np.isnan(dice)]
        return np.mean(dice)



Find Edges

In [None]:
def find_edges(color_mask):
    

    #Dimensions
    [rows, cols] = color_mask.shape

    struct1 = ndimage.generate_binary_structure(2, 2)

    all_edges = np.zeros([rows, cols]);
    for color_idx in range(0,3):

        #Binary Mask
        mask = copy.deepcopy(color_mask);
        mask[np.where(color_mask==color_idx)] = 1;
        mask[np.where(color_mask!=color_idx)] = 0;

        #Empty Positions
        up = np.zeros([rows, cols])
        down = np.zeros([rows, cols])
        left = np.zeros([rows, cols])
        right = np.zeros([rows, cols])
        up_left = np.zeros([rows, cols])
        up_right = np.zeros([rows, cols])
        down_left = np.zeros([rows, cols])
        down_right = np.zeros([rows, cols])

        #Shift Positions
        up[:rows-1, :] = mask[1:rows,:]
        down[1:rows,:] = mask[0:rows-1,:]
        left[:,:cols-1] = mask[:,1:cols]
        right[:,1:cols] = mask[:,:cols-1]
        up_left[0:rows-1,0:cols-1] = mask[1:rows,1:cols]
        up_right[0:rows-1,1:cols] = mask[1:rows,0:cols-1]
        down_left[1:rows,0:cols-1] = mask[0:rows-1,1:cols]
        down_right[1:rows,1:cols] = mask[0:rows-1,0:cols-1]

        #Fill if Coincides with the Center
        conn = np.zeros([8,rows, cols])
        conn[0] = mask*down_right
        conn[1] = mask*down
        conn[2] = mask*down_left
        conn[3] = mask*right
        conn[4] = mask*left
        conn[5] = mask*up_right
        conn[6] = mask*up
        conn[7] = mask*up_left
        

        #Find Edges & Non-Edges
        sum_conn = np.sum(conn,axis=0)
        not_full = np.where(sum_conn<8,np.full_like(sum_conn, 1),np.full_like(sum_conn, 0))
        salient = np.where(sum_conn>0,np.full_like(sum_conn, 1),np.full_like(sum_conn, 0))
        edge = not_full*salient;
   
        if(color_idx==1):
            all_edges = edge;

        else:
            all_edges +=edge;


    #Remove Overlap
    all_edges[np.where(all_edges>1)] = 1;
    all_edges = ndimage.binary_dilation(all_edges, structure=struct1).astype(all_edges.dtype)



    return all_edges

Dataset

In [None]:
class BasicDataset(Dataset):
    
    
    def __init__(self, Real_CCM, Fake_CCM, Train):
        


        #Attributes
        self.Real_CCM = Real_CCM;
        self.Fake_CCM = Fake_CCM;
        self.Train = Train;        



    def __len__(self):

        #Length
        if(self.Train==1):
          return int(6*np.shape(self.Real_CCM)[1])
        else:
          return int(3*np.shape(self.Real_CCM)[1])


    def __getitem__(self, i):
        

                
        #First Layer Parameters
        layer_1 = i%3;
        generated_1 = i%2;
        if(self.Train==1):
            idx_1 = math.floor(i/6);         
        else:
            idx_1 = math.floor(i/3);
            generated_1 = 0;
        

        #First Image
        img_1 = self.Real_CCM[layer_1,idx_1,:,:].squeeze().copy();
        
        #Second Layer Parameters
        layer_2 = randint(0,2);
        generated_2 = randint(0,1);
        if(self.Train==1):
            idx_2 = randint(0, np.shape(self.Real_CCM)[1]-1)       
        else:
            idx_2 = randint(0, np.shape(self.Real_CCM)[1]-1) 
            generated_2 = 0;

  
        #Second Image
        img_2 = self.Real_CCM[layer_2,idx_2,:,:].squeeze().copy();
        
        
                
        #Iterate         
        for g_index in range(0,4):

           
        
            #Crop Image 1
            if(generated_1==0):


                #Row and Column
                row = randint(0,191);
                col = randint(0,191);
                
                #Crop Image
                cropped_img_1 = img_1[row:row + 192,col:col +  192].copy();
                cropped_img_1 = (cropped_img_1 - np.min(cropped_img_1))/(np.max(cropped_img_1) - np.min(cropped_img_1))
         
            else:
                idx_1 = randint(0, np.shape(self.Fake_CCM)[1]-1)  
                cropped_img_1 =  self.Fake_CCM[layer_1,idx_1,:,:].squeeze().copy();    

            #Crop Image 2
            if(generated_2==0):


                #Row and Column
                row = randint(0,191);
                col = randint(0,191);
                
                #Crop Image
                cropped_img_2 = img_2[row:row + 192,col:col +  192].copy();                
                cropped_img_2 = (cropped_img_2 - np.min(cropped_img_2))/(np.max(cropped_img_2) - np.min(cropped_img_2))
         
            else:
                idx_2 = randint(0, np.shape(self.Fake_CCM)[1]-1)  
                cropped_img_2 =  self.Fake_CCM[layer_2,idx_2,:,:].squeeze().copy();  
              
                     
            #Flip Image 1
            if(randint(0,1)==1):
                cropped_img_1 = np.fliplr(cropped_img_1).copy();
            if(randint(0,1)==1):
                cropped_img_1 = np.flipud(cropped_img_1).copy();
        
            #Flip Image 2
            if(randint(0,1)==1):
                cropped_img_2 = np.fliplr(cropped_img_2).copy();
            if(randint(0,1)==1):
                cropped_img_2 = np.flipud(cropped_img_2).copy();    


            #Normalize Images
            if(np.mean(cropped_img_1) < np.mean(cropped_img_2)):
                cropped_img_2 = cropped_img_2 * np.mean(cropped_img_1)/np.mean(cropped_img_2);
            else:
                cropped_img_1 = cropped_img_1 * np.mean(cropped_img_2)/np.mean(cropped_img_1);
                
            cropped_img_1 = cropped_img_1 - 0.5;
            cropped_img_2 = cropped_img_2 - 0.5;


            #Create Masks
            mask_1 = layer_1*np.ones(np.shape(cropped_img_1))
            mask_2 = layer_2*np.ones(np.shape(cropped_img_2))
            
            
            #Produce Filter
            width = 192;
            height = 192;
            original_filter =  np.zeros((height,width))
            x = randint(80,112);
            angle = 0;
            for y in range(0,height):

                #Update
                angle = angle + 0.5 - random.random();
                x = int(x + 0.1 * np.cos(angle));

                #Fill Filter
                if (x > 0 and x< width):
                    original_filter[y,0:x] = 1;

                #Boundary Conditions
                if (x < 0):
                  x = width;
                if (x > width):
                  x = -1;


            #Flip Filter
            if(randint(0,1)==1):
              original_filter = np.fliplr(original_filter).copy();
            if(randint(0,1)==1):
              original_filter = np.flipud(original_filter).copy();

            #Invert Filter
            if(randint(0,1)==1):
              original_filter = 1 - original_filter;


 
            #Smooth Filter
            sigma_var = randint(1,21);
            filter_smooth = ndimage.gaussian_filter(original_filter, sigma=(sigma_var, sigma_var), order=0)
            filter_smooth_opp = 1 - filter_smooth;
            filter_binary = np.round(filter_smooth);
            filter_binary_opp = 1 - filter_binary;



            #Apply Filters
            if(randint(0,1)==1):
                img = (cropped_img_1 * filter_smooth_opp) + (cropped_img_2 * filter_smooth);
                mask = (mask_1 * filter_binary_opp) + (mask_2 * filter_binary);
            else:
                img = cropped_img_1.copy();
                mask = mask_1.copy();


            #Find Edges
            edge = find_edges(mask);

            #Expand Dim
            img = np.expand_dims(img,axis = 0);
            mask = np.expand_dims(mask,axis = 0);
            edge = np.expand_dims(edge,axis = 0);

            #Add to Group
            if(g_index ==0):
                images = img;
                masks = mask; 
                edges = edge;   
            else:
                images = np.concatenate((images, img), 0);
                masks = np.concatenate((masks, mask), 0);
                edges = np.concatenate((edges, edge), 0);

            
      
        
        return {'images': images, 'masks': masks, 'edges': edges}


Evaluate Network

In [None]:
def eval_net(net, loader, tissues, device, n_val):
    
    #Evaluation Mode
    net.eval();
    

    #Initialize
    val_dice = [];
    val_dice_0 = [];
    val_dice_1 = [];
    val_dice_2 = [];
    val = np.zeros((tissues+1));
    
    #Display Progress Bar
    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
        
        #Iterate 
        for batch in loader:
          
            #Load 
            imgs = torch.from_numpy(np.reshape(np.array(batch['images']), (-1,1,192,192))).to(device=device, dtype=torch.float32);
            targets = torch.from_numpy(np.reshape(np.array(batch['masks']), (-1,1,192,192))).to(device=device, dtype=torch.long);

            #Predict
            predictions = net(imgs);
            
                
            #DSC Calculation
            cm = ConfusionMatrix(tissues)
            cm.add(predictions, targets.squeeze(1))
            acc_out = segmentation_metrics(cm);
            f1_per_class = cm.get_f1_per_class();                
            val_dice.append(acc_out['mF1']);
            val_dice_0.append(f1_per_class[0]);
            val_dice_1.append(f1_per_class[1]);
            val_dice_2.append(f1_per_class[2]);  
            
            
            #Update Progress Bar
            pbar.update(imgs.shape[0]);

           
    #Average Dice
    val[0] =  np.nanmean(np.array(val_dice))
    val[1] =  np.nanmean(np.array(val_dice_0))
    val[2] =  np.nanmean(np.array(val_dice_1))
    val[3] =  np.nanmean(np.array(val_dice_2))
    
    
    return val;




Train Network

In [None]:

def train_net(net, epochs, tissues, lr, batch_size, g , python_path, torch_path, Experiment_Name, Model_Name):



    #Load Numpy Files
    SNP_Train = np.load(python_path + 'AAR_Net_Images_Training_SNP.npy');
    Epithelium_Train = np.load(python_path + 'AAR_Net_Images_Training_Epithelium.npy');
    Stroma_Train = np.load(python_path + 'AAR_Net_Images_Training_Stroma.npy');
    SNP_Val = np.load(python_path + 'AAR_Net_Images_Validation_SNP.npy');
    Epithelium_Val = np.load(python_path + 'AAR_Net_Images_Validation_Epithelium.npy');
    Stroma_Val = np.load(python_path + 'AAR_Net_Images_Validation_Stroma.npy');  
    
    
    #Shuffle Data
    indices = np.arange(0,np.shape(SNP_Train)[0]);
    np.random.shuffle(indices);
    SNP_Train = np.expand_dims(SNP_Train[indices],axis = 0)    
    indices = np.arange(0,np.shape(SNP_Val)[0]);
    np.random.shuffle(indices);
    SNP_Val = np.expand_dims(SNP_Val[indices],axis = 0)   
    indices = np.arange(0,np.shape(Epithelium_Train)[0]);
    np.random.shuffle(indices);
    Epithelium_Train = np.expand_dims(Epithelium_Train[indices],axis = 0)      
    indices = np.arange(0,np.shape(Epithelium_Val)[0]);
    np.random.shuffle(indices);
    Epithelium_Val = np.expand_dims(Epithelium_Val[indices],axis = 0)   
    indices = np.arange(0,np.shape(Stroma_Train)[0]);
    np.random.shuffle(indices);
    Stroma_Train = np.expand_dims(Stroma_Train[indices],axis = 0)     
    indices = np.arange(0,np.shape(Stroma_Val)[0]);
    np.random.shuffle(indices);
    Stroma_Val = np.expand_dims(Stroma_Val[indices],axis = 0)  
    
        
    
    #Concatenate
    Real_CCM_Train = np.concatenate((SNP_Train, Epithelium_Train, Stroma_Train), 0);
    Real_CCM_Val = np.concatenate((SNP_Val, Epithelium_Val, Stroma_Val), 0);       


    #Generate Fake Images
    Fake_CCM_Train = np.zeros((3,32800,192,192))
    for layer in range(0,3):
        for idx in range(0,32800):
            latent_dim = 400;
            z = Variable(FloatTensor(np.random.normal(0, 1, (1, latent_dim))))
            labels_fake = np.array(layer* np.ones((1,1)).squeeze());
            labels_fake = Variable(LongTensor(labels_fake))
            fake_img = generator(z, labels_fake)
            Fake_CCM_Train[layer,idx,:,:] = fake_img.cpu().detach().numpy().squeeze();
    Fake_CCM_Val = [];

      
    #Construct Datasets
    train_dataset = BasicDataset(Real_CCM_Train, Fake_CCM_Train, 1);
    val_dataset = BasicDataset(Real_CCM_Val, Fake_CCM_Val, 0);

    #Construct Loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True);
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True);


    #Write the Training Parameters to the Log File
    logging.info(f'''Started training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {len(train_dataset)}
        Validation size: {len(val_dataset)}
        Device:          {device.type}
    ''');


    #Optimizer and Scheduler
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8);
    scheduler = StepLR(optimizer, step_size = 20, gamma = 0.5)
  
    
    #Results
    training = np.zeros((tissues+1,epochs));
    validation = np.zeros((tissues+1,epochs));


    #Iterate through Epochs
    for epoch in range(epochs):
        
        #Begin Training
        net.train();


        #Initialize Epoch Loss
        epoch_loss = 0;
        loss_count = 0;
    
        #Initialize Training Sensitivity & Specificity
        train_dice = [];
        train_dice_0 = [];
        train_dice_1 = [];
        train_dice_2 = [];


        #Display Progress Bar
        with tqdm(total=4*len(train_dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            
            
            #Iterate through the Batches in the Training Set
            for batch in train_loader:
                
                imgs = torch.from_numpy(np.reshape(np.array(batch['images']), (-1,1,192,192))).to(device=device, dtype=torch.float32);
                targets = torch.from_numpy(np.reshape(np.array(batch['masks']), (-1,1,192,192))).to(device=device, dtype=torch.long);
                edges = torch.from_numpy(np.reshape(np.array(batch['edges']), (-1,1,192,192))).to(device=device, dtype=torch.float32);

                

                #Predict the Class
                predictions = net(imgs);
          

                #Ratio
                total_size = np.size(targets.cpu().numpy());
                class_0 = np.where(targets.cpu().numpy()==0);            
                class_1 = np.where(targets.cpu().numpy()==1);                
                class_2 = np.where(targets.cpu().numpy()==2); 
              
                if(len(class_0[0]) !=0):
                    ratio_0 = total_size/len(class_0[0]);    
                else:
                    ratio_0 = 0;
                   
                if(len(class_1[0]) !=0):
                    ratio_1 = total_size/len(class_1[0]);    
                else:
                    ratio_1 = 0;               
                
                if(len(class_2[0]) !=0):
                    ratio_2 = total_size/len(class_2[0]);    
                else:
                    ratio_2 = 0;               
                    
                
                #Inverse Class Weights
                weights = [np.power(ratio_0,1/2), np.power(ratio_1, 1/2), np.power(ratio_2,1/2)];
                class_weights = torch.FloatTensor(weights).cuda();   


                #Calculate Loss
                criterion_= nn.CrossEntropyLoss(weight = class_weights, reduction = 'none');
                loss_matrix = criterion_(predictions, targets.squeeze(1)); 
                loss_1 = torch.mean(loss_matrix);
                loss_2 = torch.mean(loss_matrix*edges);


                loss = loss_1 + 2*loss_2;
                epoch_loss += loss.item()
                loss_count += 1;

  
                #DSC Calculation
                cm = ConfusionMatrix(tissues)
                cm.add(predictions, targets.squeeze(1))
                acc_out = segmentation_metrics(cm);
                f1_per_class = cm.get_f1_per_class();   
                train_dice.append(acc_out['mF1']);
                train_dice_0.append(f1_per_class[0]);
                train_dice_1.append(f1_per_class[1]);
                train_dice_2.append(f1_per_class[2]);



                #Update 
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                #Backpropagation
                optimizer.zero_grad();
                loss.backward();
                optimizer.step();

                #Update Progress Bar
                pbar.update(imgs.shape[0])
                



        #Average Training Dice
        training[0,epoch] = np.nanmean(np.array(train_dice));
        training[1,epoch] = np.nanmean(np.array(train_dice_0));
        training[2,epoch] = np.nanmean(np.array(train_dice_1));
        training[3,epoch] = np.nanmean(np.array(train_dice_2));


        #Average Validation Dice
        validation[:,epoch] = eval_net(net, val_loader, tissues, device, len(train_dataset));

        
        #Print Results
        logging.info('Average Train Dice: ' + str(training[0,epoch]))
        logging.info('Average Train Dice SNP: ' + str(training[1,epoch]))  
        logging.info('Average Train Dice Epithelium: ' + str(training[2,epoch]))  
        logging.info('Average Train Dice Stroma: ' + str(training[3,epoch]))  
        
        logging.info('Average Val Dice: ' + str(validation[0,epoch]))   
        logging.info('Average Val Dice SNP: ' + str(validation[1,epoch]))  
        logging.info('Average Val Dice Epithelium: ' + str(validation[2,epoch]))  
        logging.info('Average Val Dice Stroma: ' + str(validation[3,epoch]))  
 

        #Learning Rate Scheduler
        scheduler.step()            
            

        #Save Model
        torch.save(net.state_dict(), torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth')
        
        #Save Training Metrics
        np.save(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.npy', training);
        np.save(torch_path  +'Validation_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.npy', validation);



    return training, validation

Load Generator

In [None]:
#Path
python_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Python/AAR-Net/';
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/AAR-Net/';

#Parameters
experiment = 'Original'

#Hyper Parameters
latent_dim = 400;
n_classes = 3;
img_size = 192;
channels = 1;


#Cast to Cuda
cuda = True if torch.cuda.is_available() else False
if cuda:
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor

#Attempt to use GPU instead of CPU
CUDA_VISIBLE_DEVICES=0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');

#Load Generator
generator = Generator();
generator.load_state_dict(torch.load(torch_path +  'Generator_' + experiment + '.pth', map_location=device));
generator.to(device=device)

#Set the Network to Eval Mode for Prediction
generator.eval();


Main

In [None]:
#Parameters
Experiment_Name = 'Original';
Model_Name = 'AAR_Net';


#Hyper Parameters
epochs = 80;
lr = 0.005; 
batch_size = 5;
tissues = 3; 
channels = 1;


#Set up Basic Configuration of the Log File
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s');
logging.info(f'Using device {device}');



try:
    if(cuda): 

        for g in range(0,5):


            #Load Network 
            net = AAR_Net(n_channels=channels, n_classes=tissues); 
            net.to(device=device)

            #Train
            training, validation  =  train_net(net, epochs, tissues, lr, batch_size, g , python_path, torch_path, Experiment_Name, Model_Name);

            
            #Save Training Metrics
            np.save(torch_path  +'Training_Dice_'+ Model_Name +'_' + Experiment_Name +'_Group_' + str(g) + '.npy', training);
            np.save(torch_path  +'Validation_Dice_'+  Model_Name +'_' + Experiment_Name +'_Group_' + str(g) + '.npy', validation);

            
            epoch = np.argmax(validation[0,40:]) + 40;


            #Delete Other Models
            for i in range(1,epochs+1):
                if(i !=epoch+1):
                    if(os.path.exists(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(i) +'.pth')):
                            os.remove(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(i) +'.pth');

                if(os.path.exists(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(i) +'.npy')):
                    os.remove(torch_path  +'Training_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(i) +'.npy');
                    os.remove(torch_path  +'Validation_Dice_'+ Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(i) +'.npy');


except KeyboardInterrupt:
    
  pass  