In [19]:
from PIL import Image
import numpy as np
import math
import matplotlib.pyplot as plt
import os
import re

def file_names(directory):
    paths = os.listdir('puzzle_scans/puzzle_1')
    return [directory + '/' + file for file in paths if re.search('.jpg', file) is not None]

def read_img(img_path):
    img = Image.open(img_path)
    img_array = np.array(img, dtype='int32')
    img.close()
    return img_array

def reshape_array(arr):
    return np.reshape(arr, (arr.shape[0] * arr.shape[1], arr.shape[2]), order = 'C')

def display_img(array):
    arr = array.astype(dtype='uint8')
    img = Image.fromarray(arr, 'RGB')
    plt.figure()
    plt.imshow(np.asarray(img))

def split_img(arr, nrow, ncol):
    
    def find_splits(shape, num):
        split = math.floor(shape / num)
        return [split * i for i in range(1,num)]
    
    def equal_pixels(list_arr, ax):
        extra = list_arr[-1].shape[ax] - list_arr[0].shape[ax] + 1
        rem_indx = [list_arr[-1].shape[ax] - n for n in reversed(range(1, extra))]
        replace = np.delete(list_arr[-1], rem_indx, axis = ax)
        list_arr.pop(-1)
        return list_arr + [replace]
    
    pieces = []
    rows = equal_pixels(np.split(arr, find_splits(arr.shape[0], nrow)), 0)
    
    for row in rows:
        cols = equal_pixels(np.split(row, find_splits(row.shape[1], ncol), axis = 1), 1)
        pieces += cols
    
    return pieces

In [20]:
# list of files to read in
files = file_names('puzzle_scans/puzzle_1')
all_pics = []

for picture in files:
    
    # store picture as array
    pic = read_img(picture)

    # split each image into 20 pieces
    pieces = split_img(pic, 5, 4)
    assert all([piece.shape for piece in pieces]), 'All arrays are not the same size'

    # split each piece into n by n parts
    n = 3

    # split each piece into quarters
    chopped = []
    for piece in pieces:
        chopped.append(split_img(piece, n, n))    
    assert all([part.shape for piece in chopped for part in piece]), \
    'All parts of all pieces are not the same size'

    # reshape arrays and find average RGB value of each part of each piece    
    avg_rgb = np.zeros((len(chopped), n**2, 3))
    for i,piece in enumerate(chopped):
        for ii,part in enumerate(piece):
                chopped[i][ii] = reshape_array(part)
                avg_rgb[i,ii,:] = np.mean(chopped[i][ii], axis = 0)
    
    all_pics.append(avg_rgb)

# combine all pictures into one large array for clustering
all_pics = np.vstack(all_pics)
assert all_pics.shape == (300, n**2, 3), \
f'Output array should be size (300,{n**2},3) but is actually size {all_pics.shape}'

(300, 9, 3)