In [None]:
from skimage import segmentation, io
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import calinski_harabasz_score, silhouette_score
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import pandas as pd
from utils.utils import generate_synthetic_image, plot_dendrogram
from utils.image_manager import ImagesManager
from utils.mask_manager import MaskManager
from utils.features_manager import FeaturesManager
import os
import warnings
plt.rcParams["axes.grid"] = False
from sklearn.preprocessing import scale

In [None]:
model = keras.models.load_model('./model/model.h5')

In [None]:
#path_to_images = './data/eval_xai_method/dog_resized'
path_to_images ='./data/debug'
all_images = os.listdir(path_to_images)
all_images = [x for x in all_images if x != '.DS_Store']

In [None]:
all_images

In [None]:
num_images = len(all_images)

In [None]:
images_chosen = np.random.choice(all_images, size=num_images, replace=False)

In [None]:
images_chosen

In [None]:
prediction_list = []
for image_name in images_chosen:
    image_path = os.path.join(path_to_images, image_name)
    img = io.imread(image_path)
    img_norm = img/255
    img_batch = np.expand_dims(img_norm, axis=0)
    pred = model.predict(img_batch)
    prediction_list.append(pred)
    io.imshow(img)
    plt.show()
    print(pred)

In [None]:
ratio_pixels_gaussian_kernel = 0.0075 #Corresponde a sigma=4 para una imagen 200x200
chi2_95 = 6

for image_name in images_chosen[:1]:
    image_path = os.path.join(path_to_images, image_name)
    img = io.imread(image_path)
    img_norm = img/255
    x_pixels, y_pixels, _ = img_norm.shape
    area_total = x_pixels*y_pixels
    s = np.sqrt(area_total)/200
    max_dist = 10*s
    ratio = 1/2*s
    sigma_sq = area_total*ratio_pixels_gaussian_kernel/(chi2_95*np.pi)
    sigma = np.sqrt(sigma_sq)
    superpixels = segmentation.quickshift(img_norm, kernel_size=sigma, max_dist=max_dist, ratio=ratio)
    num_superpixels_image = np.unique(superpixels).shape[0]
    img_batch = np.expand_dims(img_norm, axis=0)
    pred = model.predict(img_batch)
    print('num_superpixels:', num_superpixels_image)
    print('image_shape:', img_norm.shape)
    print('sigma:', sigma)
    print('max_dist:', max_dist)
    print('ratio:', ratio)
    io.imshow(segmentation.mark_boundaries(img_norm, superpixels, color=(1,0,0), mode='inner'))
    plt.show()

In [None]:
def get_superpixels_information(path_to_images, images_extension='jpg',
                                normalise_features=False, num_channels=3,
                                kernel_size=4, max_dist=10, ratio=.3):
    
    images_manager = ImagesManager(location=path_to_images, images_extension=images_extension)
    images_names = images_manager.get_images_names()
    
    df_features_list = []
    dict_masks = {}
    dict_deleted_superpixels = {}
    num_superpixels = []
    
    for image_name in images_names:
        masks_manager = MaskManager(location=path_to_images, image_name=image_name)
        mask = masks_manager.get_mask(kernel_size=kernel_size, max_dist=max_dist, ratio=ratio)
        dict_masks[image_name] = mask
        num_superpixels.append(len(np.unique(mask)))
        features_manager = FeaturesManager(location=path_to_images, image_name=image_name, 
                                           mask=mask, normalise_features=True, num_channels=num_channels)
        df_features, deleted_superpixels = features_manager.get_features()
        
        df_features_list.append(df_features)
        
        if len(deleted_superpixels)>0:
            dict_deleted_superpixels[image_name] = deleted_superpixels
        
    df_all = pd.concat(df_features_list)
    
    are_empty_values = np.max(df_all.isna().any())
    
    if are_empty_values:
        warnings.warn('There are null values in the dataset')
    
    
    return df_all, dict_masks, dict_deleted_superpixels, num_superpixels

In [None]:
df_features, masks, superpixels_few_pixels, num_superpixels = get_superpixels_information(
    './data/debug', 
    normalise_features=True
)

In [None]:
df_features_drop = df_features.drop(columns=['mass_center0', 'mass_center1'])
df_indexes_name = df_features_drop.index.names
df_features_no_indexes = df_features_drop.reset_index(inplace=False)

In [None]:
df_features_drop.head()

