In [None]:
from PIL import Image
from pathlib import Path

In [None]:
from scipy.misc import imsave
import os

In [None]:
import numpy as np

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

In [None]:
from collections import namedtuple

In [None]:
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import Pipeline

In [None]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [None]:
Region = namedtuple('BoundingBox', 'x0 y0 x1 y1')

# Load the Files

In [None]:
files = list(Path("../../../data/sprites/").glob("*.png"))
files[5]

In [None]:
@interact(f=files)
def draw_file(f):
    im = np.array(Image.open(f))
    plt.close('all')
    plt.figure()
    plt.imshow(im)
    plt.show()

# Show the Files

In [None]:
f = files[5]
im = np.array(Image.open(f))
plt.imshow(im)

# Pre-Processing

We want to remove primary colors from the background. So we do that here.

In [None]:
def remove_color(img, color, inplace=True):
    img = np.asarray(img)
    if not inplace:
        img = img.copy()
    alpha = (~np.all(img[:,:,:3]==color[:3], axis=2))
    if img.shape[2]==4:
        img[:,:,3] *= alpha
    else:
        img = np.dstack((img, alpha*255))
    return img

In [None]:
im = remove_color(im, im[0,0,:], inplace=False)
plt.imshow(im)

We also want to remove anything which is noisy in the image, like text. So we go to the image and remove regions by covering them with high alpha.

In [None]:
def remove_region(img, region: Region, inplace=True):
    img = np.asarray(img)
    if not inplace:
        img = img.copy()
    img[region.y0:region.y1,region.x0:region.x1,3] = 0
    return img

In [None]:
im = remove_region(im, Region(425,250,750,375), inplace=False)
plt.imshow(im)

# Clustering

We are going to use clustering to detect the sprites

## Convert to Scatter Plot

First we convert the image to a scatter representation.

In [None]:
im_scatter_y, im_scatter_x = np.where(im[:,:,3]==255)
plt.scatter(im_scatter_x, im_scatter_y, s=.01)
plt.imshow(im)

Now we define our clustering pipeline.

In [None]:
cluster = Pipeline([('clustering', DBSCAN(2.))])

And our data.

In [None]:
X = np.vstack((im_scatter_x, im_scatter_y)).T
X.shape

We fit the cluster and get the number of clusters.

In [None]:
cluster.fit(X)
core_samples_mask = np.zeros_like(cluster.named_steps['clustering'].labels_, dtype=bool)
labels = cluster.named_steps['clustering'].labels_
unique_labels = np.unique(labels)
print("n clusters:", len(unique_labels))

## Plotting Clusters

Now we plot it to check that it's right.

In [None]:
# REF: http://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#sphx-glr-auto-examples-cluster-plot-dbscan-py
def plot_clusters(img, cluster, X, core_samples_mask, all_labels, colors):
    core_samples_mask[cluster.named_steps['clustering'].core_sample_indices_] = True
    unique_labels = np.unique(all_labels)
    fig, ax = plt.subplots()
    for k, col in zip(unique_labels, colors):
        if k == -1:
            # Black used for noise.
            col = [0, 0, 0, 1]

        class_member_mask = (all_labels == k)

        xy = X[class_member_mask & core_samples_mask]
        ax.plot(xy[:, 0], xy[:, 1], markerfacecolor=tuple(col), markersize=1, alpha=.5)

    ax.imshow(img, alpha=.5)
    return ax

In [None]:
colors = [plt.cm.Spectral(each)
              for each in np.linspace(0, 1, len(unique_labels))]
plot_clusters(im,
              cluster,
              all_labels=labels,
              X=X,
              core_samples_mask=core_samples_mask,
              colors=colors)

## Bounding Boxes

Now for each cluster we get the bounding box.

In [None]:
def get_sprite(X,
               core_samples_mask,
               all_labels,
               given_label,
               color,
               x_pad=2,
               y_pad=2):
    class_member_mask = (all_labels == given_label)
    xy = X[class_member_mask & core_samples_mask]
    return Region(x0=max(0,np.min(xy[:, 0]-x_pad)),
                  y0=max(0,np.min(xy[:, 1]-y_pad)),
                  x1=min(np.max(X[:, 0]), np.max(xy[:, 0])+x_pad),
                  y1=min(np.max(X[:, 1]), np.max(xy[:, 1])+y_pad))
    

In [None]:
sprite_regions = [get_sprite(X=X,
                             core_samples_mask=core_samples_mask,
                             all_labels=labels,
                             given_label=k,
                             color=col) for k, col in zip(unique_labels, colors) if k != -1]
sprite_regions

And we get those regions from the image.

In [None]:
def get_region(img, region: Region):
    img = np.asarray(img)
    return img[region.y0:region.y1,region.x0:region.x1]

In [None]:
sprites = [get_region(im, reg) for reg in sprite_regions]

## Plotting Sprites
Now we check them all to see if they are right.

In [None]:
fig, axes = plt.subplots(int(np.ceil(np.sqrt(len(sprites)))),
                         int(np.floor(len(sprites)/np.sqrt(len(sprites)))),
                         sharex=True, sharey=True)
j = 0
for axes_i in axes:
    for axes_j in axes_i:
        axes_j.imshow(sprites[j])
        j += 1

In [None]:
@interact(n=list(range(len(sprites))))
def draw_sprite(n):
    plt.close('all')
    plt.figure()
    plt.imshow(sprites[n])
    plt.show()

## Saving

