In [1]:
import numpy as np
import itertools
from skimage import color, io

P = 0.3
Q = 0.0625

LEFT = 0
RIGHT = 1
UP = 2
DOWN = 3

In [2]:
def load_image(path):
    return color.rgb2lab(io.imread(path))


def split_into_squares(image, nrows, ncols):
    return list(itertools.chain.from_iterable(np.hsplit(r, ncols) for r in np.split(image, nrows)))


def calculate_dissimilarity(x_i, x_j, relation):
    nrows, ncols, _ = x_i.shape

    if relation == LEFT:
        return calculate_dissimilarity(x_j, x_i, RIGHT)
    elif relation == RIGHT:
        return np.sum(
            np.power(np.power(np.abs((2 * x_i[:, ncols - 1] - x_i[:, ncols - 2]) - x_j[:, 0]), P) +
                     np.power(np.abs((2 * x_j[:, 0] - x_j[:, 1]) - x_i[:, ncols - 1]), P), Q / P))
    elif relation == UP:
        return calculate_dissimilarity(x_j, x_i, DOWN)
    elif relation == DOWN:
        return np.sum(
            np.power(np.power(np.abs((2 * x_i[nrows - 1] - x_i[nrows - 2]) - x_j[0]), P) +
                     np.power(np.abs((2 * x_j[0] - x_j[1]) - x_i[nrows - 1]), P), Q / P))
    else:
        raise TypeError(f'invalid relation: {relation}')


def build_dissimilarity_matrix(squares):
    dissimilarity_matrix = np.empty((4, len(squares), len(squares)))
    for relation in range(4):
        for i, x_i in enumerate(squares):
            for j, x_j in enumerate(squares):
                if i == j:
                    continue
                dissimilarity_matrix[relation][i][j] = calculate_dissimilarity(x_i, x_j, relation)
    return dissimilarity_matrix


def calculate_compatibility(dissimilarity_matrix, i, j, relation):
    return np.exp(-dissimilarity_matrix[relation][i][j] /
                  np.percentile(np.delete(dissimilarity_matrix[relation][i], i), 25))


def build_compatibility_matrix(dissimilarity_matrix):
    _, order, _ = dissimilarity_matrix.shape
    compatibility_matrix = np.empty((4, order, order))
    for relation in range(4):
        for i in range(order):
            for j in range(order):
                if i == j:
                    continue
                compatibility_matrix[relation][i][j] = calculate_compatibility(dissimilarity_matrix, i, j, relation)
    return compatibility_matrix


def calculate_best_neighbours(compatibility_matrix):
    _, order, _ = compatibility_matrix.shape
    best_neighbours = np.zeros((4, order), dtype=int)
    for relation in range(4):
        for i in range(order):
            best_neighbours[relation][i] = np.argmax(compatibility_matrix[relation][i])
    return best_neighbours


def opposite_relation(relation):
    if relation == 0 or relation == 2:
        return relation + 1
    else:
        return relation - 1

In [3]:
def find_best_estimated_seed(best_neighbours):
    _, order = best_neighbours.shape
    num_best_buddies = np.zeros(order, dtype=int)
    for relation in range(4):
        for i in range(order):
            buddy = best_neighbours[relation][i]
            opposite = opposite_relation(relation)
            if best_neighbours[opposite][buddy] == i:
                num_best_buddies[i] += 1
    return np.argmax(num_best_buddies)


In [4]:
image = load_image('../data/data_train/64/1200.png')
squares = split_into_squares(image, 8, 8)
dissimilarity_matrix = build_dissimilarity_matrix(squares)
compatibility_matrix = build_compatibility_matrix(dissimilarity_matrix)
best_neighbours = calculate_best_neighbours(compatibility_matrix)
print(best_neighbours[RIGHT][0])
print(best_neighbours[LEFT][3])

3
0


In [5]:
print(find_best_estimated_seed(best_neighbours))

0


In [6]:
arr = np.arange(9).reshape((3, 3))
print(arr)
print(np.roll(arr, 1, 1))

[[0 1 2]
 [3 4 5]
 [6 7 8]]
[[2 0 1]
 [5 3 4]
 [8 6 7]]


In [7]:
def adjacent(i, j):
    yield i -1, j
    yield i + 1, j
    yield i, j- 1
    yield i, j+1

In [8]:
def is_in_puzzle(puzzle, i, j):
    nrows, ncols = puzzle.shape
    return 0 <= i < nrows and 0 <= j < ncols

In [9]:
def is_occupied_slot(puzzle, i, j):
    return is_in_puzzle(puzzle, i, j) and puzzle[i][j] >= 0

