# Dijkstra算法

参考：https://blog.csdn.net/qq_42730750/article/details/108700328
绘图参考：https://plotly.com/python/network-graphs/

In [63]:
import numpy as np
import sys
import pandas as pd
from tqdm.notebook import tqdm

class Graph():
    '''
    v1.0 无向图，实现dijkstra算法求解图最短路径
    v1.1 增加邻接矩阵G_matrix参数输入和dijkstra_matrix输出方法
    '''
    
    def __init__(self, G_info=None, G_matrix=None):
        """
        1. 图信息
        G_info = [['A', 'B', 4],
                    ['A', 'C', 6],
                    ['A', 'D', 6],
                    ['B', 'C', 1],
                    ['B', 'E', 7],
                    ['C', 'D', 2],
                    ['C', 'E', 6],    
                    ['C', 'F', 4],
                    ['D', 'F', 5],
                    ['E', 'F', 1],
                    ['E', 'G', 6],
                    ['F', 'G', 8]]
        2. 邻接矩阵
        G_matrix 
        """
        self.G_info = G_info
        self.G_matrix = G_matrix
        if self.G_matrix is not None:
            self.G_info = self.G_info2G_matrix()
        
        
    def G_info2G_matrix(self):
        """
        获得G_info
        """
        G_info = []
        for i in range(len(self.G_matrix)):
            for j in range(i+1, len(self.G_matrix)):
                G_info.append([self.G_matrix.columns[i], self.G_matrix.columns[j], self.G_matrix.iloc[i,j]])
        return G_info
    
                
    # ------------------------------------------------------------------
    def get_G_matrix(self):
        """
        获取G_matrix
        """
        nodes = self.get_all_nodes()
        self.G_matrix = pd.DataFrame(np.nan, columns=nodes, index=nodes)
        for i in self.G_info:
            self.G_matrix.loc[i[0], i[1]] = i[2]
            self.G_matrix.loc[i[1], i[0]] = i[2]
        return self.G_matrix

    # ------------------------------------------------------------------
    def set_G_info(self, G_info):
        """
        设置图信息
        """
        self.G_info = G_info

    # ------------------------------------------------------------------
    def get_G_info(self):
        """
        获取图信息
        """
        return self.G_info
    
    # ------------------------------------------------------------------
    def remove_G_info(self, node, inplace=False):
        """
        从图中移除指定节点
        如果inplace等于True，则直接在G_info上移除
        """
        res = []
        for i in self.G_info:
            if node not in i:
                res.append(i)
        if inplace==True: self.G_info = res
        return res
        
    # ------------------------------------------------------------------
    def get_all_nodes(self):
        """
        获取图的所有节点，类型list
        """
        nodes = set()
        for i in self.G_info:
            nodes.add(i[0])
            nodes.add(i[1])
        nodes = list(nodes)
        nodes.sort()
        return nodes
    
    # ------------------------------------------------------------------
    def get_node_info(self, node):
        """
        返回指定节点的连接信息
        """
        res = []
        for i in self.G_info:
            if i[0] == node or i[1] == node:
                res.append(i)
        return res
    
    # ------------------------------------------------------------------
    def get_distance(self, start_node, end_node):
        """
        给定两个节点，返回其直接相连的距离，如果两者并不直接相连，返回nan
        """
        if start_node == end_node:
            return 0
        for i in self.G_info:
            if start_node in i and end_node in i:
                return i[2]
        return np.nan
    
    # ------------------------------------------------------------------
    def get_nearest_node(self, node):
        """
        给定节点node，获得离其最近的节点
        >>> G.get_nearest_node('D')
        ['C', 'D', 3]
        """
        res = []
        max_d = 0
        info = self.get_node_info(node)
        for i in info:
            if i[2]>max_d: 
                max_d = i[2] 
        for i in info:
            if i[2]<max_d: 
                max_d = i[2]
                res = i
        return res
  

    # ------------------------------------------------------------------
    def _dijkstra_iterator(self, U, S, path):
        """
        Dijkstra算法中的递归循环
        """
        # -----------------------------------
        # 遍历U中所有节点
        # 计算U中每个节点到S中已确定最短距离节点的距离
        for j in U.keys():
            tmp = sys.maxsize
            for i in S.keys():
                tmp_ = self.get_distance(i, j)+S[i]
                if tmp_<tmp: 
                    U[j] = tmp_ 
                    tmp = tmp_
                    path[j]=i
        # ------------------------------------
        # 排序选择最短节点，并记录 
        uu = pd.Series(U).sort_values()
        S[uu.index[0]]= uu.values[0]
        U.pop(uu.index[0])
        if len(U)==0: return
        self._dijkstra_iterator(U, S, path)


    # ------------------------------------------------------------------
    def dijkstra_algo(self, start_node, end_node=None):
        """
        Dijkstra算法
        给定初始node节点，返回其与其他所有节点的最短路径和距离
        当指定end_node时，返回start_node和end_node的最短路径和距离
        """
        # -------------------------------------
        # 初始化U,S,path
        # U存储未确定最短路径的节点和其与指定start_node的距离
        # S存储已经确定离start_node最短路径的节点和距离
        # path存储S中节点的最短路径
        all_nodes = self.get_all_nodes()
        U={}
        for i in all_nodes:
            if i != start_node:
                U[i] = self.get_distance(start_node, i)
        S = {start_node:0}
        path={}
        self._dijkstra_iterator(U,S,path)
        if end_node is None: 
            return S, path
        # ---------------------------------------
        # 指定end_node，返回node和end_node之间的最短路径
        shortest_path = [end_node]
        if end_node is not None:
            while True:
                try:
                    previous_node = path[end_node] 
                    end_node = previous_node
                    shortest_path.insert(0, previous_node)
                except KeyError:
                    break
            return S, shortest_path

    # ------------------------------------------------------------------
    def get_dijkstra_matrix(self):
        """
        对所有节点求解最短路径，返回最短路径矩阵
        节点i和j的最短路径应用dijkstra算法
        """
        nodes = self.get_all_nodes()
        dijkstra_matrix = pd.DataFrame(np.nan, columns=nodes, index=nodes)
        for i in tqdm(nodes, desc='dijkstra_algo'):
            S, shortest_path = self.dijkstra_algo(start_node=i)
            for s in S:
                dijkstra_matrix.loc[i, s] = S[s]
        return dijkstra_matrix

