# Lojdova slagalica

In [38]:
import copy
from collections import defaultdict

In [16]:
def serialize(matrix):
    result = []
    for row in matrix:
        for col in row:
            result.append(str(col)) # [4,1,3,0,2,5,7,8,6]
    return ':'.join(result)

In [17]:
def deserialize(state):
    splited = state.split(':') # ['4', '1', '3', '0'...]
    splited = [int(x) for x in splited]
    return [splited[:3], splited[3:6], splited[6:]]

In [24]:
# Funkcija vraca sve moguce naredne poteze u odnosu na tekuce trenutno stanje

def get_neighbours(state):
    matrix = deserialize(state)
    blank_i, blank_j = -1, -1
    
    n = len(matrix)
    for i in range(n):
        for j in range(n):
            if matrix[i][j] == 0:
                blank_i, blank_j = i, j
                break
    
    neighbours = []
    if blank_i > 0:
        new_matrix = copy.deepcopy(matrix)
        new_matrix[blank_i][blank_j] = new_matrix[blank_i - 1][blank_j]
        new_matrix[blank_i - 1][blank_j] = 0
        neighbours.append(serialize(new_matrix))
    
    if blank_i < (n-1):
        new_matrix = copy.deepcopy(matrix)
        new_matrix[blank_i][blank_j] = new_matrix[blank_i + 1][blank_j]
        new_matrix[blank_i + 1][blank_j] = 0
        neighbours.append(serialize(new_matrix))
    
    if blank_j > 0:
        new_matrix = copy.deepcopy(matrix)
        new_matrix[blank_i][blank_j] = new_matrix[blank_i][blank_j - 1]
        new_matrix[blank_i][blank_j - 1] = 0
        neighbours.append(serialize(new_matrix))
    
    if blank_j < (n-1):
        new_matrix = copy.deepcopy(matrix)
        new_matrix[blank_i][blank_j] = new_matrix[blank_i][blank_j + 1]
        new_matrix[blank_i][blank_j + 1] = 0
        neighbours.append(serialize(new_matrix))
    
    return zip(neighbours, [1 for _ in neighbours])

In [23]:
def get_next_node(open_set, heuristic_guess):
    next_node = None
    min_heuristic = float('inf')
    for node in open_set:
        if node in heuristic_guess:
            guess = heuristic_guess[node]
            if guess < min_heuristic:
                min_heuristic = guess
                next_node = node
                
    return next_node

In [22]:
start_state = [
    [4, 5, 1],
    [2, 8, 3],
    [7, 6, 0]
]
end_state = [
    [1, 2, 3],
    [4, 5, 6], 
    [7, 8, 0]
]

In [46]:
def solve_loyd(start_state, end_state, h):
    
    start_state = serialize(start_state)
    end_state = serialize(end_state)
    
    open_set = set()
    open_set.add(start_state)
    
    parents = {}
    parents[start_state] = None
    
    cheapest_path = defaultdict(lambda: float('inf'))
    cheapest_path[start_state] = 0
    
    heuristic_guess = defaultdict(lambda: float('inf'))
    heuristic_guess[start_state] = h(start_state)
    
    path_found = False
    while len(open_set) > 0:
        
        current_state = get_next_node(open_set, heuristic_guess)
        if current_state == end_state:
            path_found = True
            break
            
        open_set.remove(current_state)
        for (neighbour, weight) in get_neighbours(current_state):
            new_cheapest_weigth = cheapest_path[current_state] + weight
            
            if new_cheapest_weigth < cheapest_path[neighbour]:
                
                parents[neighbour] = current_state
                cheapest_path[neighbour] = new_cheapest_weigth
                heuristic_guess[neighbour] = new_cheapest_weigth + h(neighbour)
                
                if neighbour not in open_set:
                    open_set.add(neighbour)
                    
    path = []
    if path_found:
        while end_state is not None:
            path.append(end_state)
            end_state = parents[end_state]
        path.reverse()
        
    return path

In [47]:
def loyd_h(state):
    state = deserialize(state)
    H = 0
    n = len(state)
    for i in range(n):
        for j in range(n):
            H += abs(state[i][j] % n - j) + abs(state[i][j] / n - i)    
    return H

In [48]:
p = solve_loyd(start_state, end_state, loyd_h)

In [52]:
for state in p:
    d_state = deserialize(state)
    for row in d_state:
        print(row)
    print()

[4, 5, 1]
[2, 8, 3]
[7, 6, 0]

[4, 5, 1]
[2, 8, 3]
[7, 0, 6]

[4, 5, 1]
[2, 0, 3]
[7, 8, 6]

[4, 0, 1]
[2, 5, 3]
[7, 8, 6]

[4, 1, 0]
[2, 5, 3]
[7, 8, 6]

[4, 1, 3]
[2, 5, 0]
[7, 8, 6]

[4, 1, 3]
[2, 0, 5]
[7, 8, 6]

[4, 1, 3]
[0, 2, 5]
[7, 8, 6]

[0, 1, 3]
[4, 2, 5]
[7, 8, 6]

[1, 0, 3]
[4, 2, 5]
[7, 8, 6]

[1, 2, 3]
[4, 0, 5]
[7, 8, 6]

[1, 2, 3]
[4, 5, 0]
[7, 8, 6]

[1, 2, 3]
[4, 5, 6]
[7, 8, 0]