In [None]:
class Clusteriser:
    
    def __init__(self, num_superpixels, df):
        self.num_superpixels = num_superpixels
        self.df = df

    def get_fixed_clusters(self, n_cluster):
        
        cluster = AgglomerativeClustering(n_clusters=n_cluster, linkage='complete', affinity='l1')
        results = cluster.fit(self.df)
        labels = results.labels_
        #metric = calinski_harabasz_score(self.df, labels)
        metric = silhouette_score(self.df, labels, metric='euclidean')

        return metric, labels
    
    def get_mutiple_clusters(self):

        mean_superpixels = round(np.mean(self.num_superpixels))
        n_clusters = range(2, mean_superpixels)
        n_clusters_list = []
        metric_list = []

        for n_cluster in n_clusters:
            metric, _ = self.get_fixed_clusters(n_cluster)
            metric_list.append(metric)
            n_clusters_list.append(n_cluster)
            print('n_cluster:', n_cluster, 'metric:', metric)

        return n_clusters_list, metric_list
        
    def get_best_clusterisation(self):
        n_clusters_list, metric_list = self.get_mutiple_clusters()
        idx_best_metric = np.argmax(metric_list)
        best_n_clusters = n_clusters_list[idx_best_metric]

        _, labels = self.get_fixed_clusters(10)
        return np.array(labels)

In [None]:
cluster = Clusteriser(num_superpixels=num_superpixels, df=df_features_drop)
new_segmentation = cluster.get_best_clusterisation()
old_segmentation = df_features_no_indexes.num_superpixel.to_numpy()
images_names = df_features_no_indexes.image_name
df_mapping = pd.DataFrame({
    'new_segmentation': new_segmentation,
    'old_segmentation': old_segmentation,
    'image_name': images_names
})

In [None]:
np.unique(new_segmentation)

In [None]:
def get_new_masks(images_names, old_masks, df_mapping):
    
    masks = dict()
    
    for image_name in images_names:
        old_mask = old_masks[image_name]
        new_mask = np.empty(shape=old_mask.shape, dtype=int)
        idx = np.where(df_mapping.image_name == image_name)
        df_filter = df_mapping.iloc[idx]
        unique_new_clusters = np.unique(df_filter.new_segmentation)

        for cl in unique_new_clusters:
            idx = np.where(df_filter.new_segmentation == cl)
            old_clusters_related = df_filter.old_segmentation.iloc[idx].to_list()
            idx = np.where(np.isin(old_mask, old_clusters_related))
            new_mask[idx] = cl

        masks[image_name] = new_mask
    return masks


def print_image_mask(masks, images_names, path_to_images):

    for image_name in images_names:
        mask = masks[image_name]
        image_path = os.path.join(path_to_images, image_name)
        img = io.imread(image_path)
        io.imshow(img)
        plt.show()
        plt.imshow(mask, vmin=np.min(mask), vmax=np.max(mask))
        plt.show()

In [None]:
new_segmentation_dic = get_new_masks(
    images_names=np.unique(df_features_no_indexes.image_name),
    old_masks=masks,
    df_mapping=df_mapping
)

In [None]:
print_image_mask(new_segmentation_dic, masks.keys(), path_to_images)

In [None]:
unique_clusters = np.unique(new_segmentation)
for cluster_selected in unique_clusters:
    print('cluster:', cluster_selected)
    idx = np.where(new_segmentation == cluster_selected)
    df_filter = df_features_no_indexes.iloc[idx]
    images_names = np.unique(df_features_no_indexes.image_name.iloc[idx])
    
    for image_name in images_names:
        idx_image_name = np.where(df_filter.image_name == image_name)
        image_superpixels = np.unique(df_filter.num_superpixel.iloc[idx_image_name])
        idx_image_name_superpixels = np.where(df_filter.num_superpixel.isin(image_superpixels))
        df_image_superpixels = df_filter.iloc[idx_image_name_superpixels]
        current_img_path = os.path.join(path_to_images, image_name)
        current_img = io.imread(current_img_path)
        current_mask = masks[image_name]
        new_mask = np.isin(current_mask, image_superpixels)
        new_mask = new_mask[..., np.newaxis]
        new_image = current_img * new_mask 
        #io.imshow(new_image)
        #plt.show()
        plt.imshow(new_image, vmin=np.min(new_image), vmax=np.max(new_image))
        plt.show()
    print('--------------------------------------------')

# Trash