In [1]:
import torch
import tensorflow as tf
import scipy as sc
import numpy as np
from PIL import Image
import os
import re
import torch.nn.functional as F
import cv2
import numpy as np
import time
from scipy.io import loadmat, savemat
import torch.nn as nn
from tensorflow.keras.optimizers import Adam
from torch.autograd import Variable
from torchvision.models import resnet
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import matplotlib.pyplot as plt

In [5]:
model_preference = input("Type 'LP' for LinkNet + PSPNet segmentation \nType 'PP' for PSPNet + PSPNet segmentation \nType 'V' for vessels segmentation\n")

Type 'LP' for LinkNet + PSPNet segmentation 
Type 'PP' for PSPNet + PSPNet segmentation 
Type 'V' for vessels segmentation
 V


In [6]:
type_of_segmentation = input("Type 'all' to perform segmentation for all images \nType name of image for single image segmentation\n")

Type 'all' to perform segmentation for all images 
Type name of image for single image segmentation
 vesselsA


# Segmentation pipeline

In [8]:
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# torch.cuda.empty_cache()
transform = transforms.ToTensor()
# device = torch.device("cuda:0")
# torch.set_default_device(device)

# #torch.cuda.set_per_process_memory_fraction(0.5, device=device)

# print(torch.cuda.get_device_properties(device).multi_processor_count)
# torch.cuda.memory_allocated(device=torch.device("cuda"))
# torch.cuda.memory_summary(device=torch.device("cuda"), abbreviated=False)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, padding=0, stride=1, dilation=1, bias=False):
        super(ConvBlock, self).__init__()
        padding = (kernel_size + (kernel_size - 1) * (dilation - 1)) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.convInst = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias),
            nn.ReLU()
        )

    def forward(self, x):
        if x.shape[-2] == 1 and x.shape[-1]==1:
             out = self.convInst(x)
        else:
            out = self.conv(x)
        return out
    
    def upsample(input, size=None, scale_factor=None, align_corners=False):
        out = F.interpolate(input, size=size, scale_factor=scale_factor, mode='bilinear', align_corners=align_corners)
        return out
        
class PyramidPooling(nn.Module):
    def __init__(self, in_channels):
        super(PyramidPooling, self).__init__()
        self.pooling_size = [1, 2, 3, 6]
        self.channels = in_channels // 4
        
        self.pool1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(self.pooling_size[0]),
            ConvBlock(in_channels, self.channels, kernel_size=1),
        )

        self.pool2 = nn.Sequential(
            nn.AdaptiveAvgPool2d(self.pooling_size[1]),
            ConvBlock(in_channels, self.channels, kernel_size=1),
        )

        self.pool3 = nn.Sequential(
            nn.AdaptiveAvgPool2d(self.pooling_size[2]),
            ConvBlock(in_channels, self.channels, kernel_size=1),
        )

        self.pool4 = nn.Sequential(
            nn.AdaptiveAvgPool2d(self.pooling_size[3]),
            ConvBlock(in_channels, self.channels, kernel_size=1),
        )

    def forward(self, x):
        out1 = self.pool1(x)
        out1 = upsample(out1, size=x.size()[-2:])

        out2 = self.pool2(x)
        out2 = upsample(out2, size=x.size()[-2:])

        out3 = self.pool3(x)
        out3 = upsample(out3, size=x.size()[-2:])

        out4 = self.pool4(x)
        out4 = upsample(out4, size=x.size()[-2:])

        out = torch.cat([x, out1, out2, out3, out4], dim=1)
        return out
        
def upsample(input, size=None, scale_factor=None, align_corners=False):
    out = F.interpolate(input, size=size, scale_factor=scale_factor, mode='bilinear', align_corners=align_corners)
    return out