## 测试1

![测试1](测试1.png)

In [64]:
G_info = [['A', 'B', 12],
            ['A', 'F', 16],
            ['A', 'G', 14],
            ['B', 'C', 10],
            ['C', 'D', 3],
            ['C', 'E', 5],
            ['C', 'F', 6],
            ['D', 'E', 4],    # DE之间调成9
            ['E', 'F', 2],
            ['E', 'G', 8],
            ['F', 'G', 9]]
G = Graph(G_info)
G.get_all_nodes()
G.get_G_matrix()
G.get_node_info('D')
G.get_distance('E', 'D')
G.get_nearest_node('C')
G.dijkstra_algo(start_node='D', end_node='F')

({'D': 0, 'C': 3.0, 'E': 4.0, 'F': 6.0, 'G': 12.0, 'B': 13.0, 'A': 22.0},
 ['D', 'E', 'F'])

In [65]:
m = G.get_G_matrix()

In [67]:
G = Graph(G_matrix=m)

In [68]:
G.get_dijkstra_matrix()

dijkstra_algo:   0%|          | 0/7 [00:00<?, ?it/s]

Unnamed: 0,A,B,C,D,E,F,G
A,0.0,12.0,22.0,22.0,18.0,16.0,14.0
B,12.0,0.0,10.0,13.0,15.0,16.0,23.0
C,22.0,10.0,0.0,3.0,5.0,6.0,13.0
D,22.0,13.0,3.0,0.0,4.0,6.0,12.0
E,18.0,15.0,5.0,4.0,0.0,2.0,8.0
F,16.0,16.0,6.0,6.0,2.0,0.0,9.0
G,14.0,23.0,13.0,12.0,8.0,9.0,0.0


In [190]:
import pandas as pd
D = pd.read_csv('../adjacent_matrix.csv', index_col=0)
G = Graph(G_matrix=D)
res = G.get_dijkstra_matrix()

dijkstra_algo: 100%|██████████| 24/24 [00:20<00:00,  1.15it/s]


In [227]:
res[D.columns].loc[D.index]

