In [None]:
'''
* [分享一个流水穿透工具](https://mp.weixin.qq.com/s/_jGlWMwvfFennjx2PTqsCg)

'''
# -*- coding: utf-8 -*-
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
from pyvis.network import Network

def cmap_to_hex(cmap, value):
    rgba = cmap(value)
    r, g, b, a = int(rgba[0]*255), int(rgba[1]*255), int(rgba[2]*255), int(rgba[3]*255)
    return f"#{r:02x}{g:02x}{b:02x}"

def is_time_ordered(date_list):
    for k in range(0, len(date_list) - 1):
        if date_list[k] > date_list[k + 1]:
            return False
    return True

df = pd.read_excel('资金穿透模板.xlsm','数据源')
#-----------------------------------------------------------------------------------------
G = nx.MultiDiGraph()  # 创建多边有向图
grouped = df.groupby(['本方账户', '对方账户', '日期'])['支出'].sum()  # 汇总
grouped2 = df.groupby(['本方账户', '对方账户', '日期'])['收入'].sum()  # 汇总

max_value = grouped.max()
min_value = grouped.min()  # 获取最小值
weight_factor = 3 / (max_value - min_value)  # 假设边的最大宽度为 3

for index, value in grouped.items():
    source, target, dated = index  # 解包得到分组的键
    date = dated.strftime('%Y-%m-%d')
    if value != 0:
        weight = int(value * weight_factor)  # 根据金额计算边的粗细
        amount_str = "{:,.2f}".format(value)  # 显示千分符
        G.add_edge(str(source), str(target), amount=amount_str,
                   prev=str(source) + '→' + str(target) + ':' + str(amount_str) + ',时间:' + date,
                   weight=weight, date=date)


for index, value in grouped2.items():
    target, source, dated = index  # 解包得到分组的键
    date = dated.strftime('%Y-%m-%d')
    if value != 0:
        weight = int(value * weight_factor)  # 根据金额计算边的粗细
        amount_str = "{:,.2f}".format(value)  # 显示千分符
        G.add_edge(str(source), str(target), amount=amount_str,
                   prev=str(source) + '→' + str(target) + ':' + str(amount_str) + ',时间:' + date,
                   weight=weight, date=date)


nodesize = dict(G.out_degree)  # 节点大小
max_Ns = max(nodesize.values())
node_sizes = {}
node_size_scale = 800

edgecount = nx.pagerank(G, alpha=0.85)  # 颜色深浅
cmap = plt.get_cmap('YlOrRd')

net = Network(height='800px', width='100%', directed=True, notebook=True, filter_menu=True
              )  # filter_menu=True  notebook

# 拿到所有路径的列表
path_list = []
for node in G.nodes:
    try:
        paths = nx.shortest_path(G, source=node)
        path_list.extend(paths.values())
    except:
        pass

# 将包含该节点的穿透路径（考虑时间）添加title信息
for node in G.nodes:
    path_node_list = []  # 包含该节点的所有路径
    for path in path_list:
        if node in path:
            path_node_list.append(path)

    # 拿了该节点所有路径以后，判断时间，加载到title
    node_path_str = []
    for i in range(0, len(path_node_list) - 1):

        if len(path_node_list[i]) > 1:  # 如果子列表有2个元素，才开始遍历，判断时间
            date_list = []
            for j in range(0, len(path_node_list[i]) - 1):
                st = path_node_list[i][j]
                ed = path_node_list[i][j + 1]
                edge_data = G.get_edge_data(st, ed)  # 拿到所有两个节点的路径

                edge_minindex = min(edge_data, key=lambda x: edge_data[x]['date'])
                min_date = edge_data[edge_minindex]['date']  # 拿到所有路径中的最小日期
                date_list.append(min_date)

            if len(date_list) == 1:  # 如果时间列表只有一个，那么只有2个节点，不用判断时间对穿透的影响，直接添加
                node_path_str.append(path_node_list[i])
            else:
                if is_time_ordered(date_list):  # 如果时间列表超过一个，那么从前向后判断时间是否小于后者，满足要求再添加
                    node_path_str.append(path_node_list[i])

    node_path_title = '\n'.join([' -> '.join(map(str, node_path)) for node_path in node_path_str])
    net.add_node(node, value=nodesize[node] / max_Ns * node_size_scale,
                 color=cmap_to_hex(cmap, edgecount[node] / max(edgecount.values())),
                 alpha=0.8, label=node, title=node_path_title)

for u, v, d in G.edges(data=True):
    # 添加边，并设置title属性
    net.add_edge(u, v, title=str(d['prev']), width=d['weight'], color='black',
                 label=d['date'] + ":" + d['amount'])  # ,label=d['date']

net.force_atlas_2based(spring_length=500, overlap=1, spring_strength=0.001)
net.show_buttons(filter_=['physics'])  # filter_=['physics']
#---------------------------------------------------------
net.show('资金穿透.html')