In [None]:
%matplotlib inline

import time as time

import numpy as np
import scipy as sp

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from sklearn.feature_extraction.image import grid_to_graph
from sklearn.cluster import AgglomerativeClustering
from sklearn.utils.testing import SkipTest
from sklearn.utils.fixes import sp_version

# compression of image width
ZOOM_FACTOR = 25


# input: image array and optional zoom factor z
# output: compressed image array (by factor z^2)
def comp(img,z = ZOOM_FACTOR):
#     img=mpimg.imread(impath)
    compress = img[0:len(img):z,0:len(img[0]):z]
    return compress

# converts image to greyscale by taking averages of rgb values
# returns compressed image, compressed greyscale image
def avg(img, compress = True, z = ZOOM_FACTOR):
    if compress:
        img = comp(img,z)
    w = np.zeros((len(img),len(img[0])))
    for i in range(len(img)):
        for j in range(len(img[0])):
            w[i][j] = int(img[i][j][0]/3.0+img[i][j][1]/3.0+img[i][j][2]/3.0)
    return img, w


# transforms a coordinate from compressed to full size image
def c2f(x,z = ZOOM_FACTOR):
    return int(z * x)


# takes in an image array and optional number of clusters desired
# returns array of size of image whose nonzero values are to which cluster each pixel belongs
def cluster(impath, n = 15, compress = True, usefiltr = False): #usep will filter out small clusters
                                                                          
    full_img = impath
    fig, axes = plt.subplots(1, 2, figsize=(15, 15))
    axes = axes.flatten()

    img, grey = avg(impath, compress)
    axes[0].imshow(img)

    # initializes connectivity matrix
    connectivity = grid_to_graph(*grey.shape)

    X = np.reshape(img, (-1, 3))
    print("Compute structured hierarchical clustering...")
    st = time.time()
    n_clusters = n  # number of regions

    # linkage basically chooses how "distance" between color values computed, ward only setting that worked for me
    ward = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward',
                                   connectivity=connectivity, compute_full_tree = 'auto')
    ward.fit(X)
    label = np.reshape(ward.labels_, grey.shape)
    print("Elapsed time: ", time.time() - st)
    print("Number of pixels: ", label.size)
    print("Number of clusters: ", np.unique(label).size)
    return label
    
# takes in an image array and a clustering label for that image
# returns a list of areas, perimeters, average color, and standard deviation of color (as a tuple of type (area,
# perimeter, color avg, color std deviation), perim/area)
def get_cs(impath, label):
    comp_img = comp(impath)
    l = grey.shape[0]
    w = grey.shape[1]
    minp = label.size * MINP
    maxp = label.size * MAXP
    # counts how many pixels are in cluster l
    count = np.zeros((n_clusters,1)
    area = np.zeros((n_clusters,1))
    # this is literally a list of all the values that are in cluster n, to make getting color statistics about it easier
    list_colors = np.zeros((n_clusters, 1))
    
    # count perimeter of each cluster
    # doesn't count last row/column but that should be okay?
    for i in range(1,np.shape(label)[0]-1):
        for j in range(1,np.shape(label)[1]-1):
            area[label[i,j]] += 1
            list_colors[label[i,j]].append(comp_img[i,j])
            cl = label[i,j]
            if (cl != label[i,j+1] or cl != label[i, j-1]) or (cl != label[i+1,j] or cl != label[i-1,j]):
                count[label[i,j]] += 1
    
    R_avg = np.mean(list_colors, 1)
    G_avg = np.mean(list_colors, 2)
    B_avg = np.mean(list_colors, 3)
    R_std = np.std(list_colors, 1)
    G_std = np.std(list_colors, 2)
    B_std = np.std(list_colors, 3)
    return (area, count, (R_avg, G_avg, B_avg), (R_std, G_std, B_std), count/area)
    
# does the full processing on the image
def process_image(impath):
    label = cluster(impath)
    stats = get_cs(impath, label)
    return stats

# takes a list of all the images to train on as well as the correct output for each image                
def train_perceptron(list_of_img, correct_result):
    for img in range(list_of_img):
        