In [None]:
from skimage import segmentation, io
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import calinski_harabasz_score, silhouette_score
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import pandas as pd
from utils.utils import generate_synthetic_image, plot_dendrogram
from utils.image_manager import ImagesManager
from utils.mask_manager import MaskManager
from utils.features_manager import FeaturesManager
from utils.clusteriser import Clusteriser
import os
import warnings
plt.rcParams["axes.grid"] = False
from sklearn.preprocessing import scale
from quickshift_tabular.quickshift_tabular import quickshift_tab

In [None]:
model = keras.models.load_model('./model/model.h5')

In [None]:
path_to_images ='./data/debug'
images_extension='jpg'

In [None]:
def get_superpixels_information(path_to_images, images_extension,
                                normalise_features=True, num_channels=3,
                                kernel_size=4, max_dist=20, ratio=.2,
                                kernel_size_tab=1, max_dist_tab=0.3, ratio_tab=0.5,
                                verbose=True, num_working_examples=None):

    # Get images name
    images_manager = ImagesManager(location=path_to_images, images_extension=images_extension)
    images_names = images_manager.get_images_names()
    if num_working_examples is not None:
        images_to_take = num_working_examples
        if images_to_take>len(images_names):
            images_to_take = len(images_names)
        images_names = images_names[:images_to_take]

    dict_masks = {}
    list_num_superpixels = []
    spatial_columns = ['mass_center0', 'mass_center1']

    for image_name in images_names:
        # Get the image
        img = images_manager.get_image(image_name)

        # Get the mask
        masks_manager = MaskManager()
        mask = masks_manager.get_mask(img=img, kernel_size=kernel_size, max_dist=max_dist, ratio=ratio)
        num_superpixels_bf = np.unique(mask)
        if verbose:
            print('---------------------------------------------------------')
            print('Working on image:', image_name)
            print('\tNum superpixels before tabular quickshift:', len(num_superpixels_bf))

        #Create features for each superpixel
        features_manager = FeaturesManager(img=img, image_name=image_name,
                                           mask=mask, normalise_features=True, num_channels=num_channels)
        df_features, deleted_superpixels = features_manager.get_features()

        # Use tabluar quickshift to cluster some superpixels
        col_names = df_features.columns.tolist()
        col_reodered = spatial_columns + [x for x in col_names if x not in spatial_columns]
        df_features = df_features.reindex(columns=col_reodered)
        mask_list = [x[0] for x in df_features.index.tolist()]

        X = df_features.to_numpy()

        mask_tabular_quickshift_list = quickshift_tab(X=X, ratio=ratio_tab,
                                                      kernel_size=kernel_size_tab, max_dist=max_dist_tab)
        mask_tabluar_quickshift = np.empty(shape=mask.shape)

        if verbose:
            print('\tNum superpixels after tabular quickshift:', len(np.unique(mask_tabular_quickshift_list)))
        
        for original_value, new_value in zip(mask_list, mask_tabular_quickshift_list):
            idx = np.where(mask == original_value)
            mask_tabluar_quickshift[idx] = new_value
        dict_masks[image_name] = mask_tabluar_quickshift
        list_num_superpixels.append(len(np.unique(mask_tabular_quickshift_list)))

        if verbose:
            plt.imshow(img)
            plt.show()
            plt.imshow(mask_tabluar_quickshift)
            plt.show()

            for sp_value in np.unique(mask_tabular_quickshift_list):
                mk = mask_tabluar_quickshift == sp_value
                mk = mk[..., np.newaxis]

                img_occluded = img * mk
                plt.imshow(img_occluded, vmin=np.min(img_occluded), vmax=np.max(img_occluded))
                plt.show()

        # Use the new segmentation to create a tabular dataset
        features_manager = FeaturesManager(img=img, image_name=image_name,
                                           mask=mask_tabluar_quickshift,
                                           normalise_features=True, num_channels=num_channels)
        df_features, deleted_superpixels = features_manager.get_features()
        return df_features, deleted_superpixels, dict_masks, list_num_superpixels


In [None]:
data = get_superpixels_information(path_to_images=path_to_images, images_extension=images_extension,
                                   normalise_features=True, num_channels=3,
                                   kernel_size=4, max_dist=20, ratio=.2,
                                   kernel_size_tab=1, max_dist_tab=0.3, ratio_tab=0.5,
                                   verbose=True, num_working_examples=1)

In [None]:
df_features = data[0]
deleted_superpixels = data[1]
dict_masks = data[2]
list_num_superpixels = data[3]

In [None]:
df_features_drop = df_features.drop(columns = ['mass_center0', 'mass_center1'])

In [None]:
class Clusteriser:
    
    def __init__(self, num_superpixels, df, verbose):
        self.num_superpixels = num_superpixels
        self.df = df
        self.verbose = verbose

    def get_fixed_clusters(self, n_cluster):
        
        cluster = AgglomerativeClustering(n_clusters=n_cluster, linkage='complete', affinity='l1')
        results = cluster.fit(self.df)
        labels = results.labels_
        #metric = calinski_harabasz_score(self.df, labels)
        metric = silhouette_score(self.df, labels, metric='euclidean')

        return metric, labels
    
    def get_mutiple_clusters(self, limit=None):
        
        mean_superpixels = round(np.mean(self.num_superpixels))
        total_clusters = mean_superpixels
        
        if limit is not None:
            if limit < mean_superpixels:
                total_clusters = limit
       
        n_clusters = range(2, total_clusters)

        n_clusters_list = []
        metric_list = []

        for n_cluster in n_clusters:
            metric, _ = self.get_fixed_clusters(n_cluster)
            metric_list.append(metric)
            n_clusters_list.append(n_cluster)
            if self.verbose:
                print('n_cluster:', n_cluster, 'metric:', metric)

        return n_clusters_list, metric_list
        
    #def get_best_clusterisation(self):
    #    n_clusters_list, metric_list = self.get_mutiple_clusters()
    #    idx_best_metric = np.argmax(metric_list)
    #    best_n_clusters = n_clusters_list[idx_best_metric]

    #    _, labels = self.get_fixed_clusters(best_n_clusters)
    #    return np.array(labels)

In [None]:
cluster = Clusteriser(num_superpixels=list_num_superpixels, df=df_features_drop, verbose=True)
n_clusters, metric_list = cluster.get_mutiple_clusters(limit=20)

In [None]:
plt.scatter(n_clusters, metric_list)

In [None]:
metric, new_segmentation = cluster.get_fixed_clusters(n_cluster=6)
print(new_segmentation)

In [None]:
df_features_drop.head()

In [None]:
for cl in new_segmentation:
    # Coger todas las (imagenes, superpixeles)
    # Para cada imagen, pintar los superpixeles que se activan