We're going to need numpy. We will also use another library called heapq which implements the minheaps explained above.

In [None]:
import numpy as np
import heapq

First we create a class called Astar

In [None]:
class Astar:

Let's start with some of the helper functions we're going to need

First we're going to need what valid neighbours a node has. since $\texttt{maze[x] = 1}$ indicates that a node is passable we can use some tidy logic.

In [None]:
    def neighbours(self,current):
        x,y = current
        neighbours = []

        if x < self.n-1 and self.maze[x + 1, y] == 1:
            neighbours.append((x+1,y))

        if y < self.m-1 and self.maze[x, y + 1] == 1:
            neighbours.append((x,y +1))

        if x > 0 and self.maze[x - 1 , y] == 1:
            neighbours.append((x-1,y))

        if y > 0 and self.maze[x,y-1] == 1:
            neighbours.append((x,y-1))

        return neighbours

We will also need a function to caluclate $h(x)$ at a given node.

In [None]:
    def griddist(self, current, goal):
        return abs(goal[0] - current[0]) + abs(goal[1] - current[1])

Create a default constructor that takes in the maze, start point, and the end point.

In [None]:
    def __init__(self, maze,start,end):
        
        self.maze = maze
        self.n,self.m = maze.shape
        self.start = start
        self.xstart, self.ystart = start
        self.end = end

Initalise a 2d numpy array for $g$ and fill it with infinities. Then set $g($start$) = 0$

In [None]:
        self.gs = np.full((self.n,self.m), np.inf)
        
        self.gs[self.xstart,self.ystart] = 0

Initialise $h$ using the griddist function we defined earlier.

In [None]:
        self.hs = np.zeros((self.n,self.m))

        for i in range(self.n):
            for j in range(self.m):
                self.hs[i,j] = self.griddist((i,j),end)

The $f$ cost is calculated as the sum of the $g$ cost and the $h$ cost.

In [None]:
        self.fs = self.gs + self.hs

Create a dictionary called $\texttt{came_from}$ to store which nodes were reached from where. This will be useful later when we reconstruct our path.

In [None]:
        self.came_from = {}

Initialise the closed list using a set.

In [None]:
        self.closed_set = set()

Given our new knowledge of minheaps they're the obvious choice for our open list.

In [None]:
        self.open_heap = []
        heapq.heappush(self.open_heap, (self.fs[self.xstart,self.ystart], self.start))

Now that everything's setup, we can move to the main part of the algorithm. We define a member function $\texttt{search()}$ which implements the astar algorithm. Search implements steps 3-6 in our workflow and returns the reconstructed path at the end.

In [None]:
    def search(self):
        while self.open_heap:
            f, current = heapq.heappop(self.open_heap)
            x,y = current
            
            if current in self.closed_set:
                continue
            
            self.closed_set.add(current)
            
            
            if current == self.end:
                return self.reconstruct_path()
            
            for neighbour in self.neighbours(current):
                nx, ny = neighbour
                possible_g = self.gs[x,y] + 1
                
                if possible_g < self.gs[nx,ny]:
                    self.came_from[neighbour] = current
                    self.gs[nx,ny] = possible_g
                    self.fs[nx,ny] = self.mu*possible_g + (1-self.mu)*self.hs[nx,ny]
                    
                    heapq.heappush(self.open_heap, (self.fs[nx,ny], (nx,ny)))
        return -1 #No path

The search function returns the reconstructed path. Here's how it does it:

In $\texttt{search()}$ there is a small line:
$\texttt{self.came_from[neighbour] = current}$
This adds an entry to the $\texttt{came_from}$ dictionary telling us which node we came from.

$\texttt{reconstruct_path}$ traces back from the end to the start noting how we got there and returns a list of tuples for each node in the path.

In [None]:
def reconstruct_path(self):
        path = []
        node = self.end
        while node in self.came_from:
            path.append(node)
            node = self.came_from[node]
        path.append(self.start)
        path.reverse()

        return path