Import

In [None]:
!pip install opencv-python
!pip install pingouin

In [None]:
#Import Python Packages
import copy;
import cv2;
import csv;
from datetime import datetime
from functools import reduce
from glob import glob
from glob import iglob
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
from operator import __add__
import os
from os.path import splitext
from os import listdir
import pandas as pd;
from PIL import Image, ImageOps
import PIL
import pingouin as pg
from random import randint
import random;
import sys
from scipy.stats import wilcoxon
from scipy import ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import label
from scipy.special import softmax
from sklearn.metrics import precision_recall_fscore_support as score
from skimage import exposure
from skimage import feature
from skimage import transform as tf
from skimage import morphology
from skimage.morphology import skeletonize
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch import optim
from torch.nn import Parameter
from torch.nn.modules import Conv2d, Module
from tqdm import tqdm
from typing import Any
import warnings
warnings.filterwarnings('ignore')
import xlsxwriter


Neural Network

In [None]:
class GaborConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=48, dilation=1,groups=1,bias=False, padding_mode="reflect"): 
        super().__init__()
        self.is_calculated = False

        self.conv_layer = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.kernel_size = self.conv_layer.kernel_size

        # small addition to avoid division by zero
        self.delta = 1e-3

        #Frequency
        self.freq = Parameter(
            (math.pi / 2)
            * math.sqrt(2)
            ** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor),
            requires_grad=True,
        )
        #Theta
        self.theta = Parameter(
            (math.pi / 8)
            * torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor),
            requires_grad=True,
        )
        #Sigma
        self.sigma = Parameter(math.pi / self.freq, requires_grad=True)

        #Psi
        self.psi = Parameter(
            math.pi * torch.rand(out_channels, in_channels), requires_grad=True
        )

        self.x0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False
        )
        self.y0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False
        )

        self.y, self.x = torch.meshgrid(
            [
                torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]),
                torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]),
            ]
        )
        self.y = Parameter(self.y.clone())
        self.x = Parameter(self.x.clone())

        self.weight = Parameter(
            torch.empty(self.conv_layer.weight.shape, requires_grad=True),
            requires_grad=True,
        )

        self.register_parameter("freq", self.freq)
        self.register_parameter("theta", self.theta)
        self.register_parameter("sigma", self.sigma)
        self.register_parameter("psi", self.psi)
        self.register_parameter("x_shape", self.x0)
        self.register_parameter("y_shape", self.y0)
        self.register_parameter("y_grid", self.y)
        self.register_parameter("x_grid", self.x)
        self.register_parameter("weight", self.weight)

    def forward(self, input_tensor):
        if self.training:
            self.calculate_weights()
            self.is_calculated = False
        if not self.training:
            if not self.is_calculated:
                self.calculate_weights()
                self.is_calculated = True
        return self.conv_layer(input_tensor)

    def calculate_weights(self):
        for i in range(self.conv_layer.out_channels):
            for j in range(self.conv_layer.in_channels):
                sigma = self.sigma[i, j].expand_as(self.y) 
                freq = self.freq[i, j].expand_as(self.y) 
                theta = self.theta[i, j].expand_as(self.y) 
                psi = self.psi[i, j].expand_as(self.y) 

                rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta)
                roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta)

                g = torch.exp(
                    -0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2)
                )
                g = g * torch.cos(freq * rotx + psi)
                g = g / (2 * math.pi * sigma ** 2)
                self.conv_layer.weight.data[i, j] = g

    def _forward_unimplemented(self, *inputs: Any):
        """
        code checkers makes implement this method,
        looks like error in PyTorch
        """
        raise NotImplementedError
        
        

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class filt_cat(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

    def forward(self, x1, x2):     

        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x;

class SNP_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(SNP_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.g0 = GaborConv2d(n_channels, n_channels, kernel_size=(96, 96))
        self.fc = filt_cat(n_channels, 2*n_channels)
        self.inc = DoubleConv(2*n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)



    def forward(self, x):
        f = self.g0(x)
        x0 = self.fc(f,x);
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        m = nn.Softmax(dim=1)
        logits = m(x)
        return logits
    
    
class AAR_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(AAR_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear


        self.g0 = GaborConv2d(n_channels, n_channels, kernel_size=(144, 144))

        
        self.fc = filt_cat(n_channels, 2*n_channels)
        self.inc = DoubleConv(2*n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        f = self.g0(x)
        x0 = self.fc(f,x);
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        m = nn.Softmax(dim=1)
        logits = m(x)
        
        return logits

Gaussian Filter

In [None]:
def fspecial_gauss(size, sigma):

    """Function to mimic the 'fspecial' gaussian MATLAB function
    """

    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
    g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g

Find Branchpoints

In [None]:
def skeleton_branchpoints(skel):
    # Make our input nice, possibly necessary.
    skel = skel.copy()
    skel[skel!=0] = 1
    skel = np.uint8(skel)

    # Apply the convolution.
    kernel = np.uint8([[1,  1, 1],
                       [1, 10, 1],
                       [1,  1, 1]])
    src_depth = -1
    filtered = cv2.filter2D(skel,src_depth,kernel)


    out = np.zeros_like(skel)
    out[np.where(filtered>12)] = 1
    return out

Find Endpoints

In [None]:
def skeleton_endpoints(skel):
    # Make our input nice, possibly necessary.
    skel = skel.copy()
    skel[skel!=0] = 1
    skel = np.uint8(skel)
    
    skel = np.pad(skel, (1, 1), 'constant', constant_values=(0, 0))

    # Apply the convolution.
    kernel = np.uint8([[1,  1, 1],
                       [1, 10, 1],
                       [1,  1, 1]])
    src_depth = -1
    filtered = cv2.filter2D(skel,src_depth,kernel)

    # Look through to find the value of 11.
    # This returns a mask of the endpoints, but if you
    # just want the coordinates, you could simply
    # return np.where(filtered==11)
    out = np.zeros_like(skel)
    out[np.where(filtered==11)] = 1
    out = out[1:-1,1:-1]

    return out

Count Immune

In [None]:
def count_immune_cells(mask, mask_AAR):

    #Binary Mask of Immune Cells
    binary = np.zeros(np.shape(mask)) 
    binary[np.where(mask==2)] = 1;


    #Create Unique labels
    islands = np.zeros(np.shape(binary))              
    structure = np.ones((3, 3), dtype=int) 
    labeled, ncomponents = label(binary, structure)
    labeled = labeled*binary;

    #Count Immune Cells
    count = 0;
    unique = np.unique(labeled)
    for unique_label in unique:
        if(unique_label!=0):
            if(np.shape(np.where(labeled ==unique_label))[1]>9):
                count = count + 1;    
    

    #Normalize by Area of SNP
    count = count * np.size(mask_AAR)/np.shape(np.where(mask_AAR==1))[1] 
  
    return count;

Count Neuromas

In [None]:
def count_neuromas(mask, mask_AAR):

    #Binary Mask of Neuromas
    binary = np.zeros(np.shape(mask)) 
    binary[np.where(mask==3)] = 1;


    #Create Unique labels
    islands = np.zeros(np.shape(binary))              
    structure = np.ones((3, 3), dtype=int) 
    labeled, ncomponents = label(binary, structure)
    labeled = labeled*binary;

    #Count Neuromas
    count = 0;
    unique = np.unique(labeled)
    for unique_label in unique:
        if(unique_label!=0):
            if(np.shape(np.where(labeled ==unique_label))[1]>9):
                count = count + 1;    
    

    #Normalize by Area of SNP
    count = count * np.size(mask_AAR)/np.shape(np.where(mask_AAR==1))[1];
  
    return count;

Count Junctions

In [None]:
def count_junctions(mask, mask_AAR):
    
  
    #Remove Holes
    skel = skeletonize(mask==1);
    invert_skel = skel==0;
    invert_skel = np.array(morphology.remove_small_objects(invert_skel, 10)) 
    skel = np.abs(invert_skel-1)
    
    #Skeletonize
    skel = skeletonize(skel)

    #Remove Boundarys
    skel[0,:] = 0;
    skel[:,0] = 0;
    skel[-1,:] = 0;
    skel[:,-1] = 0;

    
    #Find Branch Points
    branchpoints = skeleton_branchpoints(skel);
    bp = np.array(np.where(branchpoints==1))
    
    #Iterate through Branch Points, Remove Unncessary Pixels 
    structure = np.ones((3, 3), dtype=int) 
    for j in range(0, np.shape(bp)[1],1):

        #Box Around Branch Point
        box = skel[int(bp[0,j])-1:int(bp[0,j])+2, int(bp[1,j])-1:int(bp[1,j])+2].copy()
        
        #Before Removing Center
        labeled, ncomponents = label(box, structure)
        
        #After Removing Center
        box[1,1] = 0;     
        labeled, ncomponents_new = label(box, structure)   
        
        #If Unchanged, Remove Center
        if(ncomponents ==ncomponents_new ):
            skel[int(bp[0,j]), int(bp[1,j])] = 0;


    #Find Endpoints        
    endpoints = skeleton_endpoints(skel);
    ep = np.array(np.where(endpoints==1));        
            
    #Separate Branches
    branchpoints = skeleton_branchpoints(skel);
    skel_sep = skel.copy();
    skel_sep[np.where(branchpoints==1)] = 0;     
    
    
    #Create Unique labels for Branches
    labeled, ncomponents = label(skel_sep, structure)
    labeled = labeled*skel_sep;
    unique = np.unique(labeled)
    
    
    #Iterate through Branches
    for unique_label in unique:
        if(unique_label!=0):
            
            #Current Branch
            current = np.zeros(np.shape(labeled));
            current[np.where(labeled==unique_label)] = 1;
            
            
            #If Current Branch Contains Endpoint
            summation= np.sum(np.multiply(current,endpoints))       
            if(np.shape(np.where(labeled==unique_label))[1]<10 and summation>0):
                
                #Remove Small Branches
                skel[np.where(labeled==unique_label)] = 0;


    #Remove  Spurs 
    branchpoints = skeleton_branchpoints(skel);  
    skel[np.where(branchpoints ==1)] = 0;    
    struct1 = ndimage.generate_binary_structure(2, 2)
    skel = ndimage.binary_dilation(skel , structure=struct1)
    skel = skeletonize(skel)


    #Count Branchpoints
    branchpoints = skeleton_branchpoints(skel); 
    labeled, ncomponents = label(branchpoints, structure)
    labeled = labeled*branchpoints;
    count = np.shape(np.unique(labeled))[0];

    #Normalize by Area of SNP
    count = count * np.size(mask_AAR)/np.shape(np.where(mask_AAR==1))[1];
       

    return count;

Nerve Density

In [None]:
def calculate_nerve_density(mask, mask_AAR):

    #Find Nerve
    nerve = np.shape(np.where(mask==1))[1]

    
    #Nerve Density
    total = np.size(mask)
    density = round(400*400*(nerve/total))

    #Normalize by Area of SNP
    density = density * np.size(mask_AAR)/np.shape(np.where(mask_AAR==1))[1];

    return density;


Nerve Thickness

In [None]:
def calculate_nerve_thickness(mask):
 
    #Skeletonized Nerve
    skel = skeletonize(mask==1);
    skel = np.shape(np.where(skel==1))[1];

    #Find Nerve 
    nerve = np.shape(np.where(mask==1))[1]


    #Nerve Thickness
    length = np.sqrt(np.size(mask));
    thickness = (nerve/skel) *(400/length);

    return thickness;

Nerve Tortuosity

In [None]:
def calculate_nerve_tortuosity(mask):
    
    
    skel = skeletonize(mask==1);
    branchpoints = skeleton_branchpoints(skel);
    skel[np.where(branchpoints==1)] = 0;

    #Initialize RMS Sum
    tau_C_Sum = 0;
    tau_L_Sum = 0;
    tau_CL_Sum = 0;

    #Iterate through objects of length greater than 1
    islands = np.zeros(np.shape(skel))              
    structure = np.ones((3, 3), dtype=int) 
    labeled, ncomponents = label(skel, structure)
    labeled = labeled*skel;
    unique = np.unique(labeled)
    for unique_label in unique:
        if(unique_label!=0):
            if(np.shape(np.where(labeled==unique_label))[1]>3):

                #Find Endpoints
                current_branch = np.zeros(np.shape(skel));
                current_branch[np.where(labeled==unique_label)] = 1;
                endpoints = skeleton_endpoints(current_branch);
                ep = np.where(endpoints==1);
                y_pos = ep[0];
                x_pos = ep[1]

                #Skip Issues
                if(np.shape(np.where(endpoints==1))[1]==0):
                    continue;
                

                #Find Straight-Line Distance
                Lx = np.sqrt(np.power(y_pos[0]-y_pos[1],2)+np.power(x_pos[0]-x_pos[1],2));

                #Pixels in Branch
                locs = np.where(current_branch==1);
                y_locs = locs[0];
                x_locs = locs[1];

                #Order Pixels in branch
                distances = np.sqrt(np.power(y_locs-y_pos[0],2) + np.power(x_locs-x_pos[0],2));
                idx = np.argsort(distances);
                y_locs = y_locs[idx];
                x_locs = x_locs[idx];



                Lc=0;
                tau_C=0;
                for a in range(1,np.shape(y_locs)[0]):
                    #Calculate First Differential
                    first_x = x_locs[a]-x_locs[a-1];
                    first_y = y_locs[a]-y_locs[a-1];


                    if a>1:
                      
                        #Calculate Second Differential
                        prev_x = x_locs[a-1]-x_locs[a-2];
                        prev_y = y_locs[a-1]-y_locs[a-2];
                        second_x = first_x-prev_x;
                        second_y = first_y-prev_y;
                        Ki = ((first_x*second_y)-(second_x*first_y))/np.power(np.power(first_x,2)+np.power(first_y,2),(3/2));
                    else:
                        Ki=0

                    #Calculate Various Measures
                    Lc = Lc + np.sqrt(np.power(first_x,2)+np.power(first_y,2));
                    tau_C = tau_C + abs(Ki);

                #Weight Different measures
                tau_L = Lc/Lx;
                tau_CL = tau_C/Lc;
                tau_C_Sum = tau_C_Sum + tau_C*np.shape(np.where(labeled==unique_label))[1];
                tau_L_Sum = tau_L_Sum + tau_L*np.shape(np.where(labeled==unique_label))[1];
                tau_CL_Sum = tau_CL_Sum + tau_CL*np.shape(np.where(labeled==unique_label))[1];



    #Normalize by Skel
    skel_count = np.shape(np.where(skel==1))
    skel_count = skel_count[1];  


    #Add to Tortuosity List
    return tau_C_Sum/skel_count;



Fix Mixture

In [None]:
def fix_mixture_1(prediction, class_1, class_2):

    
    # Class Locations
    binary = np.zeros(np.shape(prediction)) 
    binary[np.where(prediction==class_1)] = 1;
    binary[np.where(prediction==class_2)] = 1;
    
    #Create Labeled Islands
    islands = np.zeros(np.shape(prediction))              
    structure = np.ones((3, 3), dtype=int) 
    labeled, ncomponents = label(binary, structure)
    labeled = labeled*binary;
    
    #Iterate through Islands
    unique = np.unique(labeled)
    for unique_label in unique:
        if(unique_label!=0):
            
            #If more than one class in island
            if(np.unique(prediction[np.where(labeled==unique_label)]).size >1):
                
                                              
                #More abundant class dominates
                count_1 = np.shape(np.where(prediction[np.where(labeled==unique_label)]==class_1))[1];
                count_2 = np.shape(np.where(prediction[np.where(labeled==unique_label)]==class_2))[1];
                if( count_1>count_2):
                    prediction[np.where(labeled==unique_label)] = class_1
                else:
                    prediction[np.where(labeled==unique_label)] = class_2
                    
    return prediction;

In [None]:
def fix_mixture_2(prediction, class_1, class_2):

    
    # Class Locations
    binary = np.zeros(np.shape(prediction)) 
    binary[np.where(prediction==class_1)] = 1;
    binary[np.where(prediction==class_2)] = 1;
    
    #Create Labeled Islands
    islands = np.zeros(np.shape(prediction))              
    structure = np.ones((3, 3), dtype=int) 
    labeled, ncomponents = label(binary, structure)
    labeled = labeled*binary;
    
    #Iterate through Islands
    unique = np.unique(labeled)
    for unique_label in unique:
        if(unique_label!=0):
            
            #If more than one class in island
            if(np.unique(prediction[np.where(labeled==unique_label)]).size >1):
                
                                              
                #More abundant class dominates
                count_1 = np.shape(np.where(prediction[np.where(labeled==unique_label)]==class_1))[1];
                count_2 = np.shape(np.where(prediction[np.where(labeled==unique_label)]==class_2))[1];
                if( count_1<count_2):
                    prediction[np.where(labeled==unique_label)] = class_2
                    
    return prediction;

Fix Class within Class

In [None]:
def fix_inner(prediction):

    #Iterate through Classes
    for class_outer in range(1,4):

        # Class Locations
        binary = np.zeros(np.shape(prediction)) 
        binary[np.where(prediction==class_outer)] = 1;
        
        
        #Fill Holes
        binary = ndimage.binary_fill_holes(binary).astype(int)
        binary[np.where(prediction==class_outer)] = 0;
        
        #Create Labeled Islands
        structure = np.ones((3, 3), dtype=int) 
        labeled, ncomponents = label(binary, structure)
        labeled = labeled*binary;
        
        
        #Iterate through Islands
        unique = np.unique(labeled)
        for unique_label in unique:
            if(unique_label!=0):    
                
                #If it doesn't contain background
                if(0 not in np.unique(prediction[np.where(labeled==unique_label)])):
                    prediction[np.where(labeled==unique_label)] = class_outer;
    
    return prediction;

Remove Small Objects

In [None]:
def remove_small(prediction):

    #Iterate through Classes
    for current_class in range(1,4):

        # Class Locations
        binary = np.zeros(np.shape(prediction)) 
        binary[np.where(prediction==current_class)] = 1;
        
        
        #Create Labeled Islands
        islands = np.zeros(np.shape(prediction))              
        structure = np.ones((3, 3), dtype=int) 
        labeled, ncomponents = label(binary, structure)
        labeled = labeled*binary;

        #Iterate through Islands
        unique = np.unique(labeled)
        for unique_label in unique:
            if(unique_label!=0): 
                
                if(np.shape(np.where(labeled ==unique_label))[1]<8):
                    prediction[np.where(labeled==unique_label)] = 0;
    
    return prediction;

Test AAR-Net

In [None]:
def test_AAR_net(SNP, torch_path, Experiment_Name, Model_Name, tissues, channels):
    
 
    #Weight
    weight_kernel = fspecial_gauss(192, 20);
    
    #Load Models
    for g in range(0,5):

        validation = np.load(torch_path  +'Validation_Dice_'+  Model_Name +'_' + Experiment_Name + '_Group_' + str(g)  +'.npy').squeeze();       
        epoch = np.argmax(validation[0,40:]) + 40;


        #Model
        if(g==0):
            net_0 = AAR_Net(n_channels=channels, n_classes=tissues);
            net_0.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_0.to(device=device);  
            net_0.eval();
        elif(g==1):
            net_1 =AAR_Net(n_channels=channels, n_classes=tissues);
            net_1.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_1.to(device=device);  
            net_1.eval()        
        elif(g==2):
            net_2 =AAR_Net(n_channels=channels, n_classes=tissues);
            net_2.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_2.to(device=device);  
            net_2.eval()           
        elif(g==3):
            net_3 = AAR_Net(n_channels=channels, n_classes=tissues);
            net_3.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_3.to(device=device);  
            net_3.eval()  
        elif(g==4):
            net_4 = AAR_Net(n_channels=channels, n_classes=tissues); 
            net_4.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_4.to(device=device);  
            net_4.eval()  
 


    #Iterate through Images
    AAR_Masks = [];
    for idx in range(0, np.shape(SNP)[0]):


        #Select Image
        img = SNP[idx, :,:].copy();

        #Initialize
        collective = np.zeros((tissues,384,384));
        weight = np.zeros((tissues,384,384));
         
            
        #Iterate
        for row in range(0,200,8):
            for col in range(0,200,8):

                #Select Window
                cropped_img = img[row:row+192,col:col+192].copy();



                #Normalize Window
                cropped_img = (cropped_img - np.min(cropped_img))/(np.max(cropped_img) - np.min(cropped_img))
                cropped_img = cropped_img - 0.5;

                #Flip
                cropped_img_0 = cropped_img.copy();
                cropped_img_1 = np.fliplr(cropped_img).copy();
                cropped_img_2 = np.flipud(cropped_img).copy();
                cropped_img_3 = np.fliplr(np.flipud(cropped_img).copy()).copy();

                #Expand Dim
                cropped_img_0 = np.expand_dims(cropped_img_0,axis = 0);
                cropped_img_1 = np.expand_dims(cropped_img_1,axis = 0);
                cropped_img_2 = np.expand_dims(cropped_img_2,axis = 0);
                cropped_img_3 = np.expand_dims(cropped_img_3,axis = 0);

                #Concatenate
                cropped_img = np.concatenate((cropped_img_0, cropped_img_1, cropped_img_2, cropped_img_3), 0);
                cropped_img = np.reshape(cropped_img, (-1,1,192,192))


                #Pre-process the Input Image 
                input_img = torch.from_numpy(cropped_img).to(device=device, dtype=torch.float32);


                with torch.no_grad():

                    #Predict the Mask                         
                    output = net_0(input_img) + net_1(input_img) + net_2(input_img) + net_3(input_img) + net_4(input_img);


                    #Extract Outputs  
                    output = output.cpu().detach().numpy();
                    output_0 = output[0].squeeze();
                    output_1 = output[1].squeeze();
                    output_2 = output[2].squeeze();
                    output_3 = output[3].squeeze();                    


                    #Flip
                    output_1[0] = np.fliplr(output_1[0]).copy();
                    output_1[1] = np.fliplr(output_1[1]).copy();
                    output_1[2] = np.fliplr(output_1[2]).copy();
                    output_2[0] = np.flipud(output_2[0]).copy();
                    output_2[1] = np.flipud(output_2[1]).copy();
                    output_2[2] = np.flipud(output_2[2]).copy();
                    output_3[0] = np.fliplr(np.flipud(output_3[0]).copy()).copy();
                    output_3[1] = np.fliplr(np.flipud(output_3[1]).copy()).copy();
                    output_3[2] = np.fliplr(np.flipud(output_3[2]).copy()).copy();


                    #Multiply by Weight
                    mult_output_0 = np.multiply(softmax(output_0,axis = 0),weight_kernel);
                    mult_output_1 = np.multiply(softmax(output_1,axis = 0),weight_kernel);
                    mult_output_2 = np.multiply(softmax(output_2,axis = 0),weight_kernel);
                    mult_output_3 = np.multiply(softmax(output_3,axis = 0),weight_kernel);

                    #Add to Collective
                    collective[:,row:row+192,col:col+192] += mult_output_0 + mult_output_1 + mult_output_2 + mult_output_3;
                    weight[:,row:row+192,col:col+192] +=4*weight_kernel ;

        #Divide by Count
        collective = np.divide(collective, weight);
        max_mask = np.argmax(collective, axis = 0);
        
        #Append Masks
        AAR_Masks.append(max_mask);
        
    
    AAR_Masks = np.array(AAR_Masks)
    return AAR_Masks;

Test SNP-Net

In [None]:
def test_SNP_net(SNP, masks_AAR, g, t, torch_path, Experiment_Name, Model_Name, tissues, channels):
    
 


    #Correct Applanation Artifact Removal
    masks_AAR_binary = (masks_AAR==0);
    masks_AAR_binary = masks_AAR_binary *1.0;
    
    #Weight
    weight_kernel = fspecial_gauss(96, 5);
    
        
    #Initialize
    immune_counts = np.zeros((np.shape(SNP)[0], 1));
    neuroma_counts = np.zeros((np.shape(SNP)[0], 1));
    junction_counts = np.zeros((np.shape(SNP)[0], 1));
    nerve_densities = np.zeros((np.shape(SNP)[0], 1));
    nerve_thicknesses = np.zeros((np.shape(SNP)[0], 1));  
    nerve_tortuosities = np.zeros((np.shape(SNP)[0], 1));   
        
   
   
    #Load Network
    validation = np.load(torch_path  +'Validation_Dice_'+  Model_Name +'_' + Experiment_Name +'_Group_' + str(t) +'_' + str(g) + '.npy').squeeze();    
    epoch = np.argmax(validation[0,60:]) + 60;
    net = SNP_Net(n_channels=channels, n_classes=tissues);
    net.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(t) +'_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
    net.to(device=device);  
    net.eval();                        


    #Iterate through Images
    masks = [];
    for idx in range(0, np.shape(SNP)[0]):

        
        #Select Image
        img = SNP[idx, :,:].copy();
        mask_AAR = masks_AAR_binary[idx, :,:].copy();

        #Normalize
        img = img/np.max(img)
        img = img - 0.5;    

        #Initialize
        collective = np.zeros((tissues,384,384));
        weight = np.zeros((tissues,384,384));


        #Iterate through Patches
        for row in range(0,324,36):
            for col in range(0,324,36):

                for flip_idx in range(0,4):

                    #Select Window
                    cropped_img = img[row:row+96,col:col+96].copy();



                     #Flip Image
                    if(flip_idx >1):
                        cropped_img = np.fliplr(cropped_img).copy();
                    if(flip_idx ==1 or flip_idx ==3):
                        cropped_img = np.flipud(cropped_img).copy();               
                    cropped_img = np.expand_dims(cropped_img,axis = 0);

                    #Cast to Torch
                    input_img = torch.from_numpy(cropped_img).unsqueeze(0).to(device=device, dtype=torch.float32);


                    with torch.no_grad():

                        #Predict the Mask                                         
                        output = net(input_img);


                        output = output.cpu().detach().numpy().squeeze();

                        #Multiply with Gaussian Weight
                        mult_output = np.multiply(softmax(output,axis = 0),weight_kernel);   

                        #Flip
                        flip_output = mult_output.copy()
                        if(flip_idx >1):
                            mult_output[0] = np.fliplr(mult_output[0]).copy();
                            mult_output[1] = np.fliplr(mult_output[1]).copy();
                            mult_output[2] = np.fliplr(mult_output[2]).copy();
                            mult_output[3] = np.fliplr(mult_output[3]).copy();
                        if(flip_idx ==1 or flip_idx ==3):
                            mult_output[0] = np.flipud(mult_output[0]).copy();
                            mult_output[1] = np.flipud(mult_output[1]).copy();
                            mult_output[2] = np.flipud(mult_output[2]).copy();
                            mult_output[3] = np.flipud(mult_output[3]).copy();

                        #Add to Collective
                        collective[:,row:row+96,col:col+96] += mult_output
                        weight[:,row:row+96,col:col+96] +=weight_kernel ;


        #Divide by Count
        collective = np.divide(collective, weight);
        prediction = np.argmax(collective, axis = 0);


        #Applanatin Artifact Removal
        prediction = prediction*mask_AAR;


        #Fix Mixture
        prediction = fix_mixture_1(prediction, 2, 3);
        prediction = fix_mixture_2(prediction, 1, 2);

        #Fix Classes within Class
        prediction = fix_inner(prediction);


        #Remove Small Objects
        prediction = remove_small(prediction);

        #Append Masks
        masks.append(prediction)
        
        #Clinical Metrics: Automatic
        immune_counts[idx,0] = count_immune_cells(prediction, mask_AAR);
        neuroma_counts[idx,0] =  count_neuromas(prediction, mask_AAR); 
        junction_counts[idx,0] =  count_junctions(prediction, mask_AAR); 
        nerve_densities[idx,0] =  calculate_nerve_density(prediction, mask_AAR);
        nerve_thicknesses[idx,0] =  calculate_nerve_thickness(prediction);
        nerve_tortuosities[idx,0] =  calculate_nerve_tortuosity(prediction); 

        
    
 
    masks = np.array(masks);
    return  masks, immune_counts, neuroma_counts, junction_counts, nerve_densities, nerve_thicknesses, nerve_tortuosities;

Initialize

In [None]:
#Cast to Cuda
CUDA_VISIBLE_DEVICES=0
cuda = True if torch.cuda.is_available() else False
if cuda:
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');


Main

In [None]:

#Specify Data Path and Save Path
img_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Raw/New_Data/'

# ---------------------
#  Load Images
# ---------------------


#Iterate
images = [];
files = sorted(os.listdir(img_path))
for fname in files:
    
    if(fname[0]=='.'):
        continue;
   
    #Load Images
    image = Image.open(os.path.join(img_path,fname)).convert('L');     
    image = np.array(image.resize((384,384), PIL.Image.BILINEAR)).astype(np.float32)
    images.append(image);
    
SNP = np.array(images);



In [None]:
# ---------------------
#  Test AAR-Net
# ---------------------

#AAR-Net Parameters
Experiment_Name = 'Original';
Model_Name = 'AAR_Net';
tissues = 3;
channels = 1;
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/AAR-Net/';


#Test AAR-Net
masks_AAR = test_AAR_net(SNP, torch_path, Experiment_Name, Model_Name, tissues, channels)


# ---------------------
#  Test SNP-Net
# ---------------------

#Parameters
Experiment_Name = 'Original';
Model_Name = 'SNP_Net';
tissues = 4;
channels = 1;
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/SNP-Net/';

#Best SNP-Net Model on all 207 SNP Images
t = 11;
g = 1;


        
#Test SNP-Net
masks, immune_counts, neuroma_counts, junction_counts, nerve_densities, nerve_thicknesses, nerve_tortuosities = test_SNP_net(SNP, masks_AAR, g, t, torch_path, Experiment_Name, Model_Name, tissues, channels)
        
        


Save Metrics

In [None]:
excel_path =  r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Excel/';


#Initialize
current = datetime.now().strftime("%m_%d_%Y-%I_%M_%S_%p")
workbook = xlsxwriter.Workbook(excel_path + 'clinical_metrics_'+ current +'.xlsx')
worksheet = workbook.add_worksheet()

#Headers
worksheet.write('A1', 'File Name')
worksheet.write('B1', 'Nerve Density (um2/160,000 um2)')
worksheet.write('C1', 'Average Nerve Thickness (um)')
worksheet.write('D1', 'Average Nerve Segment Tortuosity')
worksheet.write('E1', 'Junction Point Density (count/160,000 um2)')
worksheet.write('F1', 'Neuroma Density (count/160,000 um2)')
worksheet.write('G1', 'Immune Cell Density (count/160,000 um2)')


idx = -1;
for fname in files:
    
    if(fname[0]=='.'):
        continue;
    else:
        idx = idx +1;
        
    nerve_density = round(nerve_densities[idx,0]);
    nerve_thickness = round(nerve_thicknesses[idx,0],1);
    nerve_tortuosity = round(nerve_tortuosities[idx,0],1);
    junction_count = round(junction_counts[idx,0],1);
    neuroma_count = round(neuroma_counts[idx,0],1);
    immune_count = round(immune_counts[idx,0],1);
    
    
    worksheet.write('A' + str(idx+2), fname)
    worksheet.write('B' + str(idx+2), str(nerve_density))
    worksheet.write('C' + str(idx+2), str(nerve_thickness))
    worksheet.write('D' + str(idx+2), str(nerve_tortuosity))
    worksheet.write('E' + str(idx+2), str(junction_count))
    worksheet.write('F' + str(idx+2), str(neuroma_count))
    worksheet.write('G' + str(idx+2), str(immune_count))
    
    

workbook.close()
print('All Clinical Metrics inserted successfully!')

Plot and Save Results`

In [None]:
save_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Images/New_Data/'


#Parameters
Plot_Figures = True;
Save_Figures = True;


idx = -1;
for fname in files:
    
    if(fname[0]=='.'):
        continue;
    else:
        idx = idx +1; 


    mask_AAR = masks_AAR[idx,:,:].squeeze();
    mask = masks[idx,:,:].squeeze();
    img = SNP[idx,:,:].squeeze();
    
    
    #Plot
    if(Plot_Figures):

        #Plot Original Image
        plt.figure;
        plt.title('Image')
        plt.imshow(img,cmap='gray')
        plt.show()        


    #Convert Mask to Color
    color_new_mask_1 = np.zeros((384,384));
    color_new_mask_2 = np.zeros((384,384));
    color_new_mask_3 = np.zeros((384,384));
    color_new_mask_1[np.where(mask_AAR==0)] = 255;
    color_new_mask_2[np.where(mask_AAR==2)] = 255;
    color_new_mask_3[np.where(mask_AAR==1)] = 255;
    color_new_mask = np.zeros((384,384,3));
    color_new_mask[:,:,0] = color_new_mask_1;
    color_new_mask[:,:,1] = color_new_mask_3
    color_new_mask[:,:,2] = color_new_mask_2;
    color_new_mask = np.array(color_new_mask, dtype = 'uint8')

    #Plot
    if(Plot_Figures):

        #Plot Mask
        plt.figure;
        plt.title('AAR-Net Segmentation')
        plt.imshow(color_new_mask)
        plt.show()  

        
    if(Save_Figures):
        
        #Fix fname
        new_fname = fname;
        new_fname = new_fname.replace('.png','');
        new_fname = new_fname.replace('.jpg','');
        new_fname = new_fname.replace('.gif','');
        
        
        #Save
        new_fname = new_fname + '_AAR_Seg.jpg';
        save_img = Image.fromarray(color_new_mask);
        save_img.save(save_path + new_fname)                   
       
    

    #Convert Mask to Color
    color_new_mask_1 = np.zeros((384,384));
    color_new_mask_2 = np.zeros((384,384));
    color_new_mask_3 = np.zeros((384,384));
    color_new_mask_1[np.where(mask==1)] = 255;
    color_new_mask_2[np.where(mask==3)] = 255;
    color_new_mask_3[np.where(mask==2)] = 255;
    color_new_mask = np.zeros((384,384,3));
    color_new_mask[:,:,0] = color_new_mask_1;
    color_new_mask[:,:,1] = color_new_mask_3
    color_new_mask[:,:,2] = color_new_mask_2;
    color_new_mask = np.array(color_new_mask, dtype = 'uint8')

    #Plot
    if(Plot_Figures):  

        #Plot Mask
        plt.figure;
        plt.title('SNP-Net Segmentation')
        plt.imshow(color_new_mask)
        plt.show()

        
    if(Save_Figures):
        
        #Fix fname
        new_fname = new_fname.replace('_AAR_Seg.jpg','');

         
        #Save
        new_fname = new_fname + '_SNP_Seg.jpg';
        save_img = Image.fromarray(color_new_mask);
        save_img.save(save_path + new_fname)   
  