In [None]:
import pdb
import cv2
import torch
import os, glob
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Dataset
from segmentation_models_pytorch.losses import DiceLoss

In [None]:
def check_if_noise_isPresent(mask):

    contours_org, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    count_org = len(contours_org)
    if count_org == 1:
        return True

    # Create a kernel (structuring element) for dilation
    # A 5x5 matrix of ones will be used here
    kernel = np.ones((5, 5), np.uint8)
    # Apply dilation to the image

    dilated_mask = cv2.dilate(mask.copy(), kernel, iterations=1)
    contours_dilated, _ = cv2.findContours(dilated_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    count_dilated = len(contours_dilated)
    if count_dilated >1:
        return True
    return False

def new_remove_noise(mask,threshold=0.1):

    noise_preset = check_if_noise_isPresent(mask)
    if noise_preset == False:
        return mask

    noise_removed_mask = np.zeros_like(mask)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours)>0:
        largest_contour_area = cv2.contourArea(max(contours, key=cv2.contourArea))

        for contour in contours:
            area_contour = cv2.contourArea(contour)
            # pdb.set_trace()
            if area_contour < (largest_contour_area * threshold):
                cv2.drawContours(noise_removed_mask, [contour], -1, 255, -1)

        cleaned_mask = cv2.subtract(mask, noise_removed_mask)

        return cleaned_mask
    else:
        return mask


def get_area_of_leaf(orginal_image_mask):

    orginal_image = orginal_image_mask

    if len(orginal_image.shape) != 2:
        orginal_image = cv2.cvtColor(orginal_image.copy(), cv2.COLOR_BGR2GRAY)

    black_mask = np.zeros_like(orginal_image)
    orginal_image[orginal_image>50]=255
    orginal_image[orginal_image<50]=0

    contours, _ = cv2.findContours(orginal_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    height, width = orginal_image.shape[:2]
    image_center = (width // 2, height // 2)

    def calculate_distance(pt1, pt2):
        return np.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)

    max_distance_from_center = min(width, height) * 0.25

    center_contours = []
    for contour in contours:
        M = cv2.moments(contour)
        if M['m00'] != 0:
            cx = int(M['m10'] / M['m00'])
            cy = int(M['m01'] / M['m00'])
            centroid = (cx, cy)

            distance_to_center = calculate_distance(centroid, image_center)

            if distance_to_center <= max_distance_from_center:
                center_contours.append(contour)

    if center_contours:
        largest_contour = max(center_contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        larges_center_contour_area=cv2.contourArea(largest_contour)
        return larges_center_contour_area,x,y,w,h

    else:
        return -1,0,0,0,0

def _iou(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

def _accuracy(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")

class CustomModelBase(torch.nn.Module):

    def __init__(self, accuracy_function=_accuracy, iou_function=_iou):
        super(CustomModelBase, self).__init__()
        # self.class_weights = class_weights
        # self.loss_function = loss_function
        self.accuracy_function = accuracy_function
        self.iou_function = iou_function

class CreateModel(CustomModelBase):
    def __init__(self, model):
        super(CreateModel, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

def runInference(image , threshold = 0.5):
    h,w = image.height, image.width
    image = resize(image)

    image1 = TF.to_tensor(image)

    tensor_input = TF.normalize(image1, mean=mean, std=std)

    tensor_input = tensor_input.to(device).float()

    output = model.forward(tensor_input[None])[0].squeeze()

    prob_mask = output.sigmoid().cpu().numpy()
    output_mask = (prob_mask > threshold)
    output_mask = (output_mask * 255).astype(np.uint8)
    output_mask = cv2.resize(output_mask, (w,h))
    return output_mask

In [None]:
model = torch.load('', map_location=torch.device('cuda'))
model.eval()

In [None]:
input_path = ''
output_path = ''
resize = transforms.Resize(size=(512, 512))
folders = os.listdir(input_path)

In [None]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

total_dice = 0
with torch.no_grad():
    for folder in folders:
        all_reps = os.listdir(os.path.join(input_path, folder))
        print(os.path.join(input_path, folder))
        for rep in all_reps:
            print(rep)
            input_images = os.path.join(input_path, folder, rep)
            all_files = glob.glob(input_images + '/*')
            dest_path = os.path.join(output_path, folder, rep)
            os.makedirs(dest_path, exist_ok=True)

            for imagepth in tqdm(all_files):

                image_id = os.path.basename(imagepth)
                image = Image.open(imagepth)
                opencv_image = np.array(image)
                org_h, org_w = opencv_image.shape[:2]
                input_image_width = image.width
                input_image_height = image.height

                output_mask = runInference(image.copy())
                area_of_center_leaf,x,y,w,h = get_area_of_leaf(output_mask)

                if area_of_center_leaf < (input_image_width * input_image_height * 0.005):
                    x_diff, y_diff = 120, 120

                    cropped_img = opencv_image[y_diff: -y_diff, x_diff: -x_diff]

                    blank_img = np.zeros_like(opencv_image)
                    blank_img = cv2.cvtColor(blank_img, cv2.COLOR_BGR2GRAY)

                    output_mask = runInference(Image.fromarray(cropped_img), threshold=0.5)


                    blank_img[y_diff: -y_diff, x_diff: -x_diff] = output_mask
                    output_mask = blank_img


                remove_noise_mask = new_remove_noise(output_mask)

                remove_noise_mask = cv2.resize(remove_noise_mask, (org_w,org_h))


                mask_dir_path = os.path.join(dest_path, "masks")
                os.makedirs(mask_dir_path, exist_ok=True)
                remove_noise_mask_3 = cv2.cvtColor(remove_noise_mask, cv2.COLOR_BGR2RGB)

                cv2.imwrite(f"{mask_dir_path}/{image_id[:-4]}_mask.png", remove_noise_mask_3)
                seg_image = np.where(remove_noise_mask_3 == 0, 0, image)
                seg_image = cv2.cvtColor(seg_image.astype('uint8'), cv2.COLOR_RGB2BGR)

                dir_path = os.path.join(dest_path, "segmented_images")
                os.makedirs(dir_path, exist_ok=True)
                cv2.imwrite(f"{dir_path}/{image_id[:-4]}_seg.png", seg_image)
