In [None]:
import os

base_path = "C:\\Projects\\ature\\" # WINDOWS
# base_path = "home/ak/Projects/ature/" # LINUX

data_file_path = base_path + "\\data\\DRIVE\\test\\images"
mask_path = base_path + "\\data\\DRIVE\\test\\mask"
ground_truth_path = base_path + "\\data\\DRIVE\\test\\1st_manual"
log_path = base_path + "\\logs"
    
os.chdir(base_path)

from commons.IMAGE import Image
from commons.ImgLATTICE import Lattice
import preprocess.utils.img_utils as imgutils
from commons.MAT import Mat
from PIL import Image as IMG
import numpy as np
from commons import constants as const
import cv2
from preprocess.algorithms import fast_mst as fmst
import itertools as itr
from itertools import count
import matplotlib.pyplot as plt

In [None]:
def run_segmentation(img_obj, lattice_obj, params):
    
    ##### Unpack all params
    SKELETONIZE_THRESHOLD, IMG_LATTICE_COST_ASSIGNMENT_ALPHA, IMG_LATTICE_COST_GABOR_IMAGE_CONTRIBUTION, SEGMENTATION_THRESHOLD = params
    
    ##### Create skeleton based on threshold
    img_obj.create_skeleton(threshold=SKELETONIZE_THRESHOLD, kernels=imgutils.get_chosen_skeleton_filter())
    seed_node_list = imgutils.get_seed_node_list(img_obj.img_skeleton)     

    ##### Run segmnetation
    graph = fmst.run_segmentation(image_object=img_obj,
                          lattice_object=lattice_obj,
                          seed_list=seed_node_list,
                          segmentation_threshold=SEGMENTATION_THRESHOLD,
                          alpha=IMG_LATTICE_COST_ASSIGNMENT_ALPHA,
                          img_gabor_contribution=IMG_LATTICE_COST_GABOR_IMAGE_CONTRIBUTION,
                          img_original_contribution=1-IMG_LATTICE_COST_GABOR_IMAGE_CONTRIBUTION)

In [None]:
def get_precision_recall(segmented, truth):

    TP = 0 #True Positive
    FP = 0 #False Positive
    FN = 0 #False Negative
    for i in range(0, segmented.shape[0]):
        for j in range(0, segmented.shape[1]):
            if segmented[i, j] == 255 and truth[i, j] == 255:
                TP+=1
            if segmented[i, j] == 255 and truth[i, j] == 0:
                FP+=1
            if segmented[i, j] == 0 and truth[i, j] == 255:
                FN += 1
    
    return TP / (TP + FP), TP / (TP + FN)

In [None]:
def generate_precision_recall_plot(log_file):
        log = np.loadtxt(log_file, skiprows=1, delimiter=',')
        plt.title('Precision vs Recall plot')
        plt.xlabel('Iterations')
        plt.ylabel('Precision vs Recall')
        plt.plot(log[:,0], log[:,1], label='F1')
        plt.plot(log[:,0], log[:,2], label='Precision')
        plt.plot(log[:,0], log[:,3], label='Recall')
        plt.legend(fontsize='small', bbox_to_anchor=(1, 1.2))
        plt.savefig(log_file + '.png')

In [None]:
############# ENTRY POINT HERE ###############
############################################
# SK_THRESHOLD_PARAMS = np.arange(10, 71, 20)
# ALPHA_PARAMS = np.arange(5, 10, 1)
# GABOR_CONTRIBUTION_PARAMS = np.arange(0.7, 0.9, 0.1)
# SEGMENTATION_THRESHOLD_PARAMS = np.arange(7, 10, 0.5)

SK_THRESHOLD_PARAMS = np.arange(40, 61, 10)
ALPHA_PARAMS = np.arange(5, 10, 2)
GABOR_CONTRIBUTION_PARAMS = np.arange(0.7, 0.9, 0.1)
SEGMENTATION_THRESHOLD_PARAMS = np.arange(7, 10, 0.5)

