In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from glob import glob
import segmentation_models_pytorch as smp
import torch.nn.functional as F
import cv2
import time 
device =torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
model = torch.load("/home/nipun/Documents/Uni_Malta/LuminEye/LuminEye-Experiments/U2net/U2N2T_WITH_ESRGAN_bacth_8_epoch_300_with_diceLoss/Miche_model_2023_05_02_23:00:18_val_iou0.798.pt")


val_images = "/home/nipun/Documents/Uni_Malta/Datasets/Datasets/Miche/MICHE_MULTICLASS/Dataset/val_img"

val_masks =  "/home/nipun/Documents/Uni_Malta/LuminEye/LuminEye-Experiments/utils/Masks_with_256_val" 
n_classes = 3
batch_size = 4


valid_x = sorted(
        glob(f"{val_images}/*"))
valid_y = sorted(
        glob(f"{val_masks }/*"))

In [3]:
class Iris(Dataset):
    def __init__(self,images,masks,transform = None):
        self.transforms = transform
        self.images = images
        self.masks = masks
    def __len__(self):
        return len(self.images)
    def __getitem__(self,index):
        
        
        # print(self.masks[index])
        image = Image.open(self.images[index])
        img = np.array(image.resize((64,64)))
        
        mask = Image.open(self.masks[index])
        mask = np.array(mask.resize((256,256)))
        
               
        if self.transforms is not None:
            aug = self.transforms(image=img,mask=mask)
            img = aug['image']
            mask = aug['mask']
        return img,mask

In [4]:
def get_images(test_x,test_y,val_transform,batch_size=1,shuffle=True,pin_memory=True):
    
    val_data  = Iris(test_x,test_y,transform =val_transform)
    test_batch = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False,drop_last=True)
    return test_batch


In [5]:
transform = A.Compose([
    A.augmentations.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])



val_batch = get_images(valid_x,valid_y,transform,batch_size=batch_size)

val_cls  = Iris(valid_x,valid_y,transform =transform)



In [6]:
n_classes = 3
batch_size = 1

img_resize = 256

colors = [ [  0,   0,   0],[0,255,0],[0,0,255]]
label_colours = dict(zip(range(n_classes), colors))

valid_classes = [0,85, 170]
class_names = ["Background","Pupil","Iris"]


class_map = dict(zip(valid_classes, range(len(valid_classes))))
n_classes=len(valid_classes)


def decode_segmap(temp):
    #convert gray scale to color
    # temp=temp.numpy()
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, n_classes):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    return rgb

In [7]:
def main():
    
    total_bg = 0
    total_pupil = 0
    total_iris = 0
    total_iou = 0 
    for i in range(len(val_batch)):
        image,mask = val_cls[i]
        
        
        bg_iou,pupil_iou,iris_iou= mean_iou(image=image,mask=mask)
        
        # print(bg_iou,pupil_iou,iris_iou) 
        
        total_bg += bg_iou
        total_pupil += pupil_iou
        total_iris +=iris_iou
        
        # if i ==1:
        #     break
    return total_bg/len(val_batch),total_pupil/len(val_batch),total_iris/len(val_batch)

In [8]:


eps = 1e-10

In [9]:
def mean_iou(image,mask):
    
    image = image.to(device)
    mask = mask.to(device)
    
    
    image = image.unsqueeze(0)
    mask = mask.unsqueeze(0) 
    
    with torch.no_grad():
        
        softmax = nn.Softmax(dim=1)
    
        model_output,_,_,_,_,_,_ = model(image)
        
        predicted_label = F.softmax(model_output,dim=1)
        predicted_label = torch.argmax(predicted_label,dim=1)
        
        predicted_label = predicted_label.contiguous().view(-1) # 65536
        mask = mask.contiguous().view(-1)  # 65536
        
        
        iou_single_class = []
        
        for class_member in range(0,n_classes):
            # print(class_member)
            true_predicted_class = predicted_label==class_member
            true_label = mask == class_member
            
            if true_label.long().sum().item() == 0:
                iou_single_class.append(np.nan)
            
            else:
                intersection = (torch.logical_and(
                    true_predicted_class,
                    true_label
                ).sum().float().item() )
                
                union = (torch.logical_or(
                    true_predicted_class,
                    true_label
                ).sum().float().item())
                
                iou = (intersection +eps)/(union +eps)
                
                iou_single_class.append(iou)
            
    return iou_single_class

