In [None]:
# this implementation can be greatly improved if i use 1d array represntation instead of 2d
# unncessary overhead due to string conversions

from typing import *
import heapq
import copy

In [134]:
type GridInt = List[List[int]]
type Pos = Tuple[int, int]

class Node:
  def __init__(self, h_cost: int, g_cost: int, state: GridInt):
    self.h_cost = h_cost
    self.g_cost = g_cost
    self.state = state
  
  def __lt__(self, obj: 'Node') -> bool:
    return (self.g_cost + self.h_cost) < (obj.g_cost + obj.h_cost)
  
  def __str__(self) -> str:
    return f'Node(\n{' ':4}cost={self.h_cost + self.g_cost},\n{' ':4}state={{\n{' ':8}{f'\n{' ':8}'.join(map(str, self.state))}\n{' ':4}}}\n)'
  
  def __repr__(self):
    return self.__str__()
  
  def __iter__(self) -> Iterable[Tuple[int, GridInt]]:
    return iter((self.g_cost, self.state))

goal: GridInt = [
  [0, 1, 2],
  [3, 4, 5],
  [6, 7, 8]
]

initial: GridInt = [
  [1, 2, 3],
  [5, 6, 0],
  [7, 8, 4]
]

In [135]:
def is_goal(state: GridInt, goal: GridInt) -> bool:
  return all(x == y for x, y in zip(state, goal))

def serialize(state: GridInt) -> str:
  return ''.join([str(num) for row in state for num in row])

def deserialize(state_str: str, m: int) -> GridInt:
  return [list(map(int, state_str[i:i + m])) for i in range(0, len(state_str), m)]

def manhattan(a: Pos, b: Pos) -> int:
  return abs(a[0] - b[0]) + abs(a[1] - b[1])

def heuristic(state: GridInt, goal_pos: Dict[int, Pos]) -> int:
  return sum(manhattan((i, j), goal_pos.get(state[i][j], (i, j))) for i in range(len(state)) for j in range(len(state[0])))

def get_neighbors(state: GridInt) -> List[GridInt]:
  n, m = len(state), len(state[0])
  zr, zc = next((i, j) for i in range(n) for j in range(m) if state[i][j] == 0)

  states: List[GridInt] = []
  for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
    nr, nc = zr + dr, zc + dc
    if not (0 <= nr < n and 0 <= nc < m):
      continue
    new_state = copy.deepcopy(state)
    new_state[zr][zc], new_state[nr][nc] = new_state[nr][nc], new_state[zr][zc]
    states.append(new_state)

  return states

def solve(initial: GridInt, goal: GridInt) -> Any:
  goal_pos: Dict[int, Pos] = {goal[i][j]: (i, j) for i in range(len(goal)) for j in range(len(goal[0]))}
  distances: Dict[str, Tuple[int, Optional[str]]] = {serialize(initial): (0, None)}

  def reconstruct_path(goal_str: str) -> List[str]:
    path: List[str] = [goal_str]
    current: Optional[str] = distances[goal_str][1]
    while current:
      path.append(current)
      current = distances[current][1]

    return path[::-1]

  pq = [Node(heuristic(initial, goal_pos), 0, initial)]
  while pq:
    cost, state = heapq.heappop(pq)
    state_str = serialize(state)

    if is_goal(state, goal):
      return reconstruct_path(state_str)

    for neighbor in get_neighbors(state):
      neighbor_str = serialize(neighbor)
      if cost + 1 >= distances.get(neighbor_str, (float('inf'), None))[0]:
        continue
      distances[neighbor_str] = (cost + 1, state_str)
      heapq.heappush(pq, Node(heuristic(neighbor, goal_pos), cost + 1, neighbor))

steps = solve(initial, goal)
m = len(initial[0])
for i, step in enumerate(steps):
  print(f'step #{i + 1}')
  print(*deserialize(step, m), sep="\n", end="\n\n")

step #1
[1, 2, 3]
[5, 6, 0]
[7, 8, 4]

step #2
[1, 2, 3]
[5, 0, 6]
[7, 8, 4]

step #3
[1, 2, 3]
[0, 5, 6]
[7, 8, 4]

step #4
[0, 2, 3]
[1, 5, 6]
[7, 8, 4]

step #5
[2, 0, 3]
[1, 5, 6]
[7, 8, 4]

step #6
[2, 5, 3]
[1, 0, 6]
[7, 8, 4]

step #7
[2, 5, 3]
[1, 6, 0]
[7, 8, 4]

step #8
[2, 5, 0]
[1, 6, 3]
[7, 8, 4]

step #9
[2, 0, 5]
[1, 6, 3]
[7, 8, 4]

step #10
[0, 2, 5]
[1, 6, 3]
[7, 8, 4]

step #11
[1, 2, 5]
[0, 6, 3]
[7, 8, 4]

step #12
[1, 2, 5]
[6, 0, 3]
[7, 8, 4]

step #13
[1, 2, 5]
[6, 3, 0]
[7, 8, 4]

step #14
[1, 2, 5]
[6, 3, 4]
[7, 8, 0]

step #15
[1, 2, 5]
[6, 3, 4]
[7, 0, 8]

step #16
[1, 2, 5]
[6, 3, 4]
[0, 7, 8]

step #17
[1, 2, 5]
[0, 3, 4]
[6, 7, 8]

step #18
[1, 2, 5]
[3, 0, 4]
[6, 7, 8]

step #19
[1, 2, 5]
[3, 4, 0]
[6, 7, 8]

step #20
[1, 2, 0]
[3, 4, 5]
[6, 7, 8]

step #21
[1, 0, 2]
[3, 4, 5]
[6, 7, 8]

step #22
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]

