# <font color='red'>**KANDINSKY**</font>: Clustering and Quantization   
  
## **K-Means (and friends)**  

*Shaurya Agarwal*

![K-Means and Friends](./images/07-kmeans-and-friends-abstract-cluster.png)

### <font color='green'>__Support for Google Colab__  </font>  
    
open this notebook in Colab using the following button:  
  
<a href="https://colab.research.google.com/github/shauryashaurya/learn-data-munging/blob/main/01-Numpy/01.01-Getting-Started-With-Numpy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>  

  
<font color='green'>uncomment and execute the cell below to setup and run this notebook on Google Colab.</font>

In [None]:
# # SETUP FOR COLAB: select all the lines below and uncomment (CTRL+/ on windows)
# # Let's download and unzip the Small MovieLens Dataset
# ! mkdir ./../data
# ! wget -q https://files.grouplens.org/datasets/movielens/ml-latest-small.zip
# ! unzip ./ml-latest-small.zip -d ./../data/

In [None]:
datalocation = "./../data/ml-latest-small/"

## Setup, imports etc.

In [None]:
import concurrent.futures
import csv
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
from skimage import color
# 
import jax.numpy as jnp
from jax import random, jit

To render Matplotlib plots directly in a JupyterLab notebook, you can use the ```%matplotlib inline``` magic command. This command configures Matplotlib to render its plots inline within the Jupyter notebook cells, immediately below the code cells that produce them.  
For interactive plots within JupyterLab (e.g., for zooming and rotating 3D plots), you can use the ```%matplotlib widget``` magic instead.  
```%matplotlib widget``` requires ```ipympl``` package, ```pip install --upgrade --no-cache-dir ipympl``` if it's missing.  

In [None]:
%matplotlib widget

In [None]:
# hack to make plotly plots show up in the notebook
# pio.renderers.default = "notebook"

In [None]:
image = cv2.imread("./data/cgi/" + "01x25.png")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_rgb.shape
img2 = image_rgb.reshape((-1, image_rgb.shape[2]))
img2.shape

* Look at the color conversion codes here: https://docs.opencv.org/4.9.0/de/d25/imgproc_color_conversions.html  
* OpenCV Image transformation Enumerations: https://docs.opencv.org/3.4.0/d7/d1b/group__imgproc__misc.html

In [None]:
def convert_and_save(image_path, conversion_func, suffix, header):
    """
    Convert image to specified color space and save to CSV.

    Parameters:
    - image_path: Path to the original image.
    - conversion_func: Function to convert the image to the desired color space.
    - suffix: Suffix for the output CSV file name indicating the color space.
    - header: Header row for the CSV file.
    """
    # Load the image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Convert the image using the provided conversion function
    converted_image = conversion_func(image_rgb)

    # Flatten the image array to list pixels
    pixels = converted_image.reshape((-1, converted_image.shape[2]))
    # print('suffix: ', suffix, ' sample pixels: ', pixels[:3])

    # Save to CSV
    base_path = image_path.rsplit(".", 1)[0]
    output_path = f"{base_path}_{suffix}.csv"

    # Using NumPy to directly save to CSV
    np.savetxt(
        output_path,
        pixels,
        delimiter=",",
        header=",".join(header),
        comments="",
        encoding="utf-8",
    )

In [None]:
def image_to_color_spaces_parallel(image_path):
    """
    Convert image to various color spaces using parallel processing.
    """

    # Define conversion functions
    def to_rgb(image_rgb):
        return image_rgb

    def to_xyz(image_rgb):
        return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2XYZ)

    def to_lab(image_rgb):
        return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2Lab)

    def to_hsv(image_rgb):
        return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)

    # TODO: HSL no supported in skimage.color, build support using other libraries later
    def to_hsl(image_rgb):
        return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HLS)  # OpenCV uses HLS naming

    # Conversion specifications (function, suffix, header)
    conversions = [
        (to_rgb, "RGB", ["R", "G", "B"]),
        (to_xyz, "XYZ", ["X", "Y", "Z"]),
        (to_lab, "Lab", ["L*", "a*", "b*"]),
        (to_hsv, "HSV", ["H", "S", "V"]),
        (to_hsl, "HSL", ["H", "S", "L"]),
    ]
    for func, suff, header in conversions:
        convert_and_save(image_path, func, suff, header)

    # # Use ThreadPoolExecutor to parallelize conversions
    # with concurrent.futures.ThreadPoolExecutor() as executor:
    # 	futures = [executor.submit(convert_and_save, image_path, func, suffix, header)
    # 		for func, suffix, header in conversions]

    # # Wait for all futures to complete
    # concurrent.futures.wait(futures)

