 # Pokemon Image Clustering #

In [1]:
# for loading/processing the images
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array 
from keras.applications.vgg16 import preprocess_input
# for other things
import random
import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt

In [2]:
import os

pokemons = []
rootdir = '../data/PokemonData'
for subdir, dirs, files in os.walk(rootdir):
    for file in files:
        pokemons.append([file, subdir.split('\\')[-1]])
        
print(pokemons[:10])

[['0282b2f3a22745f1a436054ea15a0ae5.jpg', 'Abra'], ['06b9eec4827d4d49b1b4c284308708df.jpg', 'Abra'], ['10a9f06ec6524c66b779ea80354f8519.jpg', 'Abra'], ['1788abb8b51f48509cfac8067bd99e14.jpg', 'Abra'], ['28cfad92ad934d1f9b579cbff4b5d012.jpg', 'Abra'], ['2eb2a528f9a247358452b3c740df69a0.jpg', 'Abra'], ['2fd28e699b7c4208acd1637fbad5df2d.jpeg', 'Abra'], ['32240b108a8140f8b31c495166fc453c.jpg', 'Abra'], ['34532bb006714727ade4075f0a72b92d.jpg', 'Abra'], ['3680c3f65a484c3ba05a7cb93e1d7ae3.jpg', 'Abra']]


In [3]:
pokemon_labels = pd.DataFrame(pokemons, columns = ['FileName', 'Label'])
# credit to https://towardsdatascience.com/how-to-cluster-images-based-on-visual-similarity-cd6e7209fe34
#    for help with image processing
# load the image as a 224x224 array
import matplotlib.image as mpimg
import imghdr
img_dict = {}
for pokemon in range(len(pokemon_labels)):
    img_path = '../data/PokemonData/'+pokemon_labels.iloc[pokemon]['Label']+'/'+pokemon_labels.iloc[pokemon]['FileName']
    img_type = imghdr.what(img_path)
    if img_type != "png" and img_type != "jpg" and img_type != "jpeg":
        os.remove(img_path)
    else:
        img = load_img(img_path, target_size=(224,224))
        # convert from 'PIL.Image.Image' to numpy array
        img = np.array(img)
        reshaped_img = img.reshape(224,224,3)
        x = preprocess_input(reshaped_img)
        if pokemon_labels.iloc[pokemon]['Label'] in img_dict.keys(): img_dict[pokemon_labels.iloc[pokemon]['Label']].append(x)
        else : img_dict[pokemon_labels.iloc[pokemon]['Label']] = [x]

In [None]:
class img_K_means:
    def __init__(self,img_dict,K,dist_func):
        self.imgs       = img_dict #dictionary
        self.clustroids = np.asarray(random.sample(img_dict.keys(), K))
        self.centroids  = [224,224,3]*K
        self.distances  = np.zeros([K])
        self.clusters   = {}
        self.k = K
        self.dist_func  = dist_func
        for i in range(K): self.centroids[i] = random.choice(img_dict[self.clustroids[i]])
        for index,i in enumerate(self.clustroids): self.clusters[i] = [[i,self.centroids[index]]]
        
        
    def cluster(self):
        ctr=0
        for img_label in self.imgs.keys():
            for img_val in self.imgs[img_label]:        # For each data point
                for j, clustroid in enumerate(self.clustroids):                       # go through each clustroid
                    self.distances[j] = self.squared_dist(img_val,self.centroids[j])  # and calculate distance to clustroid.
                ctr+=1
                index = np.argmin(self.distances)                                     # Get the index of the closest clustroid
                if self.clustroids[index] in self.clusters.keys(): self.clusters[self.clustroids[index]].append([img_label,img_val])
                else: self.clusters[self.clustroids[index]] = [[img_label,img_val]]   #assign to corresponding cluster
                #for tuples in self.clusters[self.clustroids[index]]: print(tuples[0])
                #print("one point assigned: ", img_label, ctr)
                
    def compute_centroids(self):
        new_clusters = {}
        for index,cluster in enumerate(self.clusters.keys()):
            if index>self.k-1:break
            centroid = np.zeros([224,224,3])
            for tuples in self.clusters[cluster]: 
                centroid += np.abs(tuples[1])
            self.centroids[index] = np.abs(centroid)/len(self.clusters[cluster]) #averages the image RGB's in the cluster list
            ctr=0
            clustroid = ""
            clustroid_val = np.zeros([224,224,3])
            distances = np.full((len(self.clusters[cluster])),1000000) #start with large distances so min() doesn't grab 0's
            for tuples in self.clusters[cluster]: 
                distances[ctr] = self.squared_dist(tuples[1],self.centroids[index])
                if distances[ctr] == min(distances): 
                    clustroid = tuples[0] #gets name of pokemon of closest image to centroid
                    clustroid_val = tuples[1] #value of closest image to centroid
                ctr+=1
            
            if clustroid == "": #sometimes clustroid won't get filled by this point. Not sure why, so here's my caveman solution
                clustroid = cluster
                clustroid_val = random.choice(self.imgs[clustroid])
            while clustroid in new_clusters.keys(): # if we get collisions, randomize to a new clustroid
                clustroid = random.choice(list(self.imgs.keys()))
                clustroid_val = random.choice(self.imgs[cluster])
                
                
            print("new clustroid:",clustroid, "old clustroid:", cluster)
            new_clusters[clustroid]=[[clustroid,clustroid_val]] #if we get an error here, collision and thats not great
        print(new_clusters.keys())
        key_list = list(self.clusters.keys())
        print(type(key_list),key_list)
        for key in key_list:
            if key not in new_clusters.keys(): self.clusters.pop(key,None) # we must pop old keys, otherwise the dict will grow
        self.clusters = new_clusters
        print(self.clusters.keys())
        
    def squared_dist(self,img1,img2):
        return np.sqrt(np.sum((img1-img2)**2))
    
    def fit(self):
        cloysters = self.clusters
        for i in range(1000):
            self.cluster()
            self.compute_centroids()
            if cloysters == self.clusters: break
            cloysters = self.clusters
            
first_try = img_K_means(img_dict,10,0)
first_try.fit()
                

            

new clustroid: Paras old clustroid: Grimer
new clustroid: Primeape old clustroid: Magmar
new clustroid: Electrode old clustroid: Omastar
new clustroid: Rapidash old clustroid: Rapidash
new clustroid: Parasect old clustroid: Parasect
new clustroid: Seadra old clustroid: Charizard
new clustroid: Rhydon old clustroid: Fearow
new clustroid: Golem old clustroid: Pinsir
new clustroid: Slowpoke old clustroid: Magikarp
new clustroid: Chansey old clustroid: Alakazam
dict_keys(['Paras', 'Primeape', 'Electrode', 'Rapidash', 'Parasect', 'Seadra', 'Rhydon', 'Golem', 'Slowpoke', 'Chansey'])
<class 'list'> ['Grimer', 'Magmar', 'Omastar', 'Rapidash', 'Parasect', 'Charizard', 'Fearow', 'Pinsir', 'Magikarp', 'Alakazam']
dict_keys(['Paras', 'Primeape', 'Electrode', 'Rapidash', 'Parasect', 'Seadra', 'Rhydon', 'Golem', 'Slowpoke', 'Chansey'])
new clustroid: Paras old clustroid: Paras
new clustroid: Primeape old clustroid: Primeape
new clustroid: Electrode old clustroid: Electrode
new clustroid: Rapidash ol

In [None]:
print(first_try.clusters)