In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.cluster import KMeans

# Flaw of Averages

### Driver Code

In [None]:
def create_col_pic(shape: tuple, color: np.ndarray):
    """Create an image of a single color"""
    out = np.zeros(shape, dtype='uint8')
    out[:,:,:] = color
    return out

In [None]:
def compare(fname, save_out=None):
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_RGB2BGR)
    avg_col = np.average(img.reshape((-1, 3)), axis=0)
    
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Original Image')
    
    plt.subplot(1,2,2)
    plt.imshow(create_col_pic(img.shape, avg_col))
    plt.axis('off')
    plt.title('Average Color of Image')
    
    if(save_out): plt.savefig(save_out)

    return img, avg_col

### Examples of Averaging being okay!

In [None]:
_,_ = compare('../demonstrations/figure1.jpg')

##### You can't see the background surrounding the green square, but the value is (255, 255, 255)

In [None]:
_, _ = compare('../demonstrations/figure2.jpg')

### Examples of Averaging failing...

##### Failure to capture a range of colors

In [None]:
_, _ = compare('../demonstrations/rainbow.png')

##### Dampening

In [None]:
_, _ = compare('../demonstrations/figure3.jpg')

### Examples of Averages being... dubious?

In [None]:
_, _ = compare('../demonstrations/red_blue.png')

In [None]:
_, _ = compare('../demonstrations/black_white.png')

### Examples on some test screenshots

In [None]:
_, _ = compare('../demonstrations/madmax_clip.jpg')

In [None]:
_, _ = compare('../demonstrations/spideman_city_clip.jpg')

In [None]:
_, _ = compare('../demonstrations/spiderman_clip_green.jpg')

# Interpolation Doesn't Seem that Great

In [None]:
def compare_interpolation(fname, width1, width2, method, save_out=None):
    raw = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_RGB2BGR)
    plt.figure(figsize=(12,24))
    plt.axis('off')
    plt.subplot(1,3,1)
    plt.imshow(raw)
    plt.title('original image')
    
    plt.subplot(1,3,2)
    img1 = cv2.resize(raw, (width1, raw.shape[0]), interpolation=method)
    plt.imshow(img1)
    
    plt.title('cv2.resize ({},{}), interpolation={}'.format(width1, raw.shape[0], method))
    
    plt.subplot(1,3,3)
    img2 = cv2.resize(raw, (width2, raw.shape[0]), interpolation=method)
    plt.xticks([])
    plt.imshow(img2)
    plt.title('cv2.resize ({},{}), interpolation={}'.format(width2, raw.shape[0], method))
    
    if(save_out) : plt.savefig(save_out)
    return

In [None]:
compare_interpolation('../demonstrations/rainbow.png', 50, 1, cv2.INTER_LINEAR)

In [None]:
compare_interpolation('../demonstrations/figure3.jpg', 50, 1, cv2.INTER_LINEAR)

In [None]:
compare_interpolation('../demonstrations/rotated.png', 50, 1, cv2.INTER_LINEAR)

In [None]:
compare_interpolation('../demonstrations/madmax_clip.jpg', 50, 1, cv2.INTER_LINEAR)

In [None]:
compare_interpolation('../demonstrations/spiderman_clip_green.jpg', 50, 1, cv2.INTER_LINEAR)

## Using maching learning (more specifically kmeans clustering) to extract more meaningful colors

#### NOTE: our contrived examples are (pretty) uniform in colors, so our clustering algorithm converges quickly

In [None]:
def fit_colors(image, num_col):
    model = KMeans(n_clusters=num_col, init='k-means++', n_init=20)
    model.fit(image.reshape((-1, 3)))
    return model

In [None]:
def reconstruct_img(cluster_center, labels, shape: tuple):
    img = np.zeros(shape, dtype='uint8')
    img_ind = 0
    for i in range(shape[0]):
        for j in range(shape[1]):
            img[i][j] = np.around(cluster_center[labels[img_ind]])
            img_ind += 1
    return img

In [None]:
def compare_og_recon(fname, num_col, save_out=None):
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_RGB2BGR)
    model = KMeans(n_clusters=num_col, init='k-means++', n_init=20)
    model.fit(img.reshape((-1,3)))
    rec_img = reconstruct_img(model.cluster_centers_, model.labels_, img.shape)
    
    plt.figure(figsize=(12,9))
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Original Image')
    
    plt.subplot(1,2,2)
    plt.imshow(rec_img)
    plt.axis('off')
    plt.title('Reconstructed Image (k={})'.format(num_col))
    if(save_out) : plt.savefig(save_out)
    return

In [None]:
def compare_kmeans(fname, num_cluster, save_out=None):
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_RGB2BGR)
    avg_col = np.average(img.reshape((-1, 3)), axis=0)
    
    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    plt.subplot(2,2,1)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Original Image')
    
    plt.subplot(2,2,2)
    plt.imshow(create_col_pic(img.shape, avg_col))
    plt.axis('off')
    plt.title('Average Color of Image')
    
    model = KMeans(n_clusters=num_cluster, init='k-means++', n_init=20)
    model.fit(img.reshape((-1, 3)))
    
    recon = reconstruct_img(model.cluster_centers_, model.labels_, img.shape)
    plt.subplot(2,2,3)
    plt.imshow(recon)
    plt.axis('off')
    plt.title('Reconstructed Image (k={})'.format(num_cluster))
    
    plt.subplot(2,2,4)
    all_cols = np.zeros(img.shape, dtype='uint8')
    cut_h = img.shape[0] // len(model.cluster_centers_)
    for ind, c in enumerate(model.cluster_centers_):
        all_cols[cut_h*ind:(ind+1)*cut_h, :] = c 
    plt.imshow(all_cols)
    plt.axis('off')
    plt.title('KMeans Colors (k={})'.format(num_cluster))
    if(save_out) : plt.savefig(save_out)
    return

In [None]:
compare_kmeans('../demonstrations/figure3.jpg', 2)

In [None]:
compare_kmeans('../demonstrations/rainbow.png', 7)

In [None]:
compare_kmeans('../demonstrations/rainbow.png', 4)

### Instance of our clustering algorithm unable to overquanitfy (special case because of our contrived example)

In [None]:
compare_kmeans('../demonstrations/rainbow.png', 10)

In [None]:
compare_kmeans('../demonstrations/madmax_clip.jpg', 8)

In [None]:
compare_kmeans('../demonstrations/spiderman_clip_green.jpg', 8)

In [None]:
compare_kmeans('../demonstrations/spideman_city_clip.jpg', 8)

In [None]:
compare_og_recon('../demonstrations/figure3.jpg', 2)

In [None]:
compare_og_recon('../demonstrations/madmax_clip.jpg', 8)

In [None]:
compare_og_recon('../demonstrations/spideman_city_clip.jpg', 8)

In [None]:
compare_og_recon('../demonstrations/spiderman_clip_green.jpg', 8)

In [None]:
compare_og_recon('../demonstrations/rainbow.png', 7)

In [None]:
compare_og_recon('../demonstrations/rainbow.png', 4)

### Instances of our clustering algorithm unable to overquanitfy (special case because of our contrived example)

In [None]:
compare_og_recon('../demonstrations/rainbow.png', 10)