In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt


In [131]:
import os
from glob import glob
from tqdm import tqdm
import imageio
from albumentations import HorizontalFlip,VerticalFlip,ElasticTransform,GridDistortion,OpticalDistortion

import cv2


In [132]:


def load_data(path):
    """ X = Images and Y = masks """

    train_x = sorted(glob(os.path.join(path, "train", "train", "*.jpg")))
    train_y = sorted(glob(os.path.join(path, "train_masks", "train_masks", "*.gif")))

    test_x = sorted(glob(os.path.join(path, "validation_car", "*.jpg")))
    test_y = sorted(glob(os.path.join(path, "validation_mask", "*.gif")))

    return (train_x, train_y), (test_x, test_y)


(train_x, train_y), (test_x, test_y)=load_data('/content/drive/MyDrive/Segmentation/car_unet')

print(len(train_x),len(train_y),len(test_x),len(test_y))


1476 1476 72 72


In [133]:

# def create_dir(path):
#     if not os.path.exists(path):
#         os.makedirs(path)

In [134]:
import imageio
def augment_data(images, masks, save_path, augment=True):
    H = 512
    W = 512

    for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
        # """ Extracting names """
        name = x.split("/")[-1].split(".")[0]

        # """ Reading image and mask """
        x = cv2.imread(x, cv2.IMREAD_COLOR)
       # y = cv2.imread(y,cv2.IMREAD_COLOR)
        y = imageio.mimread(y)[0]

        # if augment == True:
        #     aug = HorizontalFlip(p=1.0)
        #     augmented = aug(image=x, mask=y)
        #     x1 = augmented["image"]
        #     y1 = augmented["mask"]

        #     aug = VerticalFlip(p=1.0)
        #     augmented = aug(image=x, mask=y)
        #     x2 = augmented["image"]
        #     y2 = augmented["mask"]

        #     aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
        #     augmented = aug(image=x, mask=y)
        #     x3 = augmented['image']
        #     y3 = augmented['mask']

        #     aug = GridDistortion(p=1)
        #     augmented = aug(image=x, mask=y)
        #     x4 = augmented['image']
        #     y4 = augmented['mask']

        #     aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
        #     augmented = aug(image=x, mask=y)
        #     x5 = augmented['image']
        #     y5 = augmented['mask']

        #     X = [x1, x2, x3, x4, x5]
        #     Y = [y1, y2, y3, y4, y5]

        # else:

        X = [x]    
        Y = [y]

        index = 0
        for i, m in zip(X, Y):
            i = cv2.resize(i, (W, H))
            m = cv2.resize(m, (W, H))

            if len(X) == 1:
                tmp_image_name = f"{name}.jpg"
                tmp_mask_name = f"{name}.jpg"
            else:
                tmp_image_name = f"{name}_{index}.jpg"
                tmp_mask_name = f"{name}_{index}.jpg"

            image_path = os.path.join(save_path, "validation_car", tmp_image_name)
            mask_path = os.path.join(save_path, "validation_mask", tmp_mask_name)

            cv2.imwrite(image_path, i)
            cv2.imwrite(mask_path, m)

            index += 1


In [135]:

#augment_data(train_x, train_y, "/content/drive/MyDrive/Segmentation/car_unet/", augment=True)
#augment_data(test_x, test_y, "/content/drive/MyDrive/Segmentation/retina_unet/new_data/test/", augment=False)

In [136]:
#augment_data(test_x, test_y, "/content/drive/MyDrive/Segmentation/car_unet/", augment=True)

In [137]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
        print(len(self.downs))
        print(len(self.ups))
    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)


In [138]:
model=UNET()


4
8


# Creating class dataset

In [139]:

import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
class retina_dataset(Dataset):

  def __init__(self,images_path,masks_path):
     
   # images=os.listdir(image_path)
    self.images_path=images_path
    self.masks_path=masks_path
    self.n_samples=len(images_path)


  def __getitem__(self,index):
        """ Reading image """
        image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
        image = image/255.0 ## (512, 512, 3)
        image = np.transpose(image, (2, 0, 1))  ## (3, 512, 512)
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        """ Reading mask """
        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask/255.0   ## (512, 512)
        mask = np.expand_dims(mask, axis=0) ## (1, 512, 512)
        mask = mask.astype(np.float32)
        mask = torch.from_numpy(mask)

        return image,mask
  def __len__(self):
    return self.n_samples     



# losses

In [140]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)


        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE


# training 

In [141]:
""" Load dataset """
train_x = sorted(glob("/content/drive/MyDrive/Segmentation/retina_unet/new_data/train/image/*"))
train_y = sorted(glob("/content/drive/MyDrive/Segmentation/retina_unet/new_data/train/mask/*"))


test_x = sorted(glob("/content/drive/MyDrive/Segmentation/retina_unet/new_data/test/image/*"))
test_y = sorted(glob("/content/drive/MyDrive/Segmentation/retina_unet/new_data/test/mask/*"))

data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(test_x)}\n"
print(data_str)

Dataset Size:
Train: 120 - Valid: 20



In [142]:

""" Dataset and loader """
train_dataset = retina_dataset(train_x, train_y)
valid_dataset = retina_dataset(test_x, test_y)

