In [None]:
#!pip install opencv-python scikit-image numpy matplotlib fast_slic

# Superpixel functions definitions

In [None]:
import os
import time

import cv2
from fast_slic import Slic
from fast_slic.avx2 import SlicAvx2
import numpy as np
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import felzenszwalb, slic, quickshift, watershed

import glob
from tqdm import tqdm


def read_image(fn):
    """Reads an image from the specified file path.

    Args:
        fn (str): File path of the image.

    Returns:
        numpy.ndarray: The image array.
    """
    return cv2.imread(fn, cv2.IMREAD_UNCHANGED)


def write_image(fn, img):
    """Writes an image to the specified file path.

    Args:
        fn (str): File path to write the image.
        img (numpy.ndarray): The image array to be written.
    """
    cv2.imwrite(fn, img)


def map_labels(labels, label_map, dtype=None):
    """Maps the given label matrix using the specified mapping dictionary.

    Args:
        labels (numpy.ndarray): Label matrix consisting of hashable objects.
        label_map (dict): Dictionary defining the mapping of the labels.
        dtype (data-type, optional): Data type of the mapping target.
            If not specified, it uses the type of the first value of `label_map`.

    Returns:
        numpy.ndarray: Mapped labels by the `label_map`.
    """
    if len(label_map) == 0:
        return labels.reshape((*labels.shape, 1))

    label_ids = np.unique(labels)
    tmp = np.asarray(list(label_map.values())[0])

    dim = tmp.shape
    assert len(dim) <= 1
    if len(dim) == 0:
        dim = 1
    else:
        dim = dim[0]
    assert dim > 0

    if dtype is None:
        dtype = tmp.dtype

    lookup_table = np.empty((label_ids[-1] + 1, dim), dtype=dtype)
    for label_id in label_ids:
        if label_id in label_map:
            lookup_table[label_id] = label_map[label_id]
        else:
            lookup_table[label_id] = label_id
    relabeled = lookup_table[labels.reshape(-1), :].reshape((*labels.shape, dim))
    return relabeled


def fast_slic(img, n_sp, compactness=10., max_num_iter=5):
    """Performs FastSLIC superpixel segmentation on the input image.

    Args:
        img (numpy.ndarray): The input image.
        n_sp (int): The desired number of superpixels.
        compactness (float, optional): Compactness parameter. Default is 10.
        max_num_iter (int, optional): Maximum number of iterations. Default is 5.
    """
    slic(img, n_sp, compactness=10., max_num_iter=5)  # , mask=np.sum(mask == class_colors[0], axis=-1) != 3)