In [None]:
os.makedirs(f.parent / f.stem, exist_ok=True)
for i, sprite in enumerate(sprites):
    fname = f.parent / f.stem / (str(i) + ".png")
    imsave(fname, sprite)

# Put it all together

In [None]:
def preprocess(f,
               color_locations,
               removal_regions,
               plot=True,
               save=False,
               **kwargs):
    # Get the image
    img = np.array(Image.open(f))
    
    if plot:
        plt.figure(0)
        plt.imshow(img)
        plt.title("Image Pre Processing")
        
    # Remove Primary Colors
    for (x,y) in color_locations:
        img = remove_color(img, img[y,x,:], inplace=True)
    
    # Remove Regions
    for reg in removal_regions:
        img = remove_region(img, reg, inplace=True)
        
    if plot:
        plt.figure(1)
        plt.imshow(img)
        plt.title("Image Post Processing")
    
    # Save or return
    if save:
        os.makedirs(f.parent / f.stem, exist_ok=True)
        fname = f.parent / f.stem / (f.stem + ".png")
        imsave(fname, sprite)
    else:
        return img

In [None]:
def postprocess(img,
                cluster_args={'eps':2., 'min_samples': 2},
                post_color_locations=[],
                x_pad=2,
                y_pad=2,
                plot=True,
                save=True,
                **kwargs):
 
    # Get the scatter data
    im_scatter_y, im_scatter_x = np.where(img[:,:,3]==255)
    X = np.vstack((im_scatter_x, im_scatter_y)).T
    
    # Get the pipeline
    cluster = Pipeline([('clustering', DBSCAN(**cluster_args))])
    
    # Train the pipeline
    cluster.fit(X)
    core_samples_mask = np.zeros_like(cluster.named_steps['clustering'].labels_, dtype=bool)
    labels = cluster.named_steps['clustering'].labels_
    unique_labels = np.unique(labels)
    print("n clusters:", len(unique_labels))
    
    # Plot the clusters
    colors = [plt.cm.Spectral(each)
              for each in np.linspace(0, 1, len(unique_labels))]
    ax = plot_clusters(img, cluster, X=X, core_samples_mask=core_samples_mask, all_labels=labels, colors=colors)
    ax.set_title("Clusters")
    
    # Get the sprites
    sprite_regions = [get_sprite(X=X,
                                 core_samples_mask=core_samples_mask,
                                 all_labels=labels,
                                 given_label=k,
                                 color=col,
                                 x_pad=x_pad,
                                 y_pad=y_pad) for k, col in zip(unique_labels, colors) if k != -1]
    sprites = [get_region(img, reg) for reg in sprite_regions]
    
    # Remove Primary Colors
    for sprite in sprites:
        for (x,y) in post_color_locations:
            remove_color(sprite, img[y,x,:], inplace=True)
    
    if plot:
        fig, axes = plt.subplots(int(np.ceil(np.sqrt(len(sprites)))),
                                 int(np.floor(len(sprites)/np.sqrt(len(sprites)))),
                                 sharex=True, sharey=True)
        j = 0
        for axes_i in axes:
            for axes_j in axes_i:
                if j < len(sprite_regions):
                    try:
                        axes_j.imshow(sprites[j])
                    except:
                        print("Show Error:",sprite_regions[j])
                j += 1
        fig.suptitle('All Sprites')
                
    # Save or return
    if save is not None and save is not False and isinstance(save, (str, Path)):
        save = Path(save)
        os.makedirs(save, exist_ok=True)
        # Delete existing files
        for f_ in save.glob('*.png'):
            f_.unlink()
        # Create new files
        for i, sprite in enumerate(sprites):
            try:
                imsave(save / (str(i)+'.png'), sprite)
            except Exception as e:
                print("Save Error:",sprite_regions[i],e)
    else:
        return sprites

In [None]:
def main(save=True, **kwargs):
    img = preprocess(save=False, **kwargs)
    
    # If save is a bool, and is true, create a path
    if save and isinstance(save, bool):
        save = kwargs['f'].parent / kwargs['f'].stem
        
    # If save is a string, make it a path
    elif save and isinstance(save, str):
        save = Path(kwargs['save'])
        
    # If save is a path, keep it as is
    elif save and isinstance(save, Path):
        pass
    
    # Otherwise, set save to None
    else:
        save = None
        
    return postprocess(img, save=save, **kwargs)

In [None]:
main(f=files[5],
     color_locations=[(0,0)],
     removal_regions=[Region(425,250,750,375)])

# Finally

Here's the tool.

In [None]:
%matplotlib notebook

In [None]:
f=files[3]
im = np.array(Image.open(f))
plt.close('all')
plt.figure()
plt.imshow(im)
plt.show()

In [None]:
%matplotlib inline
main(f=files[0],
     color_locations=[ (348,630), (25, 150)],
     removal_regions=[Region(x0=13,y0=590,x1=562,y1=619),
                      Region(x0=53,y0=14,x1=344,y1=12),
                      Region(x0=347,y0=9,x1=560,y1=62),
                      Region(x0=460,y0=245,x1=520,y1=260),
                      Region(x0=35,y0=186,x1=207,y1=283),
                      Region(x0=220,y0=180,x1=265,y1=192),
                      Region(x0=330,y0=130,x1=400,y1=145)],
     post_color_locations=[(481,455),(459,422),(179,499)])

In [None]:
%matplotlib inline
main(f=files[2],
     color_locations=[(0,0)],
     removal_regions=[Region(x0=89,y0=300,x1=262,y1=354)],
     post_color_locations=[])