from torch.utils.data import DataLoader
batch_size=2
train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )


In [143]:
device = torch.device('cuda')   ## GTX 1060 6GB
model = UNET()
model = model.to(device)
lr=1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

print(device)


4
8
cuda


# training part of the UNET architecture

In [144]:
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    loop=tqdm(loader)
    for x, y in loop:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0
    loop=tqdm(loader)
    model.eval()
    with torch.no_grad():
        for x, y in loop:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss    

In [145]:
# epochs=50

# """ Calculate the time taken """
# def epoch_time(start_time, end_time):
#     elapsed_time = end_time - start_time
#     elapsed_mins = int(elapsed_time / 60)
#     elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
#     return elapsed_mins, elapsed_secs
# checkpoint_path='/content/drive/MyDrive/Segmentation/retina_unet/checkpoint.pth'
# best_valid_loss = float("inf")
# import time

# for epoch in range(epochs):
#         start_time = time.time()

#         train_loss = train(model, train_loader, optimizer, loss_fn, device)
#         valid_loss = evaluate(model, valid_loader, loss_fn, device)
#         """ Saving the model """
#         if valid_loss < best_valid_loss:
#             data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
#             print(data_str)

#             best_valid_loss = valid_loss
#             torch.save(model.state_dict(), checkpoint_path)

#         end_time = time.time()
#         epoch_mins, epoch_secs = epoch_time(start_time, end_time)

#         data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
#         data_str += f'\tTrain Loss: {train_loss:.3f}\n'
#         data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
#         print(data_str)

# Evaluation of the UNET

In [146]:
# test x and test y are there for the test purpose 
import torch
import numpy as np
import cv2
import tqdm
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
from operator import add


In [154]:
# load model weights with which we will test

import time
def calculate_metric(y_true,y_pred):
  # ground truth
  y_true=y_true.cpu().numpy()
  y_true=y_true>0.5
  y_true = y_true.astype(np.uint8)
  y_true = y_true.reshape(-1)

  # prediction

  y_pred=y_pred.cpu().numpy()
  y_pred=y_pred>0.5
  y_pred=y_pred.astype(np.int8)
  y_pred=y_pred.reshape(-1)
  
  # calculations of 
  score_jaccard = jaccard_score(y_true, y_pred)
  score_f1 = f1_score(y_true, y_pred)
  score_recall = recall_score(y_true, y_pred)
  score_precision = precision_score(y_true, y_pred)
  score_acc = accuracy_score(y_true, y_pred)

  return [score_jaccard, score_f1, score_recall, score_precision, score_acc]



def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)    ## (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  ## (512, 512, 3)
    return mask

checkpoint_path = "/content/drive/MyDrive/Segmentation/retina_unet/checkpoint.pth"

""" Load the checkpoint """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNET()
model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
#model.eval()    



4
8


<All keys matched successfully>

In [155]:
model.eval()
metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []
import tqdm
for i, (x, y) in (enumerate(zip(test_x, test_y))):
        """ Extract the name """
        name = x.split("/")[-1].split(".")[0]

        """ Reading image """
        image = cv2.imread(x, cv2.IMREAD_COLOR) ## (512, 512, 3)
        ## image = cv2.resize(image, size)
        x = np.transpose(image, (2, 0, 1))      ## (3, 512, 512)
        x = x/255.0
        x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
        x = x.astype(np.float32)
        x = torch.from_numpy(x)
        x = x.to(device)

        """ Reading mask """
        mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
        ## mask = cv2.resize(mask, size)
        y = np.expand_dims(mask, axis=0)            ## (1, 512, 512)
        y = y/255.0
        y = np.expand_dims(y, axis=0)               ## (1, 1, 512, 512)
        y = y.astype(np.float32)
        y = torch.from_numpy(y)
        y = y.to(device)

        with torch.no_grad():
            """ Prediction and Calculating FPS """
            start_time = time.time()
            pred_y = model(x)
            pred_y = torch.sigmoid(pred_y)
            total_time = time.time() - start_time
            time_taken.append(total_time)


            score = calculate_metric(y, pred_y)
            metrics_score = list(map(add, metrics_score, score))
            pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
            pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
            pred_y = pred_y > 0.5
            pred_y = np.array(pred_y, dtype=np.uint8)

        """ Saving masks """
        ori_mask = mask_parse(mask)
        pred_y = mask_parse(pred_y)
        line = np.ones((512, 10, 3)) * 128

        cat_images = np.concatenate(
            [image, line, ori_mask, line, pred_y * 255], axis=1
        )
        cv2.imwrite(f"/content/drive/MyDrive/Segmentation/retina_unet/results/{name}.png", cat_images)

jaccard = metrics_score[0]/len(test_x)
f1 = metrics_score[1]/len(test_x)
recall = metrics_score[2]/len(test_x)
precision = metrics_score[3]/len(test_x)
acc = metrics_score[4]/len(test_x)
print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f}")

fps = 1/np.mean(time_taken)
print("FPS: ", fps)



Jaccard: 0.6674 - F1: 0.8003 - Recall: 0.7918 - Precision: 0.8144 - Acc: 0.9657
FPS:  192.96355167059636
