# Dijkstra

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [2]:
# ax to plot
fig, ax = plt.subplots()
ax.axis("equal")
plt.close()

In [3]:
class Node():
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.occupied = False
        self.cost = float("inf")
        self.visited = False
        self.prev = self

    def __lt__(self, other):
        return self.cost < other.cost
    
    def __repr__(self) -> str:
        return "Node (%d, %d) : %f" % (self.x, self.y, self.cost)


In [4]:
class GridMap():

    MAP_SIZE_X = 20
    MAP_SIZE_Y = 20

    def __init__(self):
        self.size_x = GridMap.MAP_SIZE_X
        self.size_y = GridMap.MAP_SIZE_Y
        print("Map size %d %d" % (self.size_x, self.size_y))

        self.map = np.empty((self.size_x, self.size_y), dtype=object)

        for x in range(self.size_x):
            for y in range(self.size_y):
                self.map[x,y] = Node(x,y)

        for x in range(self.size_x):
            self.map[x, 0].occupied = True
            self.map[x, self.size_y-1].occupied = True

        for y in range(self.size_y):
            self.map[0, y].occupied = True
            self.map[self.size_x-1, y].occupied = True

        for y in range(int(self.size_y/2)):
            self.map[int(self.size_y/3), y].occupied = True

        for y in range(self.size_y-1, int(self.size_y/3), -1):
            self.map[2*int(self.size_y/3), y].occupied = True

    def plot(self, ax):
        ax.axis((-1,self.size_x, -1, self.size_y))
        for x in range(self.map.shape[0]):
            for y in range(self.map.shape[1]):
                if self.map[x,y].occupied:
                    ax.plot(x,y, ".k")
    
    def shape(self):
        return (self.size_x, self.size_y)
    
    def getnode(self, x, y):
        return self.map[x, y]
    
    def get_unvisited_neighbors(self, node):
        # print(node)
        for dx in range(-1,2):
            for dy in range(-1,2):
                if dx == 0 and dy == 0:
                    continue
                nb = self.map[node.x + dx, node.y + dy]
                if nb.occupied:
                    continue
                if nb.visited:
                    continue
                yield (nb, np.sqrt(dx**2 + dy**2))

    def clear(self):
        for x in range(self.size_x):
            for y in range(self.size_y):
                self.map[x,y].visited = False
                self.map[x,y].cost = float('inf')
    

In [5]:
gridmap = GridMap()
ax.clear()
def animate(i):
    gridmap.plot(ax)

ani = animation.FuncAnimation(fig, animate, frames=1)
from IPython.display import HTML
HTML(ani.to_jshtml())

Map size 20 20


In [6]:
START = (3, 3)
STOP  = (16, 16)

ax.clear()
def animate(i):
    gridmap.plot(ax)
    ax.plot(*START, "og")
    ax.plot(*STOP, "xb")

ani = animation.FuncAnimation(fig, animate, frames=1)
from IPython.display import HTML
HTML(ani.to_jshtml())

In [7]:
import heapq

class Dijkstra():
    def __init__(self, gridmap, start, stop):
        self.gridmap = gridmap
        self.start = start
        self.stop = stop
        self.gridmap.clear()

        self.node_start = self.gridmap.getnode(*start)
        self.node_stop = self.gridmap.getnode(*stop)

        self.node_start.cost = 0
        self.node_start.visited = True
        self.nodes = [self.node_start]
        heapq.heapify(self.nodes)

        self.loop_count = 0

    def run(self, ax):
        # print("----")
        # print("Loop %d" % self.loop_count)
        self.loop_count += 1

        if len(self.nodes) == 0:
            print("len(self.nodes) == 0")
            return False

        node = heapq.heappop(self.nodes)
        ax.plot(node.x, node.y, "xg")

        if node == self.node_stop:
            print("Stop")
            return True

        for nb, cost in self.gridmap.get_unvisited_neighbors(node):
            # print("%s %f" % (nb, cost))
            if node.cost + cost < nb.cost:
                nb.cost = node.cost + cost 
                nb.prev = node
                nb.visited = True
                heapq.heappush(self.nodes, nb)
                ax.plot(nb.x, nb.y, "xc")

In [8]:
print(START)
print(STOP)
dijkstra = Dijkstra(gridmap, START, STOP)

ax.clear()
stop = False
path = []

def animate(i):
    global stop, path
    if i == 0:
        gridmap.plot(ax)
        ax.plot(*START, "og")
        ax.plot(*STOP, "xb")

    if not stop and dijkstra.run(ax):
        stop = True
        if len(path) == 0:
            node = dijkstra.node_stop
            while True:
                # print((node.x, node.y))
                path.append((node.x, node.y))
                if node == dijkstra.node_start:
                    break
                if node == node.prev:
                    break
                node = node.prev
            print(path)
            x, y = zip(*path)
            ax.plot(x,y)

ani = animation.FuncAnimation(fig, animate, frames=300, repeat=True, interval=30)
from IPython.display import HTML
HTML(ani.to_jshtml())

(3, 3)
(16, 16)
Stop
[(16, 16), (15, 15), (14, 14), (13, 13), (13, 12), (13, 11), (13, 10), (13, 9), (13, 8), (13, 7), (12, 6), (11, 7), (10, 8), (9, 9), (8, 10), (7, 10), (6, 10), (5, 9), (4, 8), (3, 7), (3, 6), (3, 5), (3, 4), (3, 3)]
