<a href="https://colab.research.google.com/github/shireesh-kumar/RVS-UNet/blob/main/TEST_RetinalVesselSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Importing required libraries

In [None]:
#importing required libraries
import os, time
from operator import add
import numpy as np
from glob import glob
import cv2
from tqdm import tqdm
import imageio
import torch
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
import torch.nn as nn
import random

#Code taken from Training file

In [None]:
#Model

class conv_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
      self.bn1 = nn.BatchNorm2d(out_c)

      self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
      self.bn2 = nn.BatchNorm2d(out_c)

      self.relu = nn.ReLU()

  def forward(self, inputs):
      x = self.conv1(inputs)
      x = self.bn1(x)
      x = self.relu(x)

      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)

      return x

class encoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv = conv_block(in_c, out_c)
      self.pool = nn.MaxPool2d((2, 2))

  def forward(self, inputs):
      x = self.conv(inputs)
      p = self.pool(x)

      return x, p

class decoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
      self.conv = conv_block(out_c+out_c, out_c)

  def forward(self, inputs, skip):
      x = self.up(inputs)
      x = torch.cat([x, skip], axis=1)
      x = self.conv(x)
      return x

class build_unet(nn.Module):
  def __init__(self):
      super().__init__()

      """ Encoder """
      self.e1 = encoder_block(3, 64)
      self.e2 = encoder_block(64, 128)
      self.e3 = encoder_block(128, 256)
      self.e4 = encoder_block(256, 512)

      """ Bottleneck """
      self.b = conv_block(512, 1024)

      """ Decoder """
      self.d1 = decoder_block(1024, 512)
      self.d2 = decoder_block(512, 256)
      self.d3 = decoder_block(256, 128)
      self.d4 = decoder_block(128, 64)

      """ Classifier """
      self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

  def forward(self, inputs):
      """ Encoder """
      s1, p1 = self.e1(inputs)
      s2, p2 = self.e2(p1)
      s3, p3 = self.e3(p2)
      s4, p4 = self.e4(p3)

      """ Bottleneck """
      b = self.b(p4)

      """ Decoder """
      d1 = self.d1(b, s4)
      d2 = self.d2(d1, s3)
      d3 = self.d3(d2, s2)
      d4 = self.d4(d3, s1)

      outputs = self.outputs(d4)

      return outputs

# #Checking the output of the model
# x = torch.randn((2, 3, 512, 512))
# f = build_unet()
# y = f(x)
# print(y.shape)

In [None]:
#Utility Functions
def seeding(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

#Testing

In [None]:
#Testing

def calculate_metrics(y_true, y_pred):
    """ Ground truth """
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)
    score_f1 = f1_score(y_true, y_pred)
    score_recall = recall_score(y_true, y_pred)
    score_precision = precision_score(y_true, y_pred)
    score_acc = accuracy_score(y_true, y_pred)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)    ## (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  ## (512, 512, 3)
    return mask

if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Folders """
    create_dir("results")

    """ Load dataset """
    test_x = sorted(glob("/content/drive/MyDrive/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/test/Original/*.png"))
    test_y = sorted(glob("/content/drive/MyDrive/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/test/Ground truth/*.png"))

    """ Hyperparameters """
    H = 256
    W = 256
    size = (W, H)
    checkpoint_path = "files/checkpoint.pth"

    """ Load the checkpoint """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = build_unet()
    model = model.to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()

    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
    time_taken = []

    for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
        """ Extract the name """
        name = x.split("/")[-1].split(".")[0]

        """ Reading image """
        image = cv2.imread(x, cv2.IMREAD_COLOR) ## (512, 512, 3)
        image = cv2.resize(image, size)
        x = np.transpose(image, (2, 0, 1))      ## (3, 512, 512)
        x = x/255.0
        x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
        x = x.astype(np.float32)
        x = torch.from_numpy(x)
        x = x.to(device)

        """ Reading mask """
        mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
        mask = cv2.resize(mask, size)
        y = np.expand_dims(mask, axis=0)            ## (1, 512, 512)
        y = y/255.0
        y = np.expand_dims(y, axis=0)               ## (1, 1, 512, 512)
        y = y.astype(np.float32)
        y = torch.from_numpy(y)
        y = y.to(device)

        with torch.no_grad():
            """ Prediction and Calculating FPS """
            start_time = time.time()
            pred_y = model(x)
            pred_y = torch.sigmoid(pred_y)
            total_time = time.time() - start_time
            time_taken.append(total_time)


            score = calculate_metrics(y, pred_y)
            metrics_score = list(map(add, metrics_score, score))
            pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
            pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
            pred_y = pred_y > 0.5
            pred_y = np.array(pred_y, dtype=np.uint8)

        """ Saving masks """
        ori_mask = mask_parse(mask)
        pred_y = mask_parse(pred_y)
        line = np.ones((size[1], 10, 3)) * 128

        retinal_images = np.concatenate(
            [image, line, ori_mask, line, pred_y * 255], axis=1
        )
        cv2.imwrite(f"results/{name}.png", retinal_images)

    jaccard = metrics_score[0]/len(test_x)
    f1 = metrics_score[1]/len(test_x)
    recall = metrics_score[2]/len(test_x)
    precision = metrics_score[3]/len(test_x)
    acc = metrics_score[4]/len(test_x)
    print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f}")

    fps = 1/np.mean(time_taken)
    print("FPS: ", fps)

100%|██████████| 200/200 [07:45<00:00,  2.33s/it]

Jaccard: 0.5653 - F1: 0.7182 - Recall: 0.6712 - Precision: 0.7806 - Acc: 0.9630
FPS:  0.5017865851886366





In [None]:
#Downloading the results
# !zip -r /content/results.zip /content/results