def slic_postprocessing(img_fn, mask_fn, n_sp=1600, method="fast_slic",
                        use_optimized_label_selection=True,
                        compactness=30, scale=100, sigma=0.5, min_size=50,
                        kernel_size=3, max_dist=6, ratio=0.5):
    """Performs post-processing on the segmented image.

    Args:
        img_fn (str): File path of the original image.
        mask_fn (str): File path of the segmented image.
        n_sp (int, optional): Number of superpixels. Default is 1600.
        method (str, optional): Segmentation method to use.
            Supported methods: "slic", "fast_slic", "felzenszwalb", "quickshift", "watershed".
            Default is "fast_slic".
        use_optimized_label_selection (bool, optional): Whether to use optimized label selection.
            Default is True.
        compactness (float, optional): Compactness parameter for SLIC methods. Default is 30.
        scale (float, optional): Scale parameter for Felzenszwalb method. Default is 100.
        sigma (float, optional): Sigma parameter for Felzenszwalb method. Default is 0.5.
        min_size (int, optional): Minimum component size for Felzenszwalb method. Default is 50.
        kernel_size (int, optional): Kernel size for Quickshift method. Default is 3.
        max_dist (float, optional): Maximum distance for Quickshift method. Default is 6.
        ratio (float, optional): Ratio parameter for Quickshift method. Default is 0.5.

    Returns:
        numpy.ndarray: Post-processed segmentation mask.
    """
    start = time.time()
    img = read_image(img_fn)
    mask = read_image(mask_fn)
    print("IO:", time.time() - start, "seconds")

    start = time.time()
    colors = mask.astype(np.uint64)
    if not use_optimized_label_selection:
        class_colors = np.unique(mask.reshape(-1, 3), axis=0)
    else:
        colors = 1000000 * colors[..., 0] + 1000 * colors[..., 1] + colors[..., 2]
        unique_colors = np.unique(colors)
        tmp = unique_colors.copy()
        class_colors = np.empty((len(unique_colors), 3), dtype=np.uint8)
        class_colors[:, 0] = tmp // 1000000; tmp %= 1000000
        class_colors[:, 1] = tmp // 1000; tmp %= 1000
        class_colors[:, 2] = tmp
    print("Class labels:", time.time() - start, "seconds")

    start = time.time()
    if method == "slic":
        sp_labels = slic(img, n_sp, compactness=compactness, max_num_iter=5)
    if method == "fast_slic":
        slic_out = Slic(num_components=n_sp, compactness=compactness)
        sp_labels = slic_out.iterate(img)
    if method == "felzenszwalb":
        sp_labels = felzenszwalb(img, scale=scale, sigma=sigma, min_size=min_size)
    if method == "quickshift":
        sp_labels = quickshift(img, kernel_size=kernel_size, max_dist=max_dist, ratio=ratio)
    if method == "watershed":
        gradient = sobel(rgb2gray(img))
        sp_labels = watershed(gradient, markers=n_sp, compactness=0.0001)
    print("SP method:", time.time() - start, "seconds")

    start = time.time()
    segmentation = np.empty((*mask.shape[:-1], len(class_colors)), dtype=np.uint8)
    label_counts = np.empty((sp_labels.max() + 1, len(class_colors)), dtype=np.int32)
    for idx, color in enumerate(class_colors):
        if not use_optimized_label_selection:
            segmentation[..., idx] = np.sum(mask == color, axis=-1) == 3
        else:
            segmentation[..., idx] = colors == unique_colors[idx]
        label_counts[:, idx] = np.bincount(sp_labels.flatten(), segmentation[..., idx].flatten())

    label_map = dict(zip(list(range(len(label_counts))), class_colors[label_counts.argmax(axis=-1)]))
    out_mask = map_labels(sp_labels, label_map)
    print("Mapping:", time.time() - start, "seconds")
    return out_mask

# Running the experiments

In [None]:
def generate_data(method_name, n_sp, compactness, used_method=""):
    """Generate data using the specified segmentation method.

    Args:
        method_name (str): Name of the segmentation method.
        n_sp (int): Number of superpixels.
        compactness (float): Compactness parameter.
        used_method (str, optional): Used method information. Defaults to "".

    Returns:
        None
    """
    for env in tqdm(data_path):
        dir_path = f"{used_method}_{method_name}_{str(n_sp)}_{str(compactness)}"
        smoothed_dir = os.path.join(env[:-6], dir_path)
        supix_dir = env
        if not os.path.exists(smoothed_dir):
            os.makedirs(smoothed_dir)

        images_name = os.listdir(env)
        for image_name in images_name:
            if image_name.endswith(".png"):
                out_img = slic_postprocessing(os.path.join(env, image_name), os.path.join(supix_dir, image_name),
                                              n_sp=n_sp, method=method_name, use_optimized_label_selection=True,
                                              compactness=compactness)
                write_image(os.path.join(smoothed_dir, image_name), out_img)


In [None]:
dataset_dir = "./semantic_matching"
data_path = []
used_method = ["supix", "transf", "seman"]:
environments_name = glob.glob(f"{dataset_dir}/*")
for env_name in environments_name:
    env_states = glob.glob(f"{env_name}/*")
    for env_state in env_states:
        env_views = glob.glob(f"{env_state}/*")
        for segms in env_views:
            if segms.split("/")[-1] in used_method:
                # logging.info(f"Loading annotation file: {os.path.join(segms, method_name)}")
                data_path.append(segms)
print(data_path)

In [None]:
method_name = "fast_slic"
n_sp = 100
compactness  = 10
generate_data(method_name, n_sp, compactness, used_method[0])

In [None]:
method_name = "fast_slic"
n_sp = 100
compactness  = 10
generate_data(method_name, n_sp, compactness, used_method[1])

In [None]:
method_name = "fast_slic"
n_sp = 100
compactness  = 10
generate_data(method_name, n_sp, compactness, used_method[2])