In [34]:
%config InlineBackend.figure_format = 'svg'
import os
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import io, data, filters
from skimage.viewer import ImageViewer
from skimage.exposure import histogram
from skimage.measure import label
from skimage.segmentation import flood, flood_fill
from scipy import ndimage
from scipy.ndimage.measurements import label
from functools import reduce
from PIL import Image
import cv2

# https://www.youtube.com/watch?v=TyWtx7q2D7Y
# Connected componenets

def scale_down(img, scalar):
    dims = img.shape
    height, length = dims[0], dims[1]
    return (length/scalar, height/scalar)

def show(img, scalar=1):
    if scalar:
        plt.figure(figsize = scale_down(img, scalar))
    else:
        plt.figure()
    io.imshow(img, aspect='auto')
    io.show()

def centroid(t1,t2):
    return (int(sum(t1)/2), int(sum(t2)/2))

# Todo refactor to check set of sets for largest then retrieve row
def unique(array, ignore=[0]):
    ignore = set(ignore)
    seen = set()
    u = []
    for a in array:
        if a not in seen and a not in ignore:
            seen.add(a)
            u.append(a)
    return u
            
def get_label_sets(labeled):
    # loop through rows
    # later allow for cols/groups
    label_sets = set()
    order_rows = []
    for row in labeled:
        clean_row = unique(row)
        frozen = frozenset(clean_row)
        if frozen not in label_sets:
            label_sets.add(frozen)
            if clean_row:
                order_rows.append({'length':len(clean_row), 'row': clean_row})
    return order_rows

def get_label_hist(labeled):
    row_label_sets = get_label_sets(labeled)
    x_vals = range(len(row_label_sets))
    y_vals = [y['length'] for y in row_label_sets]

    fig = plt.figure()
    ax = fig.add_axes([0,0,1,1])
    ax.bar(x_vals,y_vals)
    plt.show()
    
def filter_bounds(bounds):
    bound_nums = []
    p1 = 0
    for p2 in range(1,len(bounds)):
        if bounds[p2]-1 == bounds[p2-1]:
            p2 += 1
        else:
            p1 = p2-1
            bound_nums.append((bounds[p1],bounds[p2]))
        
    return bound_nums
        
def get_bounds(int_img):
    return [i for i, row in enumerate(int_img) if not sum(row)]
#     bounds = []
#     for i, row in enumerate(int_img):            
#         if not sum(row):
#             bounds.append(i)
#     return bounds


def get_labeled_for_rows(labeled, row_tups):
    sprite_rows = []
    for tup in row_tups:
        i, j = tup
        i_max, max_row = 0, []
        chunk = [{'index':i, 'row':unique(r)} for r in labeled[i+1:j]]
        #list(map(unique, labeled[i+1:j]))
        chunk_sets = list(map(set,chunk))
        chunk_set = reduce(lambda a,b: a.union(b), chunk_sets)
        sorted_rows = sorted(chunk, key=lambda x: len(x['row']), reverse=True)
        while set(max_row) != chunk_set:
            if not sorted_rows:
                break
            row_d = sorted_rows.pop(0)
            i_extra, extra = row_d['index'], row_d['row']
            # if i_extra > i_max and set(extra).issubset(set(max_row)):
            max_row.extend(extra)
#             else:
#                 extra.extend(max_row)
#                 max_row = extra
#             i_max = i_extra
        sprite_rows.append(unique(max_row))
    return sprite_rows


In [38]:
img = io.imread('link.png')
background = flood(img[..., 0], (0,0), tolerance=0.0)
img[background] = 0
#show(img, 50)
intback = np.invert(background).astype(int)
print(intback)
print(type(intback), intback.shape)
intback = intback.T
print(type(intback), intback.shape)
print('-----')
print(get_bounds(intback))

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
<class 'numpy.ndarray'> (365, 665)
<class 'numpy.ndarray'> (665, 365)
-----
[0]


In [33]:
#show(intback,50)
#show(img[0:50,0:50],50)

structure = np.ones((3, 3), dtype=np.int)
labeled, ncomponents = label(intback, structure)

