# Image Segmentation by Clustering

**Date Created: 12/6/2020**

**The purpose of this notebook is to experiment with image segmentation using clustering algorithms, comparing DyClee's performance against the clustering algorithms implemented in Scikit-Learn that were referenced in the paper:**

Nathalie Barbosa Roa, Louise Travé-Massuyès, Victor Hugo Grisales. DyClee: Dynamic clustering for tracking evolving environments. Pattern Recognition, Elsevier, 2019, 94, pp.162-186. 10.1016/j.patcog.2019.05.024 . hal-02135580

**NOTE: This notebook is partially informed by the segmentation demonstration on pages 249-251 in:**

Aurelien Geron. Hands-on Machine Learning with Scikit-Learn, Keras & Tensorflow: Concepts, Tools, and Techniques to Build Intelligent Systems (Sebastopol, CA: O'Reilly Media, Inc., 2019).

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from sklearn.cluster import MiniBatchKMeans, AgglomerativeClustering, AffinityPropagation, DBSCAN, Birch
from sklearn.datasets import load_sample_images
import tensorflow as tf

from DyClee.algorithms import SerialDyClee
from DyClee.plotting import unpack_snapshots, strip_labels

In [None]:
# Normalize to range 0 - 1.0
def normalize(img):
    return img.astype(np.float64) / 255.0

# Augment with spatial information - pixel location (y, x - aka row, col)
# @param img    A NORMALIZED image.
def augment_location(img):
    shape = img.shape
    rows = np.arange(shape[0])
    cols = np.arange(shape[1])
    yi, xj = np.meshgrid(rows, cols, indexing='ij')
    yi = yi.astype(np.float64) / shape[0]
    xj = xj.astype(np.float64) / shape[1]
    if len(shape) == 2: # Grayscale
        return np.stack((yi, xj, img), axis=-1)
    else:
        yi = np.expand_dims(yi, axis=2)
        xj = np.expand_dims(xj, axis=2)
        return np.concatenate((yi, xj, img), axis=-1)
        
    
# Flatten to treat pixels as instances
def flatten_img(img):
    if len(img.shape) == 2: # Grayscale
        return img.reshape(img.shape[0] * img.shape[1], 1)
    else:
        dims = img.shape[-1]
        return img.reshape(-1, dims)

### Load Sample Images from Scikit-Learn

In [None]:
dataset = load_sample_images()
china, flower = dataset.images
sk_sample_shape = china.shape
print(sk_sample_shape)
print(china.dtype)
fig, ax = plt.subplots(1,2, figsize=(25,25))
ax[0].imshow(china)
ax[1].imshow(flower)

### Mini-Batch K-Means

#### Only Colors

In [None]:
china_norm = normalize(china)
flower_norm = normalize(flower)

num_clusters = [2 ** x for x in range(5)]
fig, ax = plt.subplots(2, len(num_clusters), figsize=(25, 10))

china_flat = flatten_img(china_norm)
for i, num in enumerate(num_clusters):
    # THIS IS BASED ON THE DEMONSTRATION IN THE GERON BOOK - SEE PAGE 250
    china_km = MiniBatchKMeans(n_clusters=num).fit(china_flat)
    result = china_km.cluster_centers_[china_km.labels_]
    result = result.reshape(china.shape)
    ax[0,i].imshow(result)
    
flower_flat = flatten_img(flower_norm)
for i, num in enumerate(num_clusters):
    flower_km = MiniBatchKMeans(n_clusters=num).fit(flower_flat)
    result = flower_km.cluster_centers_[flower_km.labels_]
    result = result.reshape(flower.shape)
    ax[1,i].imshow(result)

#### Augmented with Locations

In [None]:
china_aug = augment_location(china_norm)
flower_aug = augment_location(flower_norm)

num_clusters = [2 ** x for x in range(5)]
fig, ax = plt.subplots(2, len(num_clusters), figsize=(25, 10))

china_flat = flatten_img(china_aug)
for i, num in enumerate(num_clusters):
    china_km = MiniBatchKMeans(n_clusters=num).fit(china_flat)
    result = china_km.cluster_centers_[china_km.labels_]
    print(result.shape)
    aug_shape = (china.shape[0], china.shape[1], 5)
    result = result.reshape(aug_shape)
    result = result[:,:,2:]
    ax[0,i].imshow(result)
    
flower_flat = flatten_img(flower_aug)
for i, num in enumerate(num_clusters):
    flower_km = MiniBatchKMeans(n_clusters=num).fit(flower_flat)
    result = flower_km.cluster_centers_[flower_km.labels_]
    aug_shape = (flower.shape[0], flower.shape[1], 5)
    result = result.reshape(aug_shape)
    result = result[:,:,2:]
    ax[1,i].imshow(result)

# DyClee Tests

**WARNING:** Running on even a limited portion of the China image will take quite a while, and the below phi parameter is NOT well-tuned.

In [None]:
%%time 

china_flat_part = flatten_img(china_aug[:200, :200])

context = np.vstack([china_flat_part.min(axis=0), china_flat_part.max(axis=0)])
dyclee = SerialDyClee(phi=0.1, context=context, t_global=4000)
results = dyclee.run_dataset(data=china_flat_part)
results = strip_labels(results)

timestamps, snapshots_ordered = unpack_snapshots(dyclee.snapshots)
print(snapshots_ordered)

In [None]:
colors = {'0':0.07780313, '1': 0.98997821, 'Unclassed':0.5}
seg_img = np.array([colors[px] for px in results]).reshape(200,200) * 255
plt.imshow(seg_img, cmap="binary") 

### MNIST

In [None]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
five = mnist["data"][0]
five_img = five.reshape(28,28)
plt.imshow(five_img, cmap="binary")

In [None]:
%%time

five_img_norm = normalize(five_img)
five_flat = flatten_img(five_img_norm)

context = np.vstack([five_flat.min(axis=0), five_flat.max(axis=0)])
dyclee = SerialDyClee(phi=0.1, context=context, t_global=784)
results = dyclee.run_dataset(data=five_flat)
results = strip_labels(results)

timestamps, snapshots_ordered = unpack_snapshots(dyclee.snapshots)
print(snapshots_ordered)

In [None]:
colors = {'0':0.07780313, '1': 0.98997821, 'Unclassed':0.5}
seg_img = np.array([colors[px] for px in results]).reshape(28,28) * 255
plt.imshow(seg_img, cmap="binary") 

### CIFAR10

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [None]:
colors = {}
colors['Unclassed'] = np.array([0,0,0])
for i in range (100):
    colors[str(i)] = (np.random.default_rng().integers(32, 256, size=(1,3))) 

In [None]:
bird = x_train[6]
plt.imshow(bird)

In [None]:
%%time

bird_norm = normalize(bird)
bird_norm_aug = augment_location(bird_norm)
bird_aug_flat = flatten_img(bird_norm_aug)

context = np.vstack([bird_aug_flat.min(axis=0), bird_aug_flat.max(axis=0)])
dyclee = SerialDyClee(phi=0.08, context=context, t_global=1024)
results = dyclee.run_dataset(data=bird_aug_flat)
results = strip_labels(results)

timestamps, snapshots_ordered = unpack_snapshots(dyclee.snapshots)
print(snapshots_ordered)

In [None]:
seg_img = np.vstack([colors[px] for px in results]).reshape(32,32, 3)
plt.imshow(seg_img)

In [None]:
horse = x_train[7]
plt.imshow(horse)

In [None]:
%%time

horse_norm = normalize(horse)
horse_norm_aug = augment_location(horse_norm)
horse_aug_flat = flatten_img(horse_norm_aug)

context = np.vstack([horse_aug_flat.min(axis=0), horse_aug_flat.max(axis=0)])
dyclee = SerialDyClee(phi=0.09, context=context, t_global=1024)
results = dyclee.run_dataset(data=horse_aug_flat)
results = strip_labels(results)

timestamps, snapshots_ordered = unpack_snapshots(dyclee.snapshots)
print(snapshots_ordered)

In [None]:
seg_img = np.vstack([colors[px] for px in results]).reshape(32,32, 3)
plt.imshow(seg_img)

In [None]:
bird2 = x_train[18]
plt.imshow(bird2)

In [None]:
%%time

bird2_norm = normalize(bird2)
bird2_norm_aug = augment_location(bird2_norm)
bird2_aug_flat = flatten_img(bird2_norm_aug)

context = np.vstack([bird2_aug_flat.min(axis=0), bird2_aug_flat.max(axis=0)])
dyclee = SerialDyClee(phi=0.09, context=context, t_global=1024)
results = dyclee.run_dataset(data=bird2_aug_flat)
results = strip_labels(results)

timestamps, snapshots_ordered = unpack_snapshots(dyclee.snapshots)
print(snapshots_ordered)

In [None]:
seg_img = np.vstack([colors[px] for px in results]).reshape(32,32, 3)
plt.imshow(seg_img)