In [10]:
def find_candidate_slots(puzzle):
    slots = {}
    nrows, ncols = puzzle.shape
    for i in range(nrows):
        for j in range(ncols):
            if puzzle[i][j] >= 0:
                for x, y in adjacent(i, j):
                    # skip if we've already added this slot
                    if (x, y) in slots:
                        continue
                    # skip if (x, y) is already occupied
                    if is_occupied_slot(puzzle, x, y):
                        continue
                    # count the number of occupied slots around (x, y)
                    slots[(x, y)] = sum(1 if is_occupied_slot(puzzle, p, q) else 0 for p, q in adjacent(x, y))
    max_neighbours = max(slots.values())
    return set(slot for slot, num_neighbours in slots.items() if num_neighbours == max_neighbours)

In [11]:
puzzle = np.full((3, 3), -1)
puzzle[1][1] = 1
assert find_candidate_slots(puzzle) == {(0, 1), (1, 0), (1, 2), (2, 1)}
puzzle[0][1] = 1
puzzle[1][0] = 1
assert find_candidate_slots(puzzle) == {(0, 0)}
puzzle[0][0] = 1
print(puzzle)
print(find_candidate_slots(puzzle))
puzzle = np.array([[-1, -1, -1, -1, -1, -1, -1, -1],
       [32, 21, 22, 16, 14, 60, 29, 39],
       [33, 31, 54, 26, 10, 20,  1, 41],
       [56, 28, 12, 17, 25, 58, 59, 43],
       [45, 46, 44, 18,  0,  3, 62,  6],
       [38, 48, 13, 63, 27, 19, 50, 51],
       [23, 37, 40, 53, 42, 24, 61, 57],
       [-1, -1, -1, -1, -1, -1, -1, -1]])
print(find_candidate_slots(puzzle))

[[ 1  1 -1]
 [ 1  1 -1]
 [-1 -1 -1]]
{(1, 2), (-1, 1), (-1, 0), (2, 1), (2, 0), (0, -1), (1, -1), (0, 2)}
{(7, 3), (4, 8), (2, 8), (7, 7), (0, 7), (0, 3), (1, -1), (4, -1), (5, 8), (7, 2), (7, 6), (0, 4), (5, -1), (0, 0), (7, 1), (2, -1), (7, 5), (0, 5), (0, 1), (7, 0), (6, 8), (3, -1), (6, -1), (3, 8), (0, 6), (1, 8), (7, 4), (0, 2)}


In [12]:
def best_buddies(best_neighbours, relation, i, j):
    return best_neighbours[relation][i] == j and best_neighbours[opposite_relation(relation)][j] == i

In [13]:
# part fits in slot if it is best buddies with all the occupied neighbours of that slot
def does_part_fit_in_slot(puzzle, best_neighbours, slot, part):
    nrows, ncols = puzzle.shape
    i, j = slot
    
    for relation in range(4):
        x, y = related_coords(relation, i, j)
        if is_occupied_slot(puzzle, x, y):
            if not best_buddies(best_neighbours, relation, part, puzzle[x][y]):
                return False
    return True

In [14]:
def related_coords(relation, i, j):
    if relation == RIGHT:
        return i, j+1
    elif relation == LEFT:
        return i, j-1
    elif relation == UP:
        return i-1, j
    elif relation == DOWN:
        return i+1, j
    else:
        raise ValueError(f'invalid relation: {relation}')

In [15]:
def average_compatibility_with_slot(puzzle, compatibility_matrix, slot, part):
    nrows, ncols = puzzle.shape
    i, j = slot
    total_compatibility = 0
    num_neighbours = 0
    
    for relation in range(4):
        x, y = related_coords(relation, i, j)
        if is_occupied_slot(puzzle, x, y):
            total_compatibility += compatibility_matrix[relation][part][puzzle[x][y]]
            num_neighbours += 1
    
    return total_compatibility / num_neighbours

In [16]:
class SlotAssignError(Exception):
    pass

In [17]:
def try_assign(puzzle, slot, part, unallocated_parts):
    nrows, ncols = puzzle.shape
    # can always assign if slot is within puzzle
    i, j = slot
    if not is_in_puzzle(puzzle, i, j):
        # check if opposite edge is empty and roll puzzle
        # slot is above top edge - check if bottom edge is empty
        if i < 0:
            if not np.all(puzzle[-1] == -1):
                raise SlotAssignError
            puzzle = np.roll(puzzle, 1, 0)
            i += 1
        # slot is below bottom edge - check if top edge is empty
        elif i >= nrows:
            if not np.all(puzzle[0] == -1):
                raise SlotAssignError
            puzzle = np.roll(puzzle, -1, 0)
            i -= 1
        # slot is to the left of puzzle - check if right edge is empty
        elif j < 0:
            if not np.all(puzzle[:,-1] == -1):
                raise SlotAssignError
            puzzle = np.roll(puzzle, 1, 1)
            j += 1            
        # slot is to the right side of puzzle - check if left edge is empty
        elif j >= ncols:
            if not np.all(puzzle[:,0] == - 1):
                raise SlotAssignError
            puzzle = np.roll(puzzle, -1, 1)
            j -= 1
        else:
            raise ValueError('invalid slot')

    # update unallocated parts
    unallocated_parts.remove(part)
    # update puzzle
    puzzle[i][j] = part
    return puzzle, unallocated_parts