bound_tups = filter_bounds(get_bounds(intback))
labeled_sprites = get_labeled_for_rows(labeled, bound_tups)
# print('----')
# n = 0
# for r in labeled_sprites:
#     n += len(r)
# print(ncomponents, n)

print(labeled_sprites)
# for r, row in enumerate(labeled_sprites):
#     print(row)
#     for i,n in enumerate(row):
#         filename = f"tmp/g{r}_{i}.png"
#         raw_inds = np.where(labeled==n)
#         rrow, rcol = raw_inds
#         minr, maxr = int(min(rrow)), int(max(rrow))
#         minc, maxc = int(min(rcol)), int(max(rcol))     
# #         comp_dict['row_tup'] = (minr, maxr+1)
# #         comp_dict['col_tup'] = (minc, maxc+1)
# #         topleft, topright = c['row_tup']
# #         botleft, botright = c['col_tup']
#         sub_image = img[minr:maxr+1,minc:maxc+1]
#         #show(sub_image,50)
#         io.imsave(filename,sub_image)
        

[[1, 2, 3, 4, 20, 21, 16, 22, 25, 26, 27, 29, 31, 23, 24, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 28, 30, 17, 18, 19], [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 74, 56, 57, 58, 67, 61, 62, 63, 68, 69, 64, 65, 66, 88, 89, 59, 60, 85, 86, 87, 75, 76, 71, 72, 70, 73, 77, 78, 81, 82, 83, 84, 79, 80], [90, 91, 92, 93, 102, 109, 103, 110, 104, 106, 111, 107, 113, 94, 95, 96, 97, 98, 99, 100, 101, 105, 114, 112, 108], [115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 149, 148, 144, 145, 146, 147, 150, 151, 152, 153, 154, 155, 156], [157, 158, 159, 160, 161, 182, 183, 185, 186, 189, 190, 175, 176, 178, 179, 187, 188, 180, 181, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 177, 191, 184], [192, 193, 194, 195, 198, 199, 200, 220, 205, 206, 222, 196, 197, 201, 202, 203, 204, 207, 209, 210, 223, 211, 224, 214, 215, 217, 219, 221, 2

In [None]:
#show(labeled, 50)
clusters = []
# component_sort_dict = {}
#for n in range(ncomponents):
#for n in final_rows[0]['row']:
for n in labeled_sprites[0]:
    raw_inds = np.where(labeled==n)
    rrow, rcol = raw_inds
#     indices = zip(rrow, rcol)
#     indices = sorted(indices, key=lambda x:x[1])
#     indices = sorted(indices, key=lambda x:x[0])
#     least_index = indices[0]
    #component_sort_dict[n] = least_index
    minr, maxr = int(min(rrow)), int(max(rrow))
    minc, maxc = int(min(rcol)), int(max(rcol))
    comp_dict = {}
    comp_dict['index']= n
    comp_dict['row_tup'] = (minr, maxr+1)
    comp_dict['col_tup'] = (minc, maxc+1)
    #comp_dict['point'] = (maxr,maxc)
    comp_dict['point'] = centroid((minr, maxr+1),(minc, maxc+1))
    #comp_dict['point'] = least_index
    #comp_dict['img'] = img[minr:maxr+1,minc:maxc+1]
    #sub_image = img[minr:maxr+1,minc:maxc+1]
    #show(sub_image,50)
    clusters.append(comp_dict)
    
    
#clusters = sorted(clusters, key=lambda k: k['point'][0]) 
#clusters = sorted(clusters, key=lambda k: k['point'][1]) 
#clusters = sorted(clusters, key=lambda k: k['index'])




### POST SORT LOGIC ####
for c in clusters:
    print(c['index'])
    topleft, topright = c['row_tup']
    botleft, botright = c['col_tup']
    sub_image = img[topleft:topright,botleft:botright]
    #io.imsave(f'tmp/img{n}.png',sub_image)
    #print(c['point'])
    show(sub_image,50)
        

In [None]:
#row_label_sets = sorted(row_label_sets, key=lambda k: k['length'], reverse=True)
#print(row_label_sets )
print('-------------------------')
# seen = set()
# max_ind, max_len = -1, 0
# final_rows = []
# for i,r in enumerate(row_label_sets):
#     if r['length'] < max_len:
#         row_set = set(row_label_sets[max_ind]['row'])
#         if bool(seen & row_set):
#             continue
#         else:
#             seen.update(row_set)
#             final_rows.append(row_label_sets[max_ind])
#             max_ind, max_len = -1, 0
#     max_ind, max_len = i, r['length']

            
# set_lengths = {}
# prev_len, max_ind = 0, 0 
# ordering = 0
# for i in range(len(row_label_sets)):
#     current = len(row_label_sets[i])
#     if current < prev_len:
#         set_lengths[ordering] = row_label_sets[max_ind]
#         ordering += 1
#     prev_len=current
#     max_ind = i
# print(set_lengths)
    
#show(labeled, 50)
clusters = []
# component_sort_dict = {}
for n in range(ncomponents):
#for n in final_rows[0]['row']:
    raw_inds = np.where(labeled==n)
    rrow, rcol = raw_inds
#     indices = zip(rrow, rcol)
#     indices = sorted(indices, key=lambda x:x[1])
#     indices = sorted(indices, key=lambda x:x[0])
#     least_index = indices[0]
    #component_sort_dict[n] = least_index
    minr, maxr = int(min(rrow)), int(max(rrow))
    minc, maxc = int(min(rcol)), int(max(rcol))
    comp_dict = {}
    comp_dict['index']= n
    comp_dict['row_tup'] = (minr, maxr+1)
    comp_dict['col_tup'] = (minc, maxc+1)
    #comp_dict['point'] = (maxr,maxc)
    comp_dict['point'] = centroid((minr, maxr+1),(minc, maxc+1))
    #comp_dict['point'] = least_index
    #comp_dict['img'] = img[minr:maxr+1,minc:maxc+1]
    #sub_image = img[minr:maxr+1,minc:maxc+1]
    #show(sub_image,50)
    clusters.append(comp_dict)
    
    
#clusters = sorted(clusters, key=lambda k: k['point'][0]) 
#clusters = sorted(clusters, key=lambda k: k['point'][1]) 
#clusters = sorted(clusters, key=lambda k: k['index'])




### POST SORT LOGIC ####
# for c in clusters[0:20]:
#     print(c['index'])
#     topleft, topright = c['row_tup']
#     botleft, botright = c['col_tup']
#     sub_image = img[topleft:topright,botleft:botright]
#     #io.imsave(f'tmp/img{n}.png',sub_image)
#     #print(c['point'])
#     show(sub_image,50)
        

##### Sub Div stuff ######
def get_diff_list(rows):
    diffrows = [None]
    for i in range(len(rows)-1):
        if rows[i+1]-rows[i] ==1:
            continue
        else:
            if diffrows[-1]!=rows[i]:
                diffrows.append(rows[i])
            diffrows.append(rows[i+1])
    return diffrows[1:]
    
rows = []
for row in range(len(img)):
    num_colors = len(set(list(tuple(x) for x in img[row])))
    if num_colors == 1:
        rows.append(row)
print(rows)

# diffrows = get_diff_list(rows)
# col_dict = {}
# for i in range(len(diffrows)-1):
#     up, down = diffrows[i],diffrows[i+1]
#     for col in range(len(img[0])):
#         num_colors = len(set(list(tuple(x) for x in img[up:down,col])))
#         if num_colors == 1:
#             if (up,down) not in col_dict:
#                 col_dict[(up,down)] = []
#             col_dict[(up,down)].append(col)
    
for rbound in rows:
    img[rbound,:] = 255
    
# for tup,col in col_dict.items():
#     up,down = tup
#     img[up:down,col] = 255

row_dict = {}
for tup in col_dict:
    up,down = tup
    bound_cols = get_diff_list(col_dict[tup])
    for i in range(len(bound_cols)-1):
        left,right = bound_cols[i], bound_cols[i+1]
        for row in range(up,down+1):
            num_colors = len(set(list(tuple(x) for x in img[row,left:right])))
            if num_colors == 1:
                if row not in row_dict:
                    row_dict[row] = []
                row_dict[row] = (left,right)
                #img[row,left:right] = 0

# for row,tup in row_dict.items():
#     left,right = tup
#     img[row,left:right] = 255
show(img,50)