Unnamed: 0,俄罗斯RTS,奥地利ATX,道琼斯工业平均,标普500,台湾加权,恒生指数,深证成指,沪深300,富时新加坡海峡时报,韩国综合指数,...,西班牙IBEX35,英国富时100,比利时BFX,上证指数,澳大利亚标普200指数,纳斯达克,墨西哥MXX,孟买Sensex30,日经225,瑞士SMI
俄罗斯RTS,0.0,0.913329,1.826116,1.830532,,,,,,,...,0.999823,0.912156,0.982807,,,1.87241,1.884527,,,1.520197
奥地利ATX,0.913329,0.0,0.997065,1.220548,,,,,,,...,0.700123,0.711322,0.67167,,,1.446754,1.611748,,,0.803534
道琼斯工业平均,1.826116,0.997065,0.0,0.223484,,,,,,,...,0.946975,0.919426,0.922912,,,0.449689,0.868304,,,0.966676
标普500,1.830532,1.220548,0.223484,0.0,,,,,,,...,0.955278,0.927172,0.928884,,,0.311484,0.848189,,,0.964931
台湾加权,,,,,0.0,0.904464,,,0.925313,0.817731,...,,,,,0.967976,,,1.877083,0.972565,
恒生指数,,,,,0.904464,0.0,,,0.786644,0.852915,...,,,,,0.945438,,,0.987217,0.941867,
深证成指,,,,,,,0.0,0.31284,,,...,,,,0.380368,,,,,,
沪深300,,,,,,,0.31284,0.0,,,...,,,,0.211989,,,,,,
富时新加坡海峡时报,,,,,0.925313,0.786644,,,0.0,0.866435,...,,,,,0.938335,,,0.951771,0.956733,
韩国综合指数,,,,,0.817731,0.852915,,,0.866435,0.0,...,,,,,0.927687,,,1.818206,0.891221,


In [122]:
import networkx as nx
G = nx.Graph()
for i in m.columns:
    G.add_node(i)
for i in G_info:
    G.add_edge(i[0], i[1], weight=i[2])
# pos = nx.shell_layout(G)
# nx.draw(G, pos, 
#         with_labels=True,
#         node_color='white',
#         edge_color='red',
#        node_size=400,
#        alpha=0.5)
# pylab.title('network', fontsize=15)
# pylab.show()

In [189]:
(tau.fillna(0) != D.fillna(0)).sum()

俄罗斯RTS         0
奥地利ATX         0
道琼斯工业平均        0
标普500          0
台湾加权           0
恒生指数           0
深证成指           0
沪深300          0
富时新加坡海峡时报      0
韩国综合指数         0
荷兰AEX          0
德国DAX          0
富时意大利MIB       0
法国CAC40        0
西班牙IBEX35      0
英国富时100        0
比利时BFX         0
上证指数           0
澳大利亚标普200指数    0
纳斯达克           0
墨西哥MXX         0
孟买Sensex30     0
日经225          0
瑞士SMI          0
dtype: int64

In [160]:
(tau != D).sum()

俄罗斯RTS         15
奥地利ATX         13
道琼斯工业平均        11
标普500          12
台湾加权           18
恒生指数           17
深证成指           21
沪深300          21
富时新加坡海峡时报      17
韩国综合指数         18
荷兰AEX          10
德国DAX          10
富时意大利MIB       12
法国CAC40        10
西班牙IBEX35      12
英国富时100        10
比利时BFX         10
上证指数           21
澳大利亚标普200指数    18
纳斯达克           15
墨西哥MXX         15
孟买Sensex30     21
日经225          18
瑞士SMI          13
dtype: int64

In [158]:
row = np.array([0,0,0,1,2,3,6])
col = np.array([1,2,3,4,5,6,7])
value = np.array([1,2,1,8,1,3,5])
G = nx.Graph()
for i in range(np.size(row)):
    G.add_weighted_edges_from([(row[i], col[i], value[i])])
nx.dijkstra_path_length(G, source=7, target=3)

8

<bound method Graph.adjacency of <networkx.classes.graph.Graph object at 0x7fdd2032e7c0>>

In [91]:
for i in p:
    print(i)

TypeError: 'int' object is not iterable

## 测试2

![测试2](测试2.png)

In [None]:
G_info = [['A', 'B', 4],
        ['A', 'C', 6],
        ['A', 'D', 6],
        ['B', 'C', 1],
        ['B', 'E', 7],
        ['C', 'D', 2],
        ['C', 'E', 6],    
        ['C', 'F', 4],
        ['D', 'F', 5],
        ['E', 'F', 1],
        ['E', 'G', 6],
        ['F', 'G', 8]]
G = Graph(G_info)
print(G.dijkstra_algo(start_node='A', end_node='G'))