Import

In [None]:
#Import Python Packages
import copy
from functools import reduce
from glob import glob
from glob import iglob
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
from operator import __add__
import os
from os.path import splitext
from os import listdir
import pandas as pd;
import PIL
from PIL import Image, ImageEnhance
from PIL import Image, ImageOps
from random import sample
from random import randint
import random;
from scipy import stats
from scipy import ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import label
from scipy.special import softmax
from skimage import exposure
from skimage import feature
from skimage import transform as tf
import sys
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn.modules import Conv2d, Module
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from typing import Any
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

Neural Network

In [None]:
class GaborConv2d(Module):
#    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=False, padding_mode="zeros"):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=48, dilation=1,groups=1,bias=False, padding_mode="reflect"): 
        super().__init__()
        self.is_calculated = False

        self.conv_layer = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.kernel_size = self.conv_layer.kernel_size

        # small addition to avoid division by zero
        self.delta = 1e-3

        #Frequency
        self.freq = Parameter(
            (math.pi / 2)
            * math.sqrt(2)
            ** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor),
            requires_grad=True,
        )
        #Theta
        self.theta = Parameter(
            (math.pi / 8)
            * torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor),
            requires_grad=True,
        )
        #Sigma
        self.sigma = Parameter(math.pi / self.freq, requires_grad=True)

        #Psi
        self.psi = Parameter(
            math.pi * torch.rand(out_channels, in_channels), requires_grad=True
        )

        self.x0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False
        )
        self.y0 = Parameter(
            torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False
        )

        self.y, self.x = torch.meshgrid(
            [
                torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]),
                torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]),
            ]
        )
        self.y = Parameter(self.y.clone());
        self.x = Parameter(self.x.clone());

        self.weight = Parameter(
            torch.empty(self.conv_layer.weight.shape, requires_grad=True),
            requires_grad=True,
        )

        self.register_parameter("freq", self.freq)
        self.register_parameter("theta", self.theta)
        self.register_parameter("sigma", self.sigma)
        self.register_parameter("psi", self.psi)
        self.register_parameter("x_shape", self.x0)
        self.register_parameter("y_shape", self.y0)
        self.register_parameter("y_grid", self.y)
        self.register_parameter("x_grid", self.x)
        self.register_parameter("weight", self.weight)

    def forward(self, input_tensor):
        if self.training:
            self.calculate_weights()
            self.is_calculated = False
        if not self.training:
            if not self.is_calculated:
                self.calculate_weights()
                self.is_calculated = True
        return self.conv_layer(input_tensor)

    def calculate_weights(self):
        for i in range(self.conv_layer.out_channels):
            for j in range(self.conv_layer.in_channels):
                sigma = self.sigma[i, j].expand_as(self.y) 
                freq = self.freq[i, j].expand_as(self.y) 
                theta = self.theta[i, j].expand_as(self.y) 
                psi = self.psi[i, j].expand_as(self.y) 

                rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta)
                roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta)

                g = torch.exp(
                    -0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2)
                )
                g = g * torch.cos(freq * rotx + psi)
                g = g / (2 * math.pi * sigma ** 2)
                self.conv_layer.weight.data[i, j] = g

    def _forward_unimplemented(self, *inputs: Any):
        """
        code checkers makes implement this method,
        looks like error in PyTorch
        """
        raise NotImplementedError
        
        