In [None]:
# Example usage
images = ["01x25.png", "02x25.png", "03x25.png", "04x25.png", "05x25.png", "06x25.png"]
for image in images:
    print("pre-processing image: ", image)
    image_path = "./data/cgi/" + image
    image_to_color_spaces_parallel(image_path)

In [None]:
# Do this for Eight-Down-Toofaan-Mail photographs as well
images = ['001.png','002.png','003.png','004.png','005.png','006.png','007.png','008.png','009.png','010.png','011.png','012.png']
for image in images:
    print("pre-processing image: ", image)
    image_path = "./data/edtf/" + image
    image_to_color_spaces_parallel(image_path)

In [None]:
# if you *HAD* to do it this way, otherwise just use markdown...or HTML
# from IPython.display import Image as IPImage

# def display_image(image_path):
#     """Display an image using IPython display."""
#     display(IPImage(filename=image_path))

In [None]:
# blows up the size of the notebook
# display_image(image_path)

Use this code:
```
![Image Title](Image Path)
```

![sample image](./data/edtf/001.png)

![sample image](./data/edtf/012.png)

In [None]:
def load_data_from_csv(file_path):
    """
    Load color space data from a CSV file.

    Parameters:
    - file_path: Path to the CSV file.

    Returns:
    - Numpy array of the color space data.
    """
    df = pd.read_csv(file_path)
    return df.to_numpy()

In [None]:
def convert_to_rgb(data, color_space):
    """
    Convert data from specified color space to RGB for visualization.

    Parameters:
    - data: Numpy array of the color space data.
    - color_space: The color space of the input data ('XYZ', 'Lab', 'HSV', 'HSL').

    Returns:
    - Numpy array of the data converted to RGB.
    """
    if color_space == "XYZ":
        return (
            (color.xyz2rgb(data.reshape(1, -1, 3)) * 255)
            .reshape(-1, 3)
            .astype(np.uint8)
        )
    elif color_space == "Lab":
        return (
            (color.lab2rgb(data.reshape(1, -1, 3)) * 255)
            .reshape(-1, 3)
            .astype(np.uint8)
        )
    elif color_space == "HSV":
        return (
            cv2.cvtColor(
                data.reshape(1, -1, 3).astype(np.float32), cv2.COLOR_HSV2RGB
            ).reshape(-1, 3)
            * 255
        )
    elif color_space == "HSL":
        return (
            cv2.cvtColor(
                data.reshape(1, -1, 3).astype(np.float32), cv2.COLOR_HLS2RGB
            ).reshape(-1, 3)
            * 255
        )
    else:
        return data  # Assuming RGB or already in RGB

In [None]:
def normalize_colors_for_plotting(data, color_space="RGB"):
    """
    Normalize color values to [0, 1] for plotting, depending on the color space.

    Parameters:
    - data: The color data to be normalized.
    - color_space: The color space of the data.

    Returns:
    - Normalized color data suitable for plotting.
    """
    if color_space in ["RGB", "HSV", "HSL"]:  # These are typically already in [0, 1]
        return data
    elif color_space == "XYZ":
        # Normalize XYZ based on typical ranges (may need adjustment based on actual data range)
        data /= np.max(data, axis=0)
    elif color_space == "Lab":
        # Normalize L in [0, 100], a and b in approximately [-110, 110]
        data = (data + [0, 110, 110]) / [100, 220, 220]
    return data

