<a href="https://colab.research.google.com/github/suzanpoudel/AI-Lab-020391/blob/main/8_puzzle_problem_using_A_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 8 puzzle problem using A* algorithm in Python.
# Here, manhattan distance is being used to compute heuristic for any state of 8 puzzle and cost is considered to be zero (0) for all states. As heapq module seems to consider (parent, child) ordered pairs in OPEN queue as separate elements, the node pairs are serialized/deserialized to and from string using json module to account for this behavior.

from copy import deepcopy
from json import dumps, loads
from heapq import heappush, heappop


class EightPuzzle:
    def __init__(self, initial_state, goal_state):
        """
        Initialize open queue with initial state, goal state and closed set for visited nodes.
        """
        self.goal_state = goal_state
        self.open = []
        heappush(
            self.open, (self.heuristic(initial_state), dumps([None, initial_state]))
        )  # pq_item is pair of [parent, child], heapq is min queue
        self.closed = []

    def goal_test(self, state):
        """
        Check to see if current state matches goal state of 8 puzzle.
        """
        for i in range(len(self.goal_state)):
            for j in range(len(self.goal_state[0])):
                item = state[i][j]
                item_g = self.goal_state[i][j]
                if item != item_g:
                    return False
        return True

    def heuristic(self, state):
        """
        Calculate heuristic value for any particular 8 puzzle state.
        Here, manhattan distance between given state's positions and goal positions is taken.
        """
        row_length = len(self.goal_state)
        col_length = len(self.goal_state[0])

        h = 0
        for i in range(row_length):
            for j in range(col_length):
                num = state[i][j]
                if num == 0:
                    continue
                iter_flag = True
                for k in range(row_length):
                    if not iter_flag:
                        break
                    for l in range(col_length):
                        test_num = self.goal_state[k][l]
                        if num == test_num:
                            h += abs(i - k) + abs(j - l)
                            iter_flag = False
                            break

        return h

    def update_position(self, state, x, y, move):
        """
        Return new state with provided current position of "hole" and a valid move.
        Valid move must be one of "UP", "DOWN", "LEFT" or "RIGHT".
        """
        new_x, new_y = x, y

        if move == "UP":
            new_x -= 1
        elif move == "DOWN":
            new_x += 1
        elif move == "LEFT":
            new_y -= 1
        elif move == "RIGHT":
            new_y += 1

        next_state = deepcopy(state)
        neighbor = state[new_x][new_y]
        next_state[x][y] = neighbor
        next_state[new_x][new_y] = 0

        return next_state

    def successor(self, state):
        """
        Generate successor states based on current state and production rules.
        """
        # Figure out row and column lengths of puzzle
        row_length = len(self.goal_state)
        col_length = len(self.goal_state[0])

        # Figure out position of "hole"
        pos_x = pos_y = 0
        iter_flag = True
        for i in range(row_length):
            if not iter_flag:
                break
            for j in range(col_length):
                if state[i][j] == 0:
                    pos_x, pos_y = i, j
                    iter_flag = False
                    break

        succ = []
        # Move "hole" up
        if pos_x > 0:
            next_state = self.update_position(state, pos_x, pos_y, "UP")
            succ.append(next_state)
        # Move "hole" down
        if pos_x < row_length - 1:
            next_state = self.update_position(state, pos_x, pos_y, "DOWN")
            succ.append(next_state)
        # Move "hole" left
        if pos_y > 0:
            next_state = self.update_position(state, pos_x, pos_y, "LEFT")
            succ.append(next_state)
        # Move "hole" right
        if pos_y < col_length - 1:
            next_state = self.update_position(state, pos_x, pos_y, "RIGHT")
            succ.append(next_state)

        return succ

    def get_pair_child(self, node_pair_iter):
        """
        Returns list of only children from iterable containing (parent, child) node pairs.
        """
        return [pair[1] for pair in node_pair_iter]

    def a_star_search(self):
        """
        A* Search using open as priority queue.
        """
        while self.open:
            # Dequeue, add to closed set and check if current node is goal
            node_pair = loads(heappop(self.open)[1])
            self.closed.append(node_pair)
            _, node = node_pair
            if self.goal_test(node):
                return node
            # If current node is not goal, generate successors and add to open queue
            open_node_pairs = [loads(pq_item[1]) for pq_item in self.open]
            for s in self.successor(node):
                if s not in self.get_pair_child(
                    open_node_pairs
                ) and s not in self.get_pair_child(self.closed):
                    cost = 0  # Assume cost 0 for every node
                    f_score = self.heuristic(s) + cost  # heapq is min queue
                    pq_item = (f_score, dumps([node, s]))
                    heappush(self.open, pq_item)
        # Return None if goal not found
        return None

    def print_state(self, state):
        """
        Display current state of 8 puzzle.
        """
        for row in state:
            print(row)

    def generate_path(self):
        """
        Generate the path from initial state to solution/goal state.
        """
        path = []
        node = self.goal_state
        while node:
            path.append(node)
            for parent, child in self.closed:
                if node == child:
                    node = parent
                    break
        path.reverse()

        for i, node in enumerate(path):
            print_string = "Initial State:" if i == 0 else f"Step {i}:"
            print(print_string)
            self.print_state(node)
            print()

    def run(self):
        """
        Driver method.
        """
        goal_node = self.a_star_search()
        if goal_node:
            self.generate_path()
            print("Found the goal state!")
            self.print_state(goal_node)
        else:
            print("Did not find goal state")


initial_state = [
    [1, 3, 6],
    [4, 5, 2],
    [7, 8, 0],
]

goal_state = [[1, 2, 3], [4, 5, 6], [7, 8, 0]]

ep = EightPuzzle(initial_state, goal_state)
ep.run()

Initial State:
[1, 3, 6]
[4, 5, 2]
[7, 8, 0]

Step 1:
[1, 3, 6]
[4, 5, 2]
[7, 0, 8]

Step 2:
[1, 3, 6]
[4, 0, 2]
[7, 5, 8]

Step 3:
[1, 3, 6]
[4, 2, 0]
[7, 5, 8]

Step 4:
[1, 3, 0]
[4, 2, 6]
[7, 5, 8]

Step 5:
[1, 0, 3]
[4, 2, 6]
[7, 5, 8]

Step 6:
[1, 2, 3]
[4, 0, 6]
[7, 5, 8]

Step 7:
[1, 2, 3]
[4, 5, 6]
[7, 0, 8]

Step 8:
[1, 2, 3]
[4, 5, 6]
[7, 8, 0]

Found the goal state!
[1, 2, 3]
[4, 5, 6]
[7, 8, 0]