In [10]:
total_bg,total_pupil,total_iris = main()



In [11]:
print(total_bg,total_pupil,total_iris)

0.9825123014910376 0.6446349500796045 0.8550768371936728


: 

In [32]:
from custom_model import  U2NET
from scipy.ndimage.morphology import distance_transform_edt as edt
from scipy.ndimage import convolve

  from scipy.ndimage.morphology import distance_transform_edt as edt


In [33]:
class HausdorffDTLoss(nn.Module):
    """Binary Hausdorff loss based on distance transform"""

    def __init__(self, alpha=2.0, **kwargs):
        super(HausdorffDTLoss, self).__init__()
        self.alpha = alpha

    @torch.no_grad()
    def distance_field(self, img: np.ndarray) -> np.ndarray:
        field = np.zeros_like(img)

        for batch in range(len(img)):
            fg_mask = img[batch] > 0.5

            if fg_mask.any():
                bg_mask = ~fg_mask

                fg_dist = edt(fg_mask)
                bg_dist = edt(bg_mask)

                field[batch] = fg_dist + bg_dist

        return field

    def forward(
        self, pred: torch.Tensor, target: torch.Tensor, debug=False
    ) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        
        
        """
    
        assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported"
        assert (
            pred.dim() == target.dim()
        ), "Prediction and target need to be of same dimension"

        # pred = torch.sigmoid(pred)

        pred_dt = torch.from_numpy(self.distance_field(pred.detach().cpu().numpy())).float()
        target_dt = torch.from_numpy(self.distance_field(target.detach().cpu().numpy())).float()

        pred_error = (pred - target) ** 2
        distance = pred_dt ** self.alpha + target_dt ** self.alpha

        dt_field = pred_error * distance
        loss = dt_field.mean()

        if debug:
            return (
                loss.cpu().numpy(),
                (
                    dt_field.cpu().numpy()[0, 0],
                    pred_error.cpu().numpy()[0, 0],
                    distance.cpu().numpy()[0, 0],
                    pred_dt.cpu().numpy()[0, 0],
                    target_dt.cpu().numpy()[0, 0],
                ),
            )

        else:
            return loss

In [34]:
class HausdorffDistance:
    def hd_distance(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:

        if np.count_nonzero(x) == 0 or np.count_nonzero(y) == 0:
            return np.array([np.Inf])

        indexes = np.nonzero(x)
        distances = edt(np.logical_not(y))
        
        print(distances)
        
        
        print(np.array(np.max(distances[indexes])))

        return np.array(np.max(distances[indexes]))
    
    
    def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # assert (
        #     pred.shape[1] == 1 and target.shape[1] == 1
        # ), "Only binary channel supported"

        pred = (pred > 0.5)
        target = (target > 0.5)

        right_hd = torch.from_numpy(
            self.hd_distance(pred.cpu().numpy(), target.cpu().numpy())  # 1.4142135623730951
        ).float()

        left_hd = torch.from_numpy(
            self.hd_distance(target.cpu().numpy(), pred.cpu().numpy())
        ).float()

        return torch.max(right_hd, left_hd)

In [51]:
image = torch.randn(8,3,256,256)

mask = torch.ones_like(image)

In [52]:
model  = U2NET()

In [53]:
pred_output = model(image)



: 

: 

In [None]:
pred = torch.sigmoid(pred_output[0]) 

In [None]:
loss = HausdorffDTLoss()

In [None]:
loss(pred,mask)

tensor(46.9851, grad_fn=<MeanBackward0>)

In [48]:
disatance = HausdorffDistance()

In [50]:
disatance.compute(pred,mask)

tensor([inf])