In [None]:
def plot_image_and_color_space(image_path, color_space="RGB"):

    # Plot the loaded image and its 3D point cloud in the specified color space.

    # Parameters:
    # - image_path: Path to the RGB image (.png format).
    # - color_space: The color space for the 3D point cloud ('RGB', 'XYZ', 'Lab', 'HSV', 'HSL').

    # Load and normalize the image
    image_bgr = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) / 255.0

    # Convert the image to the specified color space
    if color_space == "RGB":
        data = image_rgb
        axes_titles = ["R", "G", "B"]
    elif color_space == "XYZ":
        data = color.rgb2xyz(image_rgb)
        axes_titles = ["X", "Y", "Z"]
    elif color_space == "Lab":
        data = color.rgb2lab(image_rgb)
        axes_titles = ["L", "a", "b"]
    elif color_space == "HSV":
        data = color.rgb2hsv(image_rgb)
        axes_titles = ["H", "S", "V"]
    # HSL not supported yet...
    elif color_space == "HSL":
        raise ValueError("HSL not supported yet...")
    else:
        raise ValueError("Unsupported color space")

    # Normalize color values for plotting
    data_for_plotting = normalize_colors_for_plotting(data.reshape(-1, 3), color_space)

    # Plot the image
    plt.figure(figsize=(9, 5))
	
    # Plot the 3D point cloud
    # just rendering the 3d
    ax = plt.subplot(1, 1, 1, projection="3d")
    ax.scatter(
        data_for_plotting[:, 0],
        data_for_plotting[:, 1],
        data_for_plotting[:, 2],
        c=data_for_plotting,
        marker="o",
        s=0.001,
    )
    ax.set_title(f"{color_space} Color Space")
    ax.set_xlabel(axes_titles[0])
    ax.set_ylabel(axes_titles[1])
    ax.set_zlabel(axes_titles[2])

    plt.show()
    # matplotlib.pyplot.close()

In [None]:
def plot_color_space_with_rgb_colors(base_path, color_space="RGB"):

    # Plot the 3D point cloud for pixel data in the specified color space, colored with original RGB values.

    # usage:
    # - base_path: The base path used for saving the CSV files, without the color space suffix and .csv.
    # - color_space: The color space of the data to be plotted ('XYZ', 'Lab', 'HSV', 'HSL').
	
    # Load Image (.png)
    image_bgr = cv2.imread(base_path+'.png')
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) / 255.0

    # Load the color space data from CSV
    color_space_data = pd.read_csv(f"{base_path}_{color_space}.csv").to_numpy()

    # Load the corresponding RGB values from CSV
    rgb_data = pd.read_csv(f"{base_path}_RGB.csv").to_numpy()

    # Normalize RGB values to [0, 1] for matplotlib plotting
    rgb_colors = rgb_data / 255.0

    # supply axis labels
    if color_space == "RGB":
        axes_titles = ["R", "G", "B"]
    elif color_space == "XYZ":
        axes_titles = ["X", "Y", "Z"]
    elif color_space == "Lab":
        axes_titles = ["L", "a", "b"]
    elif color_space == "HSV":
        axes_titles = ["H", "S", "V"]
    # HSL not supported yet...
    elif color_space == "HSL":
        raise ValueError("HSL not supported yet...")
    else:
        raise ValueError("Unsupported color space")

    fig = plt.figure(figsize=(10, 5))
    # plt.subplot needs (rows, columns, index of current plot)
    plt.subplot(1, 2, 2)
    plt.imshow(image_rgb)
    plt.axis('off')
    plt.title('Original Image')
    
    # Plot the 3D point cloud
    ax = plt.subplot(1, 2, 1, projection="3d")

    ax.scatter(
        color_space_data[:, 0],
        color_space_data[:, 1],
        color_space_data[:, 2],
        c=rgb_colors,
        marker="o",
        s=0.005,
    )

    ax.set_title(f"{color_space} Color Space Point Cloud Colored by RGB")
    ax.set_xlabel(axes_titles[0])
    ax.set_ylabel(axes_titles[1])
    ax.set_zlabel(axes_titles[2])

    plt.show()

In [None]:
# base_path = "./data/cgi/01x25"  # Without the color space suffix and .csv
# color_space = "Lab"  # Example color space
# plot_color_space_with_rgb_colors(base_path, color_space)