class PSPNet(nn.Module):
    def __init__(self, n_classes=64, n_out_classes=3):
        super(PSPNet, self).__init__()
        self.out_channels = 512  #2048

        self.backbone = resnet.resnet34(pretrained=True)
        self.stem = nn.Sequential(
            *list(self.backbone.children())[:4],
        )
        self.block1 = self.backbone.layer1
        self.block2 = self.backbone.layer2
        self.block3 = self.backbone.layer3
        self.block4 = self.backbone.layer4
        #self.low_level_features_conv = ConvBlock(64, 64, kernel_size=3)

        self.depth = self.out_channels // 4
        self.pyramid_pooling = PyramidPooling(self.out_channels)
        
        self.decoder = nn.Sequential(
            ConvBlock(self.out_channels * 2, self.depth, kernel_size=3),
            nn.Dropout(0.1),
            nn.Conv2d(self.depth, n_out_classes, kernel_size=1),
        )

        self.aux = nn.Sequential(
            ConvBlock(self.out_channels // 2, self.depth // 2, kernel_size=3),
            nn.Dropout(0.1),
            nn.Conv2d(self.depth // 2, n_classes, kernel_size=1),
        )

        self.sigm = nn.Sigmoid()
        self.sftmax = nn.Softmax()
        
    def forward(self, x, label=None):
        out = self.stem(x)
        out1 = self.block1(out)
        out2 = self.block2(out1)
        out3 = self.block3(out2)
        out4 = self.block4(out3)
        
        out = self.pyramid_pooling(out4)
        out = self.decoder(out)
        out = upsample(out, size=x.size()[-2:])
        out = upsample(out, size=x.shape[-2:], align_corners=True)
        return out

class BasicBlock(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=False, dropout_rate = 0.0):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=bias)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size, 1, padding, groups=groups, bias=bias)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.dropout = nn.Dropout2d(dropout_rate)
        self.downsample = None
        if stride > 1:
            self.downsample = nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
                            nn.BatchNorm2d(out_planes),)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        out = self.dropout(out)
        return out

class Encoder(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=False, dropout_rate = 0):
        super(Encoder, self).__init__()
        self.block1 = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, bias, dropout_rate)
        self.block2 = BasicBlock(out_planes, out_planes, kernel_size, 1, padding, groups, bias, dropout_rate)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)

        return x

