[![Binder](https://mybinder.org/badge_logo.svg)](https://nbviewer.org/github/vicente-gonzalez-ruiz/vector_quantization/blob/main/docs/RGB_VQ.ipynb)

[![Colab](https://badgen.net/badge/Launch/on%20Google%20Colab/blue?icon=notebook)](https://colab.research.google.com/github/vicente-gonzalez-ruiz/vector_quantization/blob/main/docs/RGB_VQ.ipynb)

# [Vector Quantization (in the color domain) of a RGB image](https://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html#sphx-glr-auto-examples-cluster-plot-color-quantization-py)
Color VQ using [K-means](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans).

In [None]:
try:
    import matplotlib.pyplot as plt
except:
    !pip install matplotlib
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.axes as ax
    #plt.rcParams['text.usetex'] = True
    #plt.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command
%matplotlib inline

In [None]:
try:
    from sklearn.cluster import KMeans
    from sklearn.utils import shuffle
except:
    !pip install scikit-learn
    from sklearn.cluster import KMeans
    from sklearn.utils import shuffle

In [None]:
try:
    from skimage import io
    from skimage.color import rgb2gray
except:
    !pip install scikit-image
    from skimage import io
    from skimage.color import rgb2gray

In [None]:
try:
    import numpy as np
except:
    !pip install numpy
    import numpy as np

In [None]:
try:
    import pylab
except:
    !pip install pylab
    import pylab

In [None]:
import math
import os

In [None]:
try:
    import gzip
except:
    !pip install gzip
    import gzip

In [None]:
try:
    from information_theory.distortion import RMSE
except:
    !pip install "information_theory @ git+https://github.com/vicente-gonzalez-ruiz/information_theory"
    from information_theory.distortion import RMSE

import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans
from sklearn.utils import shuffle
import image_3 as RGB_image
import image_1 as gray_image
import distortion
import information
import os
import pylab
import gzip

## Configuration

In [None]:
fn = "http://www.hpca.ual.es/~vruiz/images/lena.png"
n_clusters = 256  # Number of bins
range_of_N_bins = [1<<i for i in range(1, 8)]
#gray_image.write = gray_image.debug_write # faster, but lower compression
image_dtype = np.uint8 # For 8 bpp/component images
#image_dtype = np.uint16 # For 16 bpp/component images

##RGB_image.write = RGB_image.debug_write # Faster, but lower compression
#RGB_image.write = information.write # The fastest, but returns only an estimation of the length
##gray_image.write = gray_image.debug_write # Faster, but lower compression
#gray_image.write = information.write # The fastest, but returns only an estimation of the length

## Read the image and show it

In [None]:
#RGB_img = RGB_image.read(fn).astype(image_dtype)
#RGB_image.show(RGB_img, fn + "000.png")
img = io.imread(fn)
plt.figure()
plt.title(fn)
io.imshow(img)
plt.show()

## Example

In [None]:
# Create a K-means clustering tool
normalized_img = np.array(img,dtype=np.float64) / 255
w, h, d = shape_of_img = tuple(img.shape)
img_reshaped = normalized_img.reshape((w*h, d))
kmeans = KMeans(n_clusters=32, random_state=0)
some_samples = shuffle(img_reshaped, random_state=0, n_samples=1_000)

In [None]:
# Find centroids in clusters (CPU intensive)
kmeans.fit(some_samples)

In [None]:
# Quantize
centroids = kmeans.cluster_centers_
labels = kmeans.predict(img_reshaped)
img_dequantized = centroids[labels].reshape(w, h, -1)

In [None]:
#RGB_image.show(VQ_img, "Quantized Image using floating point pallete")
plt.figure()
plt.title("Dequantized Image")
io.imshow(img_dequantized)
plt.show()

In [None]:
print(centroids)

In [None]:
centroids = (centroids * 255).astype(np.uint8)
print(centroids)

In [None]:
img_dequantized = centroids[labels].reshape(w, h, -1)

In [None]:
#RGB_image.show(VQ_img, "Quantized Image using integer pallete")
plt.figure()
plt.title("Dequantized Image")
io.imshow(img_dequantized)
plt.show()

In [None]:
with gzip.GzipFile("/tmp/codebook.npy.gz", "w") as f:
    np.save(file=f, arr=centroids)
os.path.getsize("/tmp/codebook.npy.gz")

In [None]:
with gzip.GzipFile("/tmp/codebook.npy.gz", "r") as f:
    centroids = np.load(f)
print(centroids)

## RD performance
The code-book is considered in the computation of the bit-rate.

In [None]:
def save(img, fn):
    io.imsave(fn, img, check_contrast=False)
    #subprocess.run(f"optipng {fn}", shell=True, capture_output=True)
    required_bytes = os.path.getsize(fn)
    print(f"Written {required_bytes} bytes in {fn}")
    return required_bytes

In [None]:
def RD_curve(img, range_of_N_bins):
    normalized_img = np.array(img,dtype=np.float64) / 255
    w, h, d = shape_of_img = tuple(img.shape)
    img_reshaped = normalized_img.reshape((w*h, d))
    points = []
    for n in range_of_N_bins:
        kmeans = KMeans(n_clusters=n, random_state=0)
        some_samples = shuffle(img_reshaped, random_state=0, n_samples=1_000)
        kmeans.fit(some_samples)
        centroids = kmeans.cluster_centers_
        labels = kmeans.predict(img_reshaped)
        k = labels.astype(np.uint8) # Up to 256 bins
        y = (centroids[k].reshape(w, h, -1) * 255).astype(np.uint8)
        print("Quantization indexes: ", np.unique(k))
        #rate = RGB_image.write(k, "/tmp/" + str(n) + '_', 0)*8/(img.shape[0]*img.shape[1])
        print(k.shape)
        rate = save(k.reshape((img.shape[0], img.shape[1])), "/tmp/" + str(n) + ".png")*8/(img.shape[0]*img.shape[1])
        print(img.size)
        with gzip.GzipFile("/tmp/codebook.npy.gz", "w") as f:
            np.save(file=f, arr=centroids)
        rate += (os.path.getsize("/tmp/codebook.npy.gz")*8/(img.shape[0]*img.shape[1]))
        #_distortion = distortion.RMSE(img, y)
        distortion = RMSE(img, y)        
        plt.title(f"{n}")
        plt.imshow(y, cmap=plt.cm.gray, vmin=0, vmax=256)
        plt.show()
        points.append((rate, distortion))
        print(f"n={n:>3}, rate={rate:>7} bits/pixel, distortion={distortion:>6.1f}")
    return points

In [None]:
RD_points = RD_curve(img, range_of_N_bins)

In [None]:
#YCoCg_SQ = []
#with open(f'../YCoCg_SQ/YCoCg_SQ.txt', 'r') as f:
#    for line in f:
#        BPP, RMSE = line.split('\t')
#        YCoCg_SQ.append((float(BPP), float(RMSE)))

In [None]:
pylab.figure(dpi=150)
#pylab.scatter(*zip(*RD_points), label=f"VQ+PNG", s=1, marker='.')
#pylab.plot(*zip(*YCoCg_SQ), c='b', marker='o', label="Deadzone($\mathbf{\Delta}^{\mathrm{Y}}_i = \mathbf{\Delta}^{\mathrm{Co}}_i = \mathbf{\Delta}^{\mathrm{Cg}}_i$)+PNG", linestyle="dashed")
pylab.plot(*zip(*RD_points), c='m', marker='x', label="VQ($\mathrm{RGB}$)+PNG", linestyle="dotted")
pylab.title(fn)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("RMSE")
pylab.legend(loc='upper right')
pylab.show()

In [None]:
#with open(f"RGB_VQ.txt", 'w') as f:
#    for item in RD_points:
#        f.write(f"{item[0]}\t{item[1]}\n")

In [None]:
input()

## Ignore the rest ...

## Compare to SQ

In [None]:
def load(path):
    curve = []
    with open(path, 'r') as f:
        for line in f:
            rate, _distortion = line.split('\t')
            curve.append((float(rate), float(_distortion)))
    return curve

In [None]:
SQ_color = load("../RGB_SQ/RGB_SQ.txt")

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*RD_points), c='m', marker='x', label="VQ", linestyle="dotted")
pylab.plot(*zip(*SQ_color), c='k', marker='x', label="SQ", linestyle="dotted")
pylab.title(fn)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("RMSE")
pylab.legend(loc='upper right')
#pylab.xscale("log")
#pylab.yscale("log")
pylab.show()

In [None]:
import time
while True:
    time.sleep(1)

In [None]:
if 'google.colab' in str(get_ipython()):
    %run ./deadzone.ipynb
    deadzone_RD = load("./dead-zone_RD_points.txt")
    print("loaded ./dead-zone_RD_points.txt")
elif not os.path.exists("../scalar_quantization/dead-zone_RD_points.txt"):
    %run ../scalar_quantization/deadzone.ipynb
    deadzone_RD = load("../scalar_quantization/dead-zone_RD_points.txt")
    print("loaded ../scalar_quantization/dead-zone_RD_points.txt after regenerating the file")
else:
    deadzone_RD = load("../scalar_quantization/dead-zone_RD_points.txt")
    print("loaded ../scalar_quantization/dead-zone_RD_points.txt without regenerating the file")