<a href="https://colab.research.google.com/github/shireesh-kumar/RVS-UNet/blob/main/TRAIN_RetinalVesselSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Code inspired from https://github.com/nikhilroxtomar/Retina-Blood-Vessel-Segmentation-in-PyTorch/tree/main
#Dataset : FIVES Retinal Vessel Dataset

#Importing requrired libraries
import os
import numpy as np
import cv2
import random
from glob import glob

import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_tensor


from sklearn.model_selection import train_test_split
from PIL import Image

In [None]:
#Constants
H = 256
W = 256
batch_size = 8
num_epochs = 20
lr = 1e-4
checkpoint_path = "files/checkpoint.pth"

# Pre-allocate tensors for inputs and labels -- Memory Optimization
device = torch.device('cuda')
inputs_preallocated = torch.zeros(batch_size, 3, H, W, device=device,dtype=torch.float32)
labels_preallocated = torch.zeros(batch_size, 1, H, W, device=device, dtype=torch.float32)

train_x = sorted(glob("/content/drive/MyDrive/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/train/Original/*.png"))
train_y = sorted(glob("/content/drive/MyDrive/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/train/Ground truth/*.png"))


In [None]:
#Dataset Custom Class

class DriveDataset(Dataset):
  def __init__(self, images_path, masks_path):
    self.images_path = images_path
    self.masks_path = masks_path
    self.n_samples = len(images_path)

  def __getitem__(self, index):
    image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
    image = cv2.resize(image, (H, W), interpolation=cv2.INTER_AREA)
    image = image/255.0 ## ex.(512, 512, 3)
    image = np.transpose(image, (2, 0, 1))  ## ex.(3, 512, 512)
    image = image.astype(np.float32)
    image = torch.from_numpy(image)

    mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (H, W), interpolation=cv2.INTER_NEAREST)
    mask = mask/255.0   ## ex.(512, 512)
    mask = np.expand_dims(mask, axis=0) ## ex.(1, 512, 512)
    mask = mask.astype(np.float32)
    mask = torch.from_numpy(mask)

    return image, mask

  def __len__(self):
    return self.n_samples

In [None]:
#Loss Function
class DiceBCELoss(nn.Module):
  def __init__(self, weight=None, size_average=True):
    super(DiceBCELoss, self).__init__()

  def forward(self, inputs, targets, smooth=1):

    inputs = torch.sigmoid(inputs)
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
    BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
    Dice_BCE = BCE + dice_loss

    return Dice_BCE

In [None]:
#Utility Functions
def seeding(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
#Model

class conv_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
      self.bn1 = nn.BatchNorm2d(out_c)

      self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
      self.bn2 = nn.BatchNorm2d(out_c)

      self.relu = nn.ReLU()

  def forward(self, inputs):
      x = self.conv1(inputs)
      x = self.bn1(x)
      x = self.relu(x)

      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)

      return x

class encoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv = conv_block(in_c, out_c)
      self.pool = nn.MaxPool2d((2, 2))

  def forward(self, inputs):
      x = self.conv(inputs)
      p = self.pool(x)

      return x, p

class decoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
      self.conv = conv_block(out_c+out_c, out_c)

  def forward(self, inputs, skip):
      x = self.up(inputs)
      x = torch.cat([x, skip], axis=1)
      x = self.conv(x)
      return x

class build_unet(nn.Module):
  def __init__(self):
      super().__init__()

      """ Encoder """
      self.e1 = encoder_block(3, 64)
      self.e2 = encoder_block(64, 128)
      self.e3 = encoder_block(128, 256)
      self.e4 = encoder_block(256, 512)

      """ Bottleneck """
      self.b = conv_block(512, 1024)

      """ Decoder """
      self.d1 = decoder_block(1024, 512)
      self.d2 = decoder_block(512, 256)
      self.d3 = decoder_block(256, 128)
      self.d4 = decoder_block(128, 64)

      """ Classifier """
      self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

  def forward(self, inputs):
      """ Encoder """
      s1, p1 = self.e1(inputs)
      s2, p2 = self.e2(p1)
      s3, p3 = self.e3(p2)
      s4, p4 = self.e4(p3)

      """ Bottleneck """
      b = self.b(p4)

      """ Decoder """
      d1 = self.d1(b, s4)
      d2 = self.d2(d1, s3)
      d3 = self.d3(d2, s2)
      d4 = self.d4(d3, s1)

      outputs = self.outputs(d4)

      return outputs

# Checking the output of the model
# x = torch.randn((2, 3, 512, 512))
# f = build_unet()
# y = f(x)
# print(y.shape)

In [None]:
#Training
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss


if __name__ == "__main__":

    seeding(42)
    create_dir("files")


    train_x, valid_x, train_y, valid_y = train_test_split(train_x, train_y, test_size=0.2, random_state=42)

    def custom_collate_fn(batch):
      for i, data_point in enumerate(batch):
          inputs_preallocated[i] = data_point[0]
          labels_preallocated[i] = data_point[1]
      return inputs_preallocated, labels_preallocated

    train_dataset = DriveDataset(train_x, train_y)
    valid_dataset = DriveDataset(valid_x, valid_y)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=custom_collate_fn
    )

    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=custom_collate_fn
    )

    model = build_unet()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
    loss_fn = DiceBCELoss()

    best_valid_loss = float("inf")

    for epoch in range(num_epochs):

        train_loss = train(model, train_loader, optimizer, loss_fn, device)
        valid_loss = evaluate(model, valid_loader, loss_fn, device)

        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
            print(data_str)

            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_path)
        data_str_metric = ''
        data_str_metric += f'\tTrain Loss: {train_loss:.3f}\n'
        data_str_metric += f'\t Val. Loss: {valid_loss:.3f}\n'
        print(data_str_metric)

Valid loss improved from inf to 1.1974. Saving checkpoint: files/checkpoint.pth
	Train Loss: 1.217
	 Val. Loss: 1.197

Valid loss improved from 1.1974 to 0.9612. Saving checkpoint: files/checkpoint.pth
	Train Loss: 1.018
	 Val. Loss: 0.961

Valid loss improved from 0.9612 to 0.8897. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.926
	 Val. Loss: 0.890

Valid loss improved from 0.8897 to 0.8251. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.852
	 Val. Loss: 0.825

Valid loss improved from 0.8251 to 0.7688. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.791
	 Val. Loss: 0.769

Valid loss improved from 0.7688 to 0.7154. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.735
	 Val. Loss: 0.715

Valid loss improved from 0.7154 to 0.6658. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.687
	 Val. Loss: 0.666

Valid loss improved from 0.6658 to 0.6311. Saving checkpoint: files/checkpoint.pth
	Train Loss: 0.640
	 Val. Loss: 0.631

Valid loss improved from 0.