class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class filt_cat(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

    def forward(self, x1, x2):     

        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x;



class AAR_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(AAR_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear


        self.g0 = GaborConv2d(n_channels, n_channels, kernel_size=(144, 144))

        
        self.fc = filt_cat(n_channels, 2*n_channels)
        self.inc = DoubleConv(2*n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        f = self.g0(x)
        x0 = self.fc(f,x);
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        m = nn.Softmax(dim=1)
        logits = m(x)
        
        return logits



Gaussian Filter

In [None]:
def fspecial_gauss(size, sigma):

    """Function to mimic the 'fspecial' gaussian MATLAB function
    """

    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
    g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g


Test Network

In [None]:
def test_net(python_path, torch_path, img_path, Experiment_Name, Model_Name, tissues, channels, Plot_Figures, Save_Figures):
    
 

    #Cast to Cuda
    cuda = True if torch.cuda.is_available() else False
    if cuda:
        FloatTensor = torch.cuda.FloatTensor
        LongTensor = torch.cuda.LongTensor
    else:
        FloatTensor = torch.FloatTensor
        LongTensor = torch.LongTensor

    #Attempt to use GPU instead of CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');


    #Load Numpy Files
    SNP = np.load(python_path + 'SNP_Net_Images.npy');

    #Weight
    weight_kernel = fspecial_gauss(192, 20);
    
    #Load Models
    for g in range(0,5):

        validation = np.load(torch_path  +'Validation_Dice_'+  Model_Name +'_' + Experiment_Name + '_Group_' + str(g)  +'.npy').squeeze();       
        epoch = np.argmax(validation[0,40:]) + 40;


        #Model
        if(g==0):
            net_0 = AAR_Net(n_channels=channels, n_classes=tissues);
            net_0.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_0.to(device=device);  
            net_0.eval();
        elif(g==1):
            net_1 =AAR_Net(n_channels=channels, n_classes=tissues);
            net_1.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_1.to(device=device);  
            net_1.eval()        
        elif(g==2):
            net_2 =AAR_Net(n_channels=channels, n_classes=tissues);
            net_2.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_2.to(device=device);  
            net_2.eval()           
        elif(g==3):
            net_3 = AAR_Net(n_channels=channels, n_classes=tissues);
            net_3.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_3.to(device=device);  
            net_3.eval()  
        elif(g==4):
            net_4 = AAR_Net(n_channels=channels, n_classes=tissues); 
            net_4.load_state_dict(torch.load(torch_path + Model_Name + '_' + Experiment_Name + '_Group_' + str(g) + '_Epoch_' +str(epoch+1) +'.pth', map_location=device));
            net_4.to(device=device);  
            net_4.eval()  
 


    #Iterate through Images
    AAR_Masks = [];
    for idx in range(0, np.shape(SNP)[0]):


        #Select Image
        img = SNP[idx, :,:].copy();

        #Initialize
        collective = np.zeros((tissues,384,384));
        weight = np.zeros((tissues,384,384));
         
            
        #Iterate
        for row in range(0,200,8):
            for col in range(0,200,8):

                #Select Window
                cropped_img = img[row:row+192,col:col+192].copy();



                #Normalize Window
                cropped_img = (cropped_img - np.min(cropped_img))/(np.max(cropped_img) - np.min(cropped_img))
                cropped_img = cropped_img - 0.5;

                #Flip
                cropped_img_0 = cropped_img.copy();
                cropped_img_1 = np.fliplr(cropped_img).copy();
                cropped_img_2 = np.flipud(cropped_img).copy();
                cropped_img_3 = np.fliplr(np.flipud(cropped_img).copy()).copy();

                #Expand Dim
                cropped_img_0 = np.expand_dims(cropped_img_0,axis = 0);
                cropped_img_1 = np.expand_dims(cropped_img_1,axis = 0);
                cropped_img_2 = np.expand_dims(cropped_img_2,axis = 0);
                cropped_img_3 = np.expand_dims(cropped_img_3,axis = 0);

                #Concatenate
                cropped_img = np.concatenate((cropped_img_0, cropped_img_1, cropped_img_2, cropped_img_3), 0);
                cropped_img = np.reshape(cropped_img, (-1,1,192,192))


                #Pre-process the Input Image 
                input_img = torch.from_numpy(cropped_img).to(device=device, dtype=torch.float32);


                with torch.no_grad():

                    #Predict the Mask                         
                    output = net_0(input_img) + net_1(input_img) + net_2(input_img) + net_3(input_img) + net_4(input_img);


                    #Extract Outputs  
                    output = output.cpu().detach().numpy();
                    output_0 = output[0].squeeze();
                    output_1 = output[1].squeeze();
                    output_2 = output[2].squeeze();
                    output_3 = output[3].squeeze();                    


                    #Flip
                    output_1[0] = np.fliplr(output_1[0]).copy();
                    output_1[1] = np.fliplr(output_1[1]).copy();
                    output_1[2] = np.fliplr(output_1[2]).copy();
                    output_2[0] = np.flipud(output_2[0]).copy();
                    output_2[1] = np.flipud(output_2[1]).copy();
                    output_2[2] = np.flipud(output_2[2]).copy();
                    output_3[0] = np.fliplr(np.flipud(output_3[0]).copy()).copy();
                    output_3[1] = np.fliplr(np.flipud(output_3[1]).copy()).copy();
                    output_3[2] = np.fliplr(np.flipud(output_3[2]).copy()).copy();


                    #Multiply by Weight
                    mult_output_0 = np.multiply(softmax(output_0,axis = 0),weight_kernel);
                    mult_output_1 = np.multiply(softmax(output_1,axis = 0),weight_kernel);
                    mult_output_2 = np.multiply(softmax(output_2,axis = 0),weight_kernel);
                    mult_output_3 = np.multiply(softmax(output_3,axis = 0),weight_kernel);

                    #Add to Collective
                    collective[:,row:row+192,col:col+192] += mult_output_0 + mult_output_1 + mult_output_2 + mult_output_3;
                    weight[:,row:row+192,col:col+192] +=4*weight_kernel ;

        #Divide by Count
        collective = np.divide(collective, weight);
        max_mask = np.argmax(collective, axis = 0);
        
        #Save Masks
        AAR_Masks.append(max_mask);
        np.save(python_path  + 'AAR_Net_Masks.npy', np.array(AAR_Masks));
        
        
        if(Plot_Figures):
            
            #Plot Original Image
            plt.figure;
            plt.imshow(img,cmap='gray')
            plt.show()        

        
        if(Plot_Figures or Save_Figures):
        
            #Convert Mask to Color
            color_new_mask_1 = np.zeros((384,384));
            color_new_mask_2 = np.zeros((384,384));
            color_new_mask_3 = np.zeros((384,384));
            color_new_mask_1[np.where(max_mask==0)] = 255;
            color_new_mask_2[np.where(max_mask==2)] = 255;
            color_new_mask_3[np.where(max_mask==1)] = 255;
            color_new_mask = np.zeros((384,384,3));
            color_new_mask[:,:,0] = color_new_mask_1;
            color_new_mask[:,:,1] = color_new_mask_3
            color_new_mask[:,:,2] = color_new_mask_2;
            color_new_mask = np.array(color_new_mask, dtype = 'uint8')

        if(Plot_Figures):
            
            #Plot Mask
            plt.figure;
            plt.imshow(color_new_mask)
            plt.show()
        
        
        if(Save_Figures):
            fname = 'AAR_' + str(idx) + '.jpg';
            save_img = Image.fromarray(color_new_mask);
            save_img.save(img_path + fname)               
    
    
    AAR_Masks = np.array(AAR_Masks)
    return AAR_Masks;

Main

In [None]:
#Parameters
Experiment_Name = 'Original';
Model_Name = 'AAR_Net';
tissues = 3;
channels = 1;
Plot_Figures = True;
Save_Figures = True;

#Path
python_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Python/SNP-Net/';
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/AAR-Net/';
img_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Images/AAR-Net/';

#Test AAR-Net
AAR_Masks = test_net(python_path, torch_path, img_path, Experiment_Name, Model_Name, tissues, channels, Plot_Figures, Save_Figures);

Number of Applanation Artifact

In [None]:
               
#Correct Applanation Artifact Removal
masks_AAR = np.load(python_path  + 'AAR_Net_Masks.npy');  
masks_AAR = (masks_AAR==0);
masks_AAR = masks_AAR *1.0;



#Count
count = 0;
for i in range(0, np.shape(masks_AAR)[0]):

    mask_AAR = masks_AAR[i];
    if(np.sum(mask_AAR) > 0.95*np.size(mask_AAR)):
        continue;

    else:
        count = count + 1;

print('Number of SNP Images with Applanation Artifacts: ' + str(count))