PARAMS_COMBINATION = itr.product(SK_THRESHOLD_PARAMS, ALPHA_PARAMS, GABOR_CONTRIBUTION_PARAMS, SEGMENTATION_THRESHOLD_PARAMS)

#### Work on all images in a directory
os.chdir(data_file_path)
for test_image in os.listdir(os.getcwd()):
    
    c = count()
    next(c)
    
    print('### WORKING ON: ' + test_image)
    ### Load image as array
    original = IMG.open(test_image)
    original = np.array(original.getdata(), np.uint8).reshape(original.size[1], original.size[0], 3)
    img_obj = Image(image_arr=original[:,:,1])
    
    #### Load the corresponding mask as array
    os.chdir(mask_path)
    mask_file = test_image.split('_')[0] + '_test_mask.gif'
    mask = IMG.open(mask_file)
    print("Mask loaded: " + mask_file)
    mask = np.array(mask.getdata(), np.uint8).reshape(mask.size[1], mask.size[0], 1)[:,:,0]

    #### Load ground truth segmented result as array
    os.chdir(ground_truth_path)
    ground_truth_file = test_image.split('_')[0] + '_manual1.gif'
    truth = IMG.open(ground_truth_file)
    print("Ground truth loaded: " + ground_truth_file)
    truth = np.array(truth.getdata(), np.uint8).reshape(truth.size[1], truth.size[0], 1)[:,:,0]
    
    
    img_obj.apply_bilateral()
    img_obj.apply_gabor(kernel_bank=imgutils.get_chosen_gabor_bank())
    print('Filter applied.')
    
    lattice_obj = Lattice(image_arr_2d=img_obj.img_gabor)
    lattice_obj.generate_lattice_graph()
    print('Lattice created.')
    
    os.chdir(log_path)
    log_file_name = test_image + "_result.csv"
    log_file = open(log_file_name,'w')
    
    ### Write header
    log_file.write(
    'ITERATION,FSCORE,PRECISION,RECALL,'\
    'SKELETONIZE_THRESHOLD,'\
    'IMG_LATTICE_COST_ASSIGNMENT_ALPHA,'\
    'IMG_LATTICE_COST_GABOR_IMAGE_CONTRIBUTION,'\
    'SEGMENTATION_THRESHOLD\n'
    )
    
    for params in PARAMS_COMBINATION:
        
        i = next(c)
        run_segmentation(img_obj, lattice_obj, params)
        
        ### Apply mask
        segmented = cv2.bitwise_and(lattice_obj.accumulator, lattice_obj.accumulator, mask=mask)
        
        precision, recall = get_precision_recall(segmented, truth)
        f1_score = 2 * precision * recall / (precision + recall) 
     
        log_file.write(str(i) + ',' + \
                       str(round(f1_score, 3)) + ',' + \
                       str(round(precision, 3)) + ',' + \
                       str(round(recall, 3)) + ',' +\
                       ','.join(map(str, params)) + '\n')
        
        log_file.flush()
        print('Number of parameter combinations tried: ' + str(i), end='\r')
    
    log_file.close()
    os.chdir(log_path)
    generate_precision_recall_plot(log_file_name)

In [None]:
SK_THRESHOLD_PARAMS = np.arange(40, 61, 10)
ALPHA_PARAMS = np.arange(5, 10, 2)
GABOR_CONTRIBUTION_PARAMS = np.arange(0.7, 0.9, 0.1)
SEGMENTATION_THRESHOLD_PARAMS = np.arange(7, 10, 0.5)

PARAMS_COMBINATION = itr.product(SK_THRESHOLD_PARAMS, ALPHA_PARAMS, GABOR_CONTRIBUTION_PARAMS, SEGMENTATION_THRESHOLD_PARAMS)


In [None]:
len(list(PARAMS_COMBINATION))