In [None]:
# plot_color_space_with_rgb_colors('./data/cgi/01x25', 'Lab')
# plot_color_space_with_rgb_colors('./data/cgi/02x25', 'HSV')
# plot_color_space_with_rgb_colors('./data/cgi/03x25', 'XYZ')
# plot_color_space_with_rgb_colors('./data/cgi/04x25', 'Lab')
# plot_color_space_with_rgb_colors('./data/cgi/05x25', 'HSV')
# plot_color_space_with_rgb_colors('./data/cgi/06x25', 'XYZ')

In [None]:
plot_color_space_with_rgb_colors('./data/edtf/001', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/002', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/003', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/004', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/005', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/006', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/007', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/008', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/009', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/010', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/011', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/012', 'Lab')

In [None]:
# plot_color_space_with_rgb_colors('./data/edtf/001', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/002', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/003', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/004', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/005', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/006', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/007', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/008', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/009', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/010', 'Lab')
plot_color_space_with_rgb_colors('./data/edtf/011', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/012', 'RGB')

In [None]:
# plot_color_space_with_rgb_colors('./data/edtf/001', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/002', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/003', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/004', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/005', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/006', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/007', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/008', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/009', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/010', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/011', 'Lab')
plot_color_space_with_rgb_colors('./data/edtf/012', 'Lab')

In [None]:
# plot_color_space_with_rgb_colors('./data/edtf/001', 'HSV')
plot_color_space_with_rgb_colors('./data/edtf/002', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/003', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/004', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/005', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/006', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/007', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/008', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/009', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/010', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/011', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/012', 'Lab')

In [None]:
# plot_color_space_with_rgb_colors('./data/edtf/001', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/002', 'HSV')
plot_color_space_with_rgb_colors('./data/edtf/003', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/004', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/005', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/006', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/007', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/008', 'HSV')
# plot_color_space_with_rgb_colors('./data/edtf/009', 'XYZ')
# plot_color_space_with_rgb_colors('./data/edtf/010', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/011', 'Lab')
# plot_color_space_with_rgb_colors('./data/edtf/012', 'Lab')

In [None]:
# Calculates the Cartesian (Euclidean) distance between two points.
def cartesian_distance(point1, point2):
    diff = point1 - point2  
    squared_diff = diff ** 2
    return jnp.sqrt(jnp.sum(squared_diff))

In [None]:
# Calculates the Manhattan distance between two points.
def manhattan_distance(point1, point2):
    return jnp.sum(jnp.abs(point1 - point2)) 

In [None]:
# Compute the Euclidean distance between two sets of points.
# same as cartesian distance
def euclidean_distance(x, y):
    # return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
	return cartesian_distance(x,y)

In [None]:
# Randomly select k centroids from the data.
def initialize_centroids(data, k, seed=0):
    rng = random.PRNGKey(seed)
    indices = random.choice(rng, jnp.arange(data.shape[0]), shape=(k,), replace=False)
    return data[indices]

In [None]:
# Assign data points to the nearest centroid.
@jit
def assign_clusters(data, centroids,distance_fn):
    distances = distance_fn(data, centroids)
    # distances = euclidean_distance(data, centroids)
    return jnp.argmin(distances, axis=1)

In [None]:
def update_centroids(data, labels, k):
    """Update centroids as the mean of assigned data points."""
    return jnp.array([data[labels == i].mean(axis=0) for i in range(k)])

In [None]:
def kmeans1(data, k, distance_fn=euclidean_distance, max_iters=100, tol=1e-4, save_to_csv=False, csv_path=None):
    centroids = initialize_centroids(data, k)
    for i in range(max_iters):
        old_centroids = centroids
        labels = assign_clusters(data, centroids, distance_fn)
        # labels = assign_clusters(data, centroids)
        centroids = update_centroids(data, labels, k)
        
        # Convergence check
        if jnp.max(jnp.linalg.norm(centroids - old_centroids, axis=1)) < tol:
            break
        
        if save_to_csv and csv_path:
            save_iteration_csv(centroids, labels, i, csv_path)
    
    return centroids, labels

In [None]:
def save_iteration_csv(centroids, labels, iteration, base_csv_path):
    """Save the centroids and their associations to a CSV file."""
    centroids_np = np.array(centroids)
    labels_np = np.array(labels)
    with open(f"{base_csv_path}_iteration_{iteration}.csv", "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Centroid", "X", "Y", "Z"])
        for i, centroid in enumerate(centroids_np):
            points_in_cluster = np.sum(labels_np == i)
            writer.writerow([i, centroid[0], centroid[1], centroid[2], points_in_cluster])

In [None]:
data = np.random.rand(1000, 3)
data_jax = jnp.array(data)
centroids, labels = kmeans1(data_jax, k=3, max_iters=10, save_to_csv=True, csv_path="kmeans_output")
print("Centroids:", centroids)

In [None]:
from jax import random
import jax.numpy as jnp

def kmeans_jax(data, k, num_iters=10):
    rng = random.PRNGKey(0)
    n, d = data.shape
    centroids = data[random.choice(rng, jnp.arange(n), shape=(k,), replace=False)]

    for i in range(num_iters):
        # Assign clusters
        distances = jnp.sqrt(jnp.sum((data[:, None, :] - centroids[None, :, :]) ** 2, axis=2))
        closest_centroids = jnp.argmin(distances, axis=1)

        # Update centroids
        centroids = jnp.array([data[closest_centroids == c].mean(axis=0) for c in range(k)])

    return centroids, closest_centroids

In [None]:
import cv2
import numpy as np

# Load image
image_path = './data/edtf/012.png'
image_bgr = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

# Load the data from CSV files
data_csv_path = './data/edtf/001_Lab.csv'
# data = pd.read_csv(data_csv_path).to_numpy()
# pixels = pd.read_csv(data_csv_path).to_numpy()
# Normalize and reshape the image data to fit K-Means input
pixels = image_rgb.reshape(-1, 3) / 255.0  # Normalize to [0, 1]

In [None]:
pixels.shape

In [None]:
# grab a random set of pixels, see what it looks like
pixels[5000:5010]

In [None]:
data_jax = jnp.array(pixels)  # Convert NumPy array to JAX array

In [None]:
k = 2  # Number of clusters

In [None]:
centroids, labels = kmeans_jax(data_jax, k, num_iters=10)

In [None]:
from skimage import color
def convert_to_rgb2(data, space):
	data = pd.DataFrame(data).to_numpy()
	print('convert_to_rgb2: data:', data)
	# Convert the image to the specified color space
	if space == "RGB":
		data = data
		# data = data.astype(np.uint8)
		axes_titles = ["R", "G", "B"]
	elif space == "XYZ":
		data = color.xyz2rgb(data)
		axes_titles = ["X", "Y", "Z"]
	elif space == "Lab":
		data = color.lab2rgb(data,illuminant='D60')
		axes_titles = ["L", "a", "b"]
	elif space == "HSV":
		data = color.hsv2rgb(data)
		# data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB)
		axes_titles = ["H", "S", "V"]
	# HSL not supported yet...
	elif space == "HSL":
		raise ValueError("HSL not supported yet...")
	else:
		raise ValueError("Unsupported color space")
	return (data.astype(np.uint16))/255.0

In [None]:
centroids_rgb = convert_to_rgb2(centroids,'HSV')

In [None]:
print('centroids: ', centroids)
print('centroids_rgb: ', centroids_rgb)
print('labels: ', labels)

In [None]:
# Map each pixel to its centroid color
# segmented_image = jnp.array([centroids[label] for label in labels])

In [None]:
segmented_image = centroids[labels].reshape(image_rgb.shape)
# segmented_image = centroids_rgb[labels].reshape(image_rgb.shape)

In [None]:
# Reshape back to the original image shape
# segmented_image = segmented_image.reshape(image_rgb.shape)

In [None]:
# Display the result
import matplotlib.pyplot as plt

fig2 = plt.figure(figsize=(10, 5))
plt.imshow(segmented_image)
plt.axis('off')  # Hide axes
plt.show()