In [4]:
from matplotlib import pyplot as plt
import six
import cv2
from predict import class_colors , class_names
import numpy as np
from functions import get_image_array
from predict import visualize_segmentation

def predict(model=None, inp=None, out_fname=None,
            checkpoints_path=None, overlay_img=False,
            class_names=None, show_legends=False, colors=class_colors,
            prediction_width=None, prediction_height=None,
            read_image_type=1):
    
    assert inp is not None, "Input must be provided."
    assert isinstance(inp, (np.ndarray, six.string_types)), \
        "Input should be a NumPy array or a file path string."

    if isinstance(inp, six.string_types):
        inp = cv2.imread(inp, read_image_type)
        assert inp is not None, f"Image at path {inp} could not be loaded."

    assert inp.ndim in [1, 3, 4], "Image should have 1, 3, or 4 dimensions."

    output_width = model.output_width
    output_height = model.output_height
    input_width = model.input_width
    input_height = model.input_height
    n_classes = model.n_classes

    x = get_image_array(inp, input_width, input_height)
    pr = model.predict(np.array([x]))
    
    pr = pr.reshape((output_height, output_width, n_classes)).argmax(axis=-1)
    
    seg_img = visualize_segmentation(
        pr, inp, n_classes=n_classes, colors=colors
    )
    
    # Display the image
    # plt.imshow(seg_img)
    # plt.axis('off')  # Turn off axis numbers and ticks
    # plt.show()

    # Convert the seg_img to uint8 format (if necessary)
    # if seg_img.dtype != np.uint8:
    #     seg_img = (seg_img * 255).astype(np.uint8)  # Scale if necessary
    
    # # Check if the image is in RGB and convert to BGR if needed
    # if seg_img.shape[2] == 3:  # Check if there are 3 channels
    #     seg_img = cv2.cvtColor(seg_img, cv2.COLOR_RGB2BGR)  # Convert from RGB to BGR

    # if out_fname is not None:
    #     # Ensure the output file name has a .png extension
    #     if not out_fname.endswith('.png'):
    #         out_fname += '.png'
        
    #     success = cv2.imwrite(out_fname, seg_img)
    #     if success:
    #         print(f"Saved segmented image to {out_fname}")
    #     else:
    #         print(f"Failed to save image at {out_fname}")

    return seg_img


In [5]:
from model import fcn_8_vgg
# Function to predict segmentation on a single image
# from predict import predict
image_path = "train_2.jpg"
output_path = "output.jpg"

# Load the model (ensure the weights are already loaded in your model)
model = fcn_8_vgg(n_classes=27, input_height=224, input_width=320)
model.load_weights('checkpoints/model.weights.h5')

In [6]:
import cv2
import numpy as np
import os

def create_binary_masks(image, class_colors):
    """Convert a color segmentation image to binary masks for each class."""
    h, w, _ = image.shape
    num_classes = len(class_colors)
    binary_masks = np.zeros((h, w, num_classes), dtype=np.uint8)
    
    for i, color in enumerate(class_colors):
        mask = np.all(image == color, axis=-1).astype(np.uint8)
        binary_masks[:, :, i] = mask
    return binary_masks

def dice_coefficient_binary(pred_mask, gt_mask):
    """Calculate the Dice Coefficient for a single binary mask."""
    TP = np.sum(pred_mask * gt_mask)  # True Positives
    FP = np.sum(pred_mask * (1 - gt_mask))  # False Positives
    FN = np.sum((1 - pred_mask) * gt_mask)  # False Negatives
    
    # Debugging info
    #print(f"TP: {TP}, FP: {FP}, FN: {FN}"
    dice = (2 * TP) / (2 * TP + FP + FN) if (2 * TP + FP + FN) > 0 else 1.0
    return dice

def mean_dice_coefficient_colored(pred_image, gt_image, class_colors):
    """Calculate the mean Dice Coefficient across all classes for colored images."""
    pred_binary_masks = create_binary_masks(pred_image, class_colors)
    gt_binary_masks = create_binary_masks(gt_image, class_colors)
    
    dice_scores = []
    num_classes = len(class_colors)
    
    for class_id in range(num_classes):
        pred_mask = pred_binary_masks[:, :, class_id]
        gt_mask = gt_binary_masks[:, :, class_id]
        dice = dice_coefficient_binary(pred_mask, gt_mask)
        dice_scores.append(dice)
    
    mean_dice = np.mean(dice_scores)
    return mean_dice, dice_scores

def calculate_mean_dice_for_folder(pred_folder, gt_folder, class_colors, model):
    """Calculate mean Dice Coefficient for all images in a folder."""
    pred_files = sorted([f for f in os.listdir(pred_folder) if f.endswith('.jpg') or f.endswith('.png')])
    gt_files = sorted([f for f in os.listdir(gt_folder) if f.endswith('.jpg') or f.endswith('.png')])
    
    overall_dice_scores = np.zeros(len(class_colors))
    num_images = len(pred_files)
    
    for pred_file, gt_file in zip(pred_files, gt_files):
        # Predict the segmentation for the image
        pred_image_path = os.path.join(pred_folder, pred_file)
        pred_image = predict(model=model, inp=pred_image_path)  # Predict the image using your model
        
        # Load the ground truth image
        gt_image = cv2.imread(os.path.join(gt_folder, gt_file))
        
        # Calculate Dice Coefficient for the current pair of images
        mean_dice, dice_scores = mean_dice_coefficient_colored(pred_image, gt_image, class_colors)
        
        # Accumulate the dice scores for each class
        overall_dice_scores += np.array(dice_scores)
    
    # Calculate the mean Dice for each class and overall
    mean_dice_by_class = overall_dice_scores / num_images
    overall_mean_dice = np.mean(mean_dice_by_class)
    
    return overall_mean_dice, mean_dice_by_class

# Folder paths
pred_folder = 'evaluation_data/eval_images'
gt_folder = 'evaluation_data/eval_predicted_images'
# Calculate mean Dice Coefficient across all images and classes
overall_mean_dice, mean_dice_by_class = calculate_mean_dice_for_folder(pred_folder, gt_folder, class_colors, model)

print(f"Overall Mean Dice Coefficient: {overall_mean_dice}")
print(f"Mean Dice Coefficient by Class: {mean_dice_by_class}")


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 536ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 328ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 337ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 334ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 352ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 351ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 338ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 336ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 343ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 338ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 352ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 350ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 