In [18]:
arr = np.arange(9).reshape((3, 3))
print(arr)
arr[:,-1] = 0
print(arr)
np.all(arr[:,-1] == 0)

[[0 1 2]
 [3 4 5]
 [6 7 8]]
[[0 1 0]
 [3 4 0]
 [6 7 0]]


True

In [19]:
def place_remaining_parts(puzzle, compatibility_matrix, unallocated_parts):
    best_neighbours = calculate_best_neighbours(compatibility_matrix)
    candidate_slots = find_candidate_slots(puzzle)

    while True:
        matches = [(slot, part) for slot in candidate_slots for part in unallocated_parts if does_part_fit_in_slot(puzzle, best_neighbours, slot, part)]
        if len(matches) == 1:
            slot, part = matches.pop()
        else:
            average_compatibilities = [(average_compatibility_with_slot(puzzle, compatibility_matrix, slot, part), (slot, part))
                                      for slot in candidate_slots for part in unallocated_parts]
            best = max(average_compatibilities, key=lambda x: x[0])
            slot, part = best[1]
        i, j = slot
        
        try:
            puzzle, unallocated_parts = try_assign(puzzle, slot, part, unallocated_parts)
            return puzzle, unallocated_parts
        except SlotAssignError:
            print(f'slot {slot} is not allowed, skipping')
            candidate_slots.remove(slot)
            if not candidate_slots:
                raise ValueError('no more slots')

    # update unallocated parts
    unallocated_parts.remove(part)
    # update puzzle
    puzzle[i][j] = part
    return puzzle, unallocated_parts

            

In [20]:
nrows = 8
ncols = 8
unallocated_parts = set(range(nrows * ncols))
best_estimated_seed = find_best_estimated_seed(best_neighbours)
puzzle = np.full((nrows, ncols), -1)
puzzle[nrows//2][ncols//2] = best_estimated_seed
unallocated_parts.remove(best_estimated_seed)
print(puzzle)

[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1  0 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]]


In [21]:
step = 1
while unallocated_parts:
#     print(f'step = {step}')
    step += 1
    puzzle, unallocated_parts = place_remaining_parts(puzzle, compatibility_matrix, unallocated_parts)
    print(puzzle)

[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1  0 -1 -1 -1]
 [-1 -1 -1 -1 27 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 18  0 -1 -1 -1]
 [-1 -1 -1 -1 27 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 18  0 -1 -1 -1]
 [-1 -1 -1 63 27 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 44 18  0 -1 -1 -1]
 [-1 -1 -1 63 27 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 44 18  0 -1 -1 -

[[-1 -1 -1 -1 -1 -1 -1 -1]
 [32 21 22 16 14 60 29 39]
 [33 31 54 26 10 20  1 41]
 [56 28 12 17 25 58 59 43]
 [45 46 44 18  0  3 62  6]
 [38 48 13 63 27 19 50 51]
 [23 37 40 53 42 24 61 57]
 [-1 -1 -1 -1 36  4 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [32 21 22 16 14 60 29 39]
 [33 31 54 26 10 20  1 41]
 [56 28 12 17 25 58 59 43]
 [45 46 44 18  0  3 62  6]
 [38 48 13 63 27 19 50 51]
 [23 37 40 53 42 24 61 57]
 [-1 -1 -1 34 36  4 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [32 21 22 16 14 60 29 39]
 [33 31 54 26 10 20  1 41]
 [56 28 12 17 25 58 59 43]
 [45 46 44 18  0  3 62  6]
 [38 48 13 63 27 19 50 51]
 [23 37 40 53 42 24 61 57]
 [-1 -1  5 34 36  4 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [32 21 22 16 14 60 29 39]
 [33 31 54 26 10 20  1 41]
 [56 28 12 17 25 58 59 43]
 [45 46 44 18  0  3 62  6]
 [38 48 13 63 27 19 50 51]
 [23 37 40 53 42 24 61 57]
 [-1 15  5 34 36  4 -1 -1]]
[[-1 -1 -1 -1 -1 -1 -1 -1]
 [32 21 22 16 14 60 29 39]
 [33 31 54 26 10 20  1 41]
 [56 28 12 17 25 58 59 43]
 [45 46 44 18  0  3 62  