In [None]:
import pandas as pd
import geopandas as gpd
import transbigdata as tbd
import warnings
warnings.filterwarnings('ignore')

import networkx as nx
from pandarallel import pandarallel
from itertools import islice
import numpy as np
import os
pandarallel.initialize(progress_bar=True)

def get_od_path(edge, node, station2node_dict, gdtollnode, beta=1.0, output_dir='output/', n_jobs=-1):
    """
    获取OD路径和长度信息的函数，并计算每条路径的选择概率。

    输入参数：
    - edge: 包含路段信息的DataFrame，包括 'u'（起点节点ID）、'v'（终点节点ID）、'length'（路段长度）、'edge_id'（路段ID）等列。
    - node: 包含节点信息的DataFrame，至少包括 'id'（节点ID）列。
    - station2node_dict: 字典，将站点ID映射到节点ID的关系。
    - gdtollnode: 包含站点信息的DataFrame，至少包括 'id' 列。
    - beta: 模型参数，用于计算路径选择概率。默认值为1.0。
    - output_dir: 输出文件的目录。
    - n_jobs: 并行处理的作业数量，默认值为-1，表示使用所有可用的CPU核心。

    输出结果：
    - None：路径和长度信息直接保存到文件中。
    """
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # 获取最短路径的函数（已优化，使用缓存）
    path_cache = {}  # 缓存已计算的最短路径

    def get_shortest_paths(start_station, end_station):
        if (start_station, end_station) not in path_cache:
            start_node = station2node_dict[start_station]
            end_node = station2node_dict[end_station]
            # 使用 shortest_simple_paths 获取最短的三条路径
            paths_generator = nx.shortest_simple_paths(G, source=start_node, target=end_node, weight='weight')
            shortest_paths = list(islice(paths_generator, 1))  # 获取前三条路径
            path_cache[(start_station, end_station)] = [list(map(int, path)) for path in shortest_paths]
        
        return path_cache[(start_station, end_station)]

    def save_paths_and_lengths(station_id_x, station_id_y, paths, beta):
        all_paths_details = []
        # 生成路径段详细信息
        for path_id, path in enumerate(paths):
            path_details = [{'u': int(u), 'v': int(v), 'id': i,
                            'station_id_x': station_id_x, 'station_id_y': station_id_y,
                            'path_id': path_id}
                            for i, (u, v) in enumerate(zip(path[:-1], path[1:]))
                            if u is not None and v is not None]
            all_paths_details.extend(path_details)
        
        od1_tmp = pd.DataFrame(all_paths_details)

        # 合并边的信息
        od1_tmp = pd.merge(od1_tmp, edge[['u', 'v', 'length', 'edge_id']], on=['u', 'v'])

        # 转换为整型并排序
        od1_tmp['length'] = od1_tmp['length'].astype(int)
        od1_tmp = od1_tmp.sort_values(by=['station_id_x', 'station_id_y', 'id'])

        # 计算累计长度
        od1_tmp['cumsumlength'] = od1_tmp.groupby(['station_id_x', 'station_id_y', 'path_id'])['length'].cumsum()

        od_dis_table = od1_tmp[['station_id_x', 'station_id_y', 'edge_id', 'cumsumlength', 'path_id']].sort_values(by = ['station_id_x', 'station_id_y', 'path_id'])
        
        # 计算每个OD的路径总长度
        od_length = od1_tmp.groupby(['station_id_x', 'station_id_y', 'path_id'])['cumsumlength'].max().reset_index()
        od_length.rename(columns={'cumsumlength': 'length'}, inplace=True)

        # 计算每条路径的选择概率
        od_length['exp_neg_beta_length'] = np.exp(-beta * od_length['length'])
        od_length['sum_exp_neg_beta_length'] = od_length.groupby(['station_id_x', 'station_id_y'])['exp_neg_beta_length'].transform('sum')
        od_length['probability'] = od_length['exp_neg_beta_length'] / od_length['sum_exp_neg_beta_length']
        
        # 准备合并所需的概率信息
        probabilities = od_length[['station_id_x', 'station_id_y', 'path_id', 'probability']]
        
        # 将选择概率合并到od_dis_table中
        od_dis_table_updated = pd.merge(
            od_dis_table,
            probabilities,
            on=['station_id_x', 'station_id_y', 'path_id'],
            how='left'
        )

        # 保存到CSV文件
        od_dis_table_updated.to_csv(f'{output_dir}od_dis_table_{station_id_x}_{station_id_y}.csv', index=False)
        od_length.to_csv(f'{output_dir}od_length_{station_id_x}_{station_id_y}.csv', index=False)

    G_edges = edge[['u','v','length']].values
    G_nodes = list(node['id'])

    # 先创建一个有向图
    G = nx.DiGraph()
    # 添加节点
    G.add_nodes_from(G_nodes) 
    # 添加边
    G.add_weighted_edges_from(G_edges)

    # 创建OD表
    o = gdtollnode[['id']]
    o.columns = ['station_id']
    o['flag'] = 1
    d = o.copy()
    od = pd.merge(o, d, on='flag')[['station_id_x', 'station_id_y']]
    od = od[od['station_id_x'] != od['station_id_y']].reset_index(drop=True)
    # 随机取1000个OD用于测试
    od = od.sample(n=10000, replace=True) 
    
    
    def process_od_pair(row):
        station_id_x = row['station_id_x']
        station_id_y = row['station_id_y']
        paths = get_shortest_paths(station_id_x, station_id_y)
        if len(paths) > 0:
            save_paths_and_lengths(station_id_x, station_id_y, paths, beta)
        
    
    print('获取OD的出行路径')
    Parallel(n_jobs=n_jobs)(delayed(process_od_pair)(row) for idx, row in od.iterrows())

    return None