In [1]:
from collections import defaultdict
import cv2
from cv2 import ximgproc
import numpy as np
from tqdm import tqdm
import pandas as pd
from skimage.segmentation import felzenszwalb, slic, quickshift, watershed
from skimage.color import rgb2gray
from skimage.filters import sobel

from bsds_datamodule import BSDSDatamodule
import superpixel_benchmark as benchmark
import superpixel_tools as tools

pd.set_option("display.max_columns", None)

In [2]:
def get_metrics_dict():
    methods = ["slic", "quick", "watershed", "felzenszwalb", "seeds"]
    metrics_dict = {}
    for method in methods:
        metrics_dict[method] = {
            "undersegmentation_error": [],
            "np_undersegmentation_error": [],
            "boundary_recall": [],
            "boundary_precision": [],
            "achievable_segmentation_acc": [],
            "compactness": [],
            "explained_variation": [],
            "mean_distance_to_edge": [],
            "intra_cluster_variation": [],
            "contour_density": [],
            "number_of_components": [],
        }

    return metrics_dict


def compute_segments(img):
    # compute segments
    segments_fz = felzenszwalb(img, scale=400, sigma=0.5, min_size=50)

    segments_slic = slic(
        img, n_segments=400, compactness=10, sigma=0.5, start_label=1, max_iter=25
    )

    segments_quick = quickshift(img, kernel_size=10, max_dist=10, ratio=0.5)

    gradient = sobel(rgb2gray(img))
    segments_watershed = watershed(gradient, markers=400, compactness=0.4)

    # set parameters for superpixel segmentation
    # convert img color space
    converted_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    num_superpixels = 400  # desired number of superpixels
    num_iterations = (
        10  # number of pixel level iterations. The higher, the better quality
    )
    prior = 0  # for shape smoothing term. must be [0, 5]
    num_levels = 4
    num_histogram_bins = 9  # number of histogram bins
    height, width, channels = converted_img.shape
    seeds = ximgproc.createSuperpixelSEEDS(
        width,
        height,
        channels,
        num_superpixels,
        num_levels,
        prior,
        num_histogram_bins,
    )
    seeds.iterate(converted_img, num_iterations)
    segments_seeds = seeds.getLabels()

    # relabel segments
    segments_fz = tools.relabel_connected_superpixels(
        segments_fz.astype(np.intc)
    )
    segments_slic = tools.relabel_connected_superpixels(
        segments_slic.astype(np.intc)
    )
    segments_quick = tools.relabel_connected_superpixels(
        segments_quick.astype(np.intc)
    )
    segments_watershed = tools.relabel_connected_superpixels(
        segments_watershed.astype(np.intc)
    )
    segments_seeds = tools.relabel_connected_superpixels(
        segments_seeds.astype(np.intc)
    )

    return {
        "felzenszwalb": segments_fz,
        "slic": segments_slic,
        "quick": segments_quick,
        "watershed": segments_watershed,
        "seeds": segments_seeds,
    }


def compute_metrics(metrics, segments_dict, gt, img):
    for method, segmentation in segments_dict.items():
        metrics[method]["undersegmentation_error"].append(
            benchmark.compute_undersegmentation_error(segmentation, gt)
        )
        metrics[method]["np_undersegmentation_error"].append(
            benchmark.compute_np_undersegmentation_error(segmentation, gt)
        )
        metrics[method]["boundary_recall"].append(
            benchmark.compute_boundary_recall(segmentation, gt)
        )
        metrics[method]["boundary_precision"].append(
            benchmark.compute_boundary_precision(segmentation, gt)
        )
        metrics[method]["achievable_segmentation_acc"].append(
            benchmark.compute_achievable_segmentation_accuracy(
                segmentation, gt
            )
        )
        metrics[method]["compactness"].append(
            benchmark.compute_compactness(segmentation)
        )
        metrics[method]["explained_variation"].append(
            benchmark.compute_explained_variation(segmentation, img)
        )
        metrics[method]["mean_distance_to_edge"].append(
            benchmark.compute_mean_distance_to_edge(segmentation, gt)
        )
        metrics[method]["intra_cluster_variation"].append(
            benchmark.compute_intra_cluster_variation(segmentation, img)
        )
        metrics[method]["contour_density"].append(
            benchmark.compute_contour_density(segmentation)
        )
        metrics[method]["number_of_components"].append(
            tools.count_superpixels(segmentation)
        )

    return metrics


def compute_avg_metrics(all_metrics):
    avg_metrics = defaultdict(dict)
    for method, metrics in all_metrics.items():
        for metric, vals in metrics.items():
            avg_metrics[method][metric] = sum(vals) / len(vals)

    return avg_metrics

In [3]:
# setup data
dm = BSDSDatamodule(batch_size=1)
dm.prepare_data()
dm.setup()

metrics = get_metrics_dict()
for batch in tqdm(dm.val_dataloader(), desc="Evaluating", colour="BLUE"):
    img, gt, _ = batch.x, batch.y, batch.batch
    img, gt = img.numpy(), gt.numpy()
    segments = compute_segments(img)
    metrics = compute_metrics(metrics, segments, gt, img)

avg_metrics = compute_avg_metrics(metrics)
df = pd.DataFrame.from_dict(avg_metrics)
df

Evaluating: 100%|[34m██████████[0m| 100/100 [14:05<00:00,  8.46s/it]


Unnamed: 0,slic,quick,watershed,felzenszwalb,seeds
undersegmentation_error,0.054039,0.09717,0.073003,0.130776,0.046281
np_undersegmentation_error,0.106684,0.191238,0.144357,0.24534,0.091144
boundary_recall,0.722134,0.539347,0.401053,0.711017,0.853232
boundary_precision,0.115604,0.170654,0.074701,0.199729,0.084098
achievable_segmentation_acc,0.945961,0.90283,0.926997,0.869224,0.953719
compactness,0.309679,0.20879,0.470384,0.117582,0.097973
explained_variation,0.766108,0.686668,0.664917,0.668611,0.799595
mean_distance_to_edge,1.36599,3.859158,2.68055,2.513954,0.768056
intra_cluster_variation,28.859566,11.787597,41.130578,20.50185,30.486117
contour_density,0.209467,0.108188,0.179609,0.132499,0.352237
