In [None]:
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from utils.image_manager import ImagesManager
from utils.mask_manager import MaskManager
from utils.features_manager import FeaturesManager
from quickshift_tabular.quickshift_tabular import quickshift_tab

In [None]:
path_to_images='./data/debug'
images_extension = 'jpg'
classifier = keras.models.load_model('./model/model.h5')

In [None]:
# Pretrained model
pretrained_model = ResNet50(weights='imagenet', include_top=True)
output_interest = pretrained_model.layers[-2].output
model = keras.models.Model(inputs=pretrained_model.inputs, outputs=output_interest)

In [None]:
for layer in model.layers:
    layer.trainable = False

In [None]:
def get_superpixels_information(path_to_images, images_extension,
                                normalise_features=True, num_channels=3,
                                kernel_size=4, max_dist=20, ratio=.2,
                                kernel_size_tab=1, max_dist_tab=0.3, ratio_tab=0.5,
                                num_working_examples=None,
                                use_model_to_extract_features=False,
                                model_feature_extractor=None, target_shape=None,
                                verbose=True, want_to_plot_clusters=False,
                                random_seed=42):

    # Get images name
    images_manager = ImagesManager(location=path_to_images, images_extension=images_extension)
    images_names = images_manager.get_images_names()
    
    # Select the number of examples
    if num_working_examples is not None:
        images_to_take = num_working_examples
        if images_to_take>len(images_names):
            images_to_take = len(images_names)
        images_names = images_names[:images_to_take]

    dict_masks = {}
    list_num_superpixels = []
    list_df = []

    for image_name in images_names:
        # Get the image
        img = images_manager.get_image(image_name)

        # Get the mask
        masks_manager = MaskManager()
        mask = masks_manager.get_mask(img=img, kernel_size=kernel_size, max_dist=max_dist, ratio=ratio, random_seed=random_seed)
        num_superpixels_bf = np.unique(mask)
        if verbose:
            print('---------------------------------------------------------')
            print('Working on image:', image_name)
            print('\tNum superpixels before tabular quickshift:', len(num_superpixels_bf))

        #Create features for each superpixel
        features_manager = FeaturesManager(img=img, image_name=image_name, mask=mask)
        df_features, list_total_zeros = features_manager.get_features(use_model_to_extract_features=use_model_to_extract_features,
                                                                      model_feature_extractor=model_feature_extractor,
                                                                      target_shape=target_shape,
                                                                      normalise_features=normalise_features,
                                                                      num_channels=num_channels)
        # Print number of features equal to 0
        if verbose and use_model_to_extract_features:
            print('\tTotal features equal to 0:', int(np.mean(list_total_zeros)))
            
        # Use tabluar quickshift to cluster some superpixels
        col_names = df_features.columns.tolist()
        mask_list = [x[1] for x in df_features.index.tolist()]
        X = df_features.to_numpy()
        mask_tabular_quickshift_list = quickshift_tab(X=X, 
                                                      superpixels_order=np.array(mask_list), 
                                                      mask=mask,
                                                      ratio=ratio_tab,
                                                      kernel_size=kernel_size_tab,
                                                      max_dist=max_dist_tab,
                                                      random_seed=random_seed)
    
        # Reset index after running tabular quickshift
        indexes_names = df_features.index.names
        
        df_features.reset_index(inplace=True)
        df_features['num_superpixel'] = mask_tabular_quickshift_list
        df_features.set_index(indexes_names, inplace=True)

        # Build n-shaped mask
        mask_tabluar_quickshift = np.empty(shape=mask.shape)

        if verbose:
            print('\tNum superpixels after tabular quickshift:', len(np.unique(mask_tabular_quickshift_list)))
            
        # Create the new mask after using tabular quickshit
        for original_value, new_value in zip(mask_list, mask_tabular_quickshift_list):
            idx = np.where(mask == original_value)
            mask_tabluar_quickshift[idx] = new_value
        dict_masks[image_name] = mask_tabluar_quickshift
        list_num_superpixels.append(len(np.unique(mask_tabular_quickshift_list)))

        # Append the dataframe for that particular image
        list_df.append(df_features)
        
        if verbose:
            plt.imshow(img)
            plt.show()
            plt.imshow(mask_tabluar_quickshift)
            plt.show()

            if want_to_plot_clusters:
                for sp in np.unique(mask_tabular_quickshift_list):
                    mask_sp = mask_tabluar_quickshift == sp
                    img_occluded = img * mask_sp[..., np.newaxis]
                    plt.imshow(img_occluded)
                    plt.show()
        
        
    # Create a single dataframe from the list of dataframe
    df = pd.concat(list_df)
    
    return df, list_num_superpixels, dict_masks
    

In [None]:
data = get_superpixels_information(path_to_images=path_to_images, 
                                   images_extension=images_extension,
                                   normalise_features=True, num_channels=3,
                                   # Parameters for quickshit
                                   kernel_size=4, max_dist=20, ratio=.2,
                                   # Parameter for tabular quickshift
                                   kernel_size_tab=1, max_dist_tab=12, ratio_tab=0.5,
                                   # Number of images to take into account
                                   num_working_examples=2,
                                   # Model or manual features
                                   use_model_to_extract_features=False, model_feature_extractor=model,
                                   # If model is used to extract features, the reshape the patch
                                   target_shape=(224, 224),
                                   # Plot options
                                   verbose=True, 
                                   want_to_plot_clusters=True)