class Decoder(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False):
        # TODO bias=True
        super(Decoder, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_planes, in_planes//4, 1, 1, 0, bias=bias),
                                nn.BatchNorm2d(in_planes//4),
                                nn.ReLU(inplace=True),)
        self.tp_conv = nn.Sequential(nn.ConvTranspose2d(in_planes//4, in_planes//4, kernel_size, stride, padding, output_padding, bias=bias),
                                nn.BatchNorm2d(in_planes//4),
                                nn.ReLU(inplace=True),)
        self.conv2 = nn.Sequential(nn.Conv2d(in_planes//4, out_planes, 1, 1, 0, bias=bias),
                                nn.BatchNorm2d(out_planes),
                                nn.ReLU(inplace=True),)

    def forward(self, x):
        x = self.conv1(x)
        x = self.tp_conv(x)
        x = self.conv2(x)

        return x

class LinkNetBase(nn.Module):
    """
    Generate model architecture
    """

    def __init__(self, n_classes=1, num_channels = 1):
        """
        Model initialization
        :param x_n: number of input neurons
        :type x_n: int
        """
        super(LinkNetBase, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, 2, 1)
        
        self.encoder1 = Encoder(64, 64, 3, 1, 1)
        self.encoder2 = Encoder(64, 128, 3, 2, 1)
        self.encoder3 = Encoder(128, 256, 3, 2, 1)
        #self.encoder3 = Encoder(128, 256, 3, 2, 1, 1, False, 0.1)
        self.encoder4 = Encoder(256, 512, 3, 2, 1)
        self.encoder5 = Encoder(512, 1024, 3, 2, 1)
        #self.encoder5 = Encoder(512, 1024, 3, 2, 1, 1, False, 0.5)
        
        self.decoder1 = Decoder(64, 64, 3, 1, 1, 0)
        self.decoder2 = Decoder(128, 64, 3, 2, 1, 1)
        self.decoder3 = Decoder(256, 128, 3, 2, 1, 1)
        self.decoder4 = Decoder(512, 256, 3, 2, 1, 1)
        self.decoder5 = Decoder(1024, 512, 3, 2, 1, 1)

        # Classifier
        self.tp_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
                                      nn.BatchNorm2d(32),
                                      nn.ReLU(inplace=True),)
        self.conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),
                                nn.BatchNorm2d(32),
                                nn.ReLU(inplace=True),)
        self.tp_conv2 = nn.ConvTranspose2d(32, n_classes, 2, 2, 0)
        
        self.conv3 =  nn.Sequential(nn.Conv2d(16, 16, 2, 2, 0),
                                nn.BatchNorm2d(16),
                                nn.ReLU(inplace=True),)
        self.tp_conv3 = nn.ConvTranspose2d(16, n_classes, 2, 2, 0)

        
        self.lsm = nn.LogSoftmax(dim=1) #Sigmoid/SoftMax
        self.sigm = nn.Sigmoid()
        self.sftmax = nn.Softmax(dim=1)
    
    # def conv_to_match_channels(self, input_tensor, target_tensor):
    #     # Apply a convolution to match the number of channels in input_tensor with target_tensor
    #     num_channels_input = input_tensor.shape[1]
    #     num_channels_target = target_tensor.shape[1]
    
    #     if num_channels_input < num_channels_target:
    #         # Apply a convolution to increase the number of channels
    #         conv_layer = nn.Conv2d(num_channels_input, num_channels_target, kernel_size=1, padding=0)
    #     elif num_channels_input > num_channels_target:
    #         # Apply a convolution to decrease the number of channels
    #         conv_layer = nn.Conv2d(num_channels_input, num_channels_target, kernel_size=1, padding=0)
    #     else:
    #         # No need to change channels
    #         return input_tensor
    
    #     return F.relu(conv_layer(input_tensor))
            
    def forward(self, x):
        # Initial block
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Encoder blocks
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        
        # Decoder blocks
        d4 = e3 + self.decoder4(e4)
        d3 = e2 + self.decoder3(d4)
        d2 = e1 + self.decoder2(d3)
        d1 = x + self.decoder1(d2)

        # Classifier
        y = self.tp_conv1(d1)
        y = self.conv2(y)
        y = self.tp_conv2(y)
        y = self.sigm(y)
        
        return y

folder = rf'{os.getcwd()}' #r'../RetinalSegmentation'

linknet_model_disk = torch.load(r''+folder+'/Models/LinkNet_Disk/LinkNet_Disk.pt', map_location=torch.device('cpu'))
pspnet_model_cup = torch.load(r''+folder+'/Models/PSPNet_Cup/PSPNet_resnet34_pretrained_combined_255.pt', map_location=torch.device('cpu'))
pspnet_model_disk = torch.load(r''+folder+'/Models/PSPNet_Disk/PSPNet_resnet34_pretrained.pt', map_location=torch.device('cpu'))
vessels_model = torch.load(r''+folder+'/Models/LinkNet_Veins/LinkNet_Veins.pt', map_location=torch.device('cpu'))

test_folder = folder + '/Images'
test_results_folder = folder + '/Segmented_Masks'
test_results_overlayed_folder = folder + '/Overlayed_Segmented_Masks'

test_files = os.listdir(test_folder)

images_test = []
for file in test_files:
    images_test.append(test_folder + "/" + file)

threshold = 0.5
if model_preference == 'LP':
    if type_of_segmentation == 'all':
        linknet_model_disk.eval()
        pspnet_model_cup.eval()
        with torch.no_grad():
            for image in images_test:
                filename = os.path.basename(image)
                image = Image.open(image)
                image_disk = image.resize((512, 512), resample=Image.NEAREST)
                image_disk = image_disk.convert('L')
                image_disk = transform(image_disk)
                image_disk = image_disk.unsqueeze(0)

                output_disk = linknet_model_disk(image_disk)
                
                output_disk = output_disk.detach().numpy()[0]
                output_disk = (output_disk > threshold)
                output_disk = output_disk.astype(np.uint8)
                output_disk = np.transpose(output_disk, (1, 2, 0))
                
                output_disk = Image.fromarray(output_disk.squeeze())
                linknet_disk = output_disk.resize((256, 256), resample=Image.NEAREST)
                linknet_disk = np.array(linknet_disk)
        
                image_cup = image.resize((256, 256), resample=Image.NEAREST)
                image_cup = transform(image_cup)
                image_cup = image_cup.unsqueeze(0)
                
                #image_cup = image_cup.to(device) 
                output_cup = pspnet_model_cup(image_cup)
                
                output_cup = output_cup.detach().numpy()[0]
                output_cup = (output_cup > threshold)
                
                output_cup = np.transpose(output_cup, (1, 2, 0))
                pspnet_cup = output_cup[:,:,2]
                pspnet_cup = pspnet_cup.astype(np.uint8)
        
                merged_mask = np.zeros_like(linknet_disk)
                merged_mask[linknet_disk == 1] = 128
                merged_mask[pspnet_cup == 1] = 255

                overlay_image_predict = np.copy(image.resize((256, 256), resample=Image.NEAREST))
                overlay_image_predict[linknet_disk == 1] = 128
                overlay_image_predict[pspnet_cup == 1] = 255

                cv2.imwrite(test_results_folder+"/"+filename, merged_mask)
                cv2.imwrite(test_results_overlayed_folder+"/"+filename, overlay_image_predict)

    else:
        orig_image = Image.open(test_folder + "/" + type_of_segmentation + ".jpg")
        
        linknet_model_disk.eval()
        image = orig_image.resize((512, 512), resample=Image.NEAREST)
        image = image.convert('L')
        image = transform(image)
        image = image.unsqueeze(0)
        
        output = linknet_model_disk(image)
        
        output = output.detach().numpy()[0]
        output = (output > threshold)
        output = output.astype(np.uint8)
        output = np.transpose(output, (1, 2, 0))
        
        output = Image.fromarray(output.squeeze())
        linknet_disk = output.resize((256, 256), resample=Image.NEAREST)
        linknet_disk = np.array(linknet_disk)

        pspnet_model_cup.eval()
        image = orig_image.resize((256, 256), resample=Image.NEAREST)
        #image = image.convert('L')
        image = transform(image)
        image = image.unsqueeze(0)
        
        #image = image.to(device) 
        output = pspnet_model_cup(image)
        
        output = output.detach().numpy()[0]
        output = (output > threshold)
        output = np.transpose(output, (1, 2, 0))
        pspnet_cup = output[:,:,2]
        pspnet_cup = pspnet_cup.astype(np.uint8)

        merged_mask = np.zeros_like(linknet_disk)
        merged_mask[linknet_disk == 1] = 128
        merged_mask[pspnet_cup == 1] = 255

        overlay_image_predict = np.copy(orig_image.resize((256, 256), resample=Image.NEAREST))
        overlay_image_predict[linknet_disk == 1] = 128
        overlay_image_predict[pspnet_cup == 1] = 255
        plt.imshow(overlay_image_predict)
        cv2.imwrite(test_results_folder+"/"+type_of_segmentation+".jpg", merged_mask)
        cv2.imwrite(test_results_overlayed_folder+"/"+type_of_segmentation+".jpg", overlay_image_predict)

if model_preference == 'PP':
    if type_of_segmentation == 'all':
        pspnet_model_disk.eval()
        pspnet_model_cup.eval()
        with torch.no_grad():
            for image in images_test:
                filename = os.path.basename(image)
                image = Image.open(image)
                
                image_disk = image.resize((256, 256), resample=Image.NEAREST)
                image_disk = transform(image_disk)
                image_disk = image_disk.unsqueeze(0)
                
                #image_disk = image_disk.to(device) 
                output_disk = pspnet_model_disk(image_disk)
                
                output_disk = output_disk.detach().numpy()[0]
                output_disk = (output_disk > threshold)
                output_disk = np.transpose(output_disk, (1, 2, 0))
                pspnet_disk = output_disk.squeeze()
                pspnet_disk = pspnet_disk.astype(np.uint8)
        
                image_cup = image.resize((256, 256), resample=Image.NEAREST)
                image_cup = transform(image_cup)
                image_cup = image_cup.unsqueeze(0)
                
                output_cup = pspnet_model_cup(image_cup)
                
                output_cup = output_cup.detach().numpy()[0]
                output_cup = (output_cup > threshold)
                
                output_cup = np.transpose(output_cup, (1, 2, 0))
                pspnet_cup = output_cup[:,:,2]
                pspnet_cup = pspnet_cup.astype(np.uint8)
        
                merged_mask = np.zeros_like(pspnet_disk)
                merged_mask[pspnet_disk == 1] = 128
                merged_mask[pspnet_cup == 1] = 255

                overlay_image_predict = np.copy(image.resize((256, 256), resample=Image.NEAREST))
                overlay_image_predict[pspnet_disk == 1] = 128
                overlay_image_predict[pspnet_cup == 1] = 255

                cv2.imwrite(test_results_folder+"/"+filename, merged_mask)
                cv2.imwrite(test_results_overlayed_folder+"/"+filename, overlay_image_predict)

    else:
        orig_image = Image.open(test_folder + "/" + type_of_segmentation + ".jpg")
        
        pspnet_model_disk.eval()
        image = orig_image.resize((256, 256), resample=Image.NEAREST)
        image = transform(image)
        image = image.unsqueeze(0)
        
        output = pspnet_model_disk(image)
        
        output = output.detach().numpy()[0]
        output = (output > threshold)
        
        output = np.transpose(output, (1, 2, 0))
        pspnet_disk = output.squeeze()
        pspnet_disk = pspnet_disk.astype(np.uint8)

        pspnet_model_cup.eval()
        image = orig_image.resize((256, 256), resample=Image.NEAREST)
        image = transform(image)
        image = image.unsqueeze(0)
        
        output = pspnet_model_cup(image)
        
        output = output.detach().numpy()[0]
        output = (output > threshold)
        
        output = np.transpose(output, (1, 2, 0))
        pspnet_cup = output[:,:,2]
        pspnet_cup = pspnet_cup.astype(np.uint8)

        merged_mask = np.zeros_like(pspnet_disk)
        merged_mask[pspnet_disk == 1] = 128
        merged_mask[pspnet_cup == 1] = 255

        overlay_image_predict = np.copy(orig_image.resize((256, 256), resample=Image.NEAREST))
        overlay_image_predict[pspnet_disk == 1] = 128
        overlay_image_predict[pspnet_cup == 1] = 255

        cv2.imwrite(test_results_folder+"/"+type_of_segmentation+".jpg", merged_mask)
        cv2.imwrite(test_results_overlayed_folder+"/"+type_of_segmentation+".jpg", overlay_image_predict)
        
if model_preference == 'V':
    if type_of_segmentation == 'all':
        vessels_model.eval()
        with torch.no_grad():
            for image in images_test:
                filename = os.path.basename(image)
                image = Image.open(image)
                vessels_image = image.resize((512, 512), resample=Image.NEAREST)
                vessels_image = vessels_image.convert('L')
                vessels_image = transform(vessels_image)
                #image = image/255
                vessels_image = vessels_image.unsqueeze(0)
                vessels_output = vessels_model(vessels_image)
        
                vessels_output = vessels_output.detach().numpy()[0]
                vessels_output = (vessels_output > threshold)
                
                vessels_output = np.transpose(vessels_output, (1, 2, 0))
                vessels_output = np.array(vessels_output)
                vessels_output = np.repeat((vessels_output == 1).astype(np.uint8), 3, axis=2)
                vessels_output[vessels_output==1] = 255
                
                overlay_image_predict = np.copy(image.resize((512, 512), resample=Image.NEAREST))
                overlay_image_predict[vessels_output == 255] = 255

                cv2.imwrite(test_results_folder+"/"+filename, vessels_output)
                cv2.imwrite(test_results_overlayed_folder+"/"+filename, overlay_image_predict)

    else:
        orig_image = Image.open(test_folder + "/" + type_of_segmentation + ".jpg")
        vessels_model.eval()
        image = orig_image.resize((512, 512), resample=Image.NEAREST)
        image = image.convert('L')
        image = transform(image)
        #image = image/255
        image = image.unsqueeze(0)
        vessels_output = vessels_model(image)

        vessels_output = vessels_output.detach().numpy()[0]
        vessels_output = (vessels_output > threshold)
        
        vessels_output = np.transpose(vessels_output, (1, 2, 0))
        vessels_output = np.array(vessels_output)
        vessels_output = np.repeat((vessels_output == 1).astype(np.uint8), 3, axis=2)
        vessels_output[vessels_output==1] = 255

        overlay_image_predict = np.copy(orig_image.resize((512, 512), resample=Image.NEAREST))
        overlay_image_predict[vessels_output == 255] = 255

        cv2.imwrite(test_results_folder+"/"+type_of_segmentation+".jpg", vessels_output)
        cv2.imwrite(test_results_overlayed_folder+"/"+type_of_segmentation+".jpg", overlay_image_predict)