In [None]:
# 设置matplotlib使用widget后端
%matplotlib widget

import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np
from matplotlib.font_manager import FontProperties
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import shapely.geometry as sgeom
from shapely.geometry import Point, LineString, MultiLineString
from shapely.ops import nearest_points

plt.rcParams['font.family'] = ['Microsoft YaHei']

##########################
# 1. 读取GIS数据函数
##########################
def load_gis_data(file_path, verbose=True):
    if not os.path.exists(file_path):
        print(f"错误: 文件不存在: {file_path}")
        return None
    data = gpd.read_file(file_path)
    return data

##########################
# 2. 配置文件路径函数
##########################
def configure_paths(root_path):
    paths = {
        'railways_file': os.path.join(root_path, 'gis_osm_railways_free_1.shp'),
        'transport_file': os.path.join(root_path, 'gis_osm_transport_a_free_1.shp'),
        'roads_file': os.path.join(root_path, 'gis_osm_roads_free_1.shp')
    }
    for key, path in paths.items():
        if not os.path.exists(path):
            print(f"警告: 文件不存在: {path}")
    return paths

##########################
# 3. 读取Excel文件中的线路信息
##########################
def load_route_data(excel_path, sheet_name="线路信息汇总"):
    try:
        df = pd.read_excel(excel_path, sheet_name=sheet_name)
        return df
    except Exception as e:
        print(f"读取Excel文件出错: {e}")
        return None

##########################
# 4. 处理数据
##########################
def process_route_data(df):
    required_columns = [
        'segment_id', 'point_id', 'longitude', 'latitude', 
        'osm_id', 'fclass', 'bridge', 'tunnel', '运行方向', '线路来源'
    ]
    for col in required_columns:
        if col not in df.columns:
            print(f"警告: 缺少必要的列 '{col}'")

    route_types = df['线路来源'].unique()
    route_data = {}
    for route_type in route_types:
        route_data[route_type] = df[df['线路来源'] == route_type].copy()
    
    return route_data

##########################
# 5. 创建线路的几何数据
##########################
def create_route_geometries(route_data):
    geometries = {}
    for route_type, df in route_data.items():
        line_segments = []
        segment_ids = df['segment_id'].unique()
        
        for segment_id in segment_ids:
            segment_df = df[df['segment_id'] == segment_id].sort_values('point_id')
            coords = segment_df[['longitude', 'latitude']].values
            if len(coords) >= 2:
                line = LineString(coords)
                first_row = segment_df.iloc[0]
                properties = {
                    'osm_id': first_row['osm_id'],
                    'fclass': first_row['fclass'],
                    'bridge': first_row['bridge'],
                    'tunnel': first_row['tunnel'],
                    'direction': first_row['运行方向'],
                    'route_type': route_type
                }
                line_segments.append((line, properties))
        geometries[route_type] = line_segments
    
    return geometries

# 计算两点间的距离
def calculate_distance(point1, point2):
    """
    使用GeoPandas和Shapely计算两点之间的地理距离（米）

    参数:
        point1 (list): 第一个点的坐标 [经度, 纬度]
        point2 (list): 第二个点的坐标 [经度, 纬度]

    返回:
        float: 两点间的距离（米）
    """

    # 创建GeoSeries，指定WGS84坐标系 (EPSG:4326)
    gdf = gpd.GeoSeries([Point(point1), Point(point2)], crs='EPSG:4326')
    
    # 将坐标投影到合适的投影坐标系（例如EPSG:3857，单位为米）
    gdf_proj = gdf.to_crs('EPSG:3857')
    
    # 计算投影后两个点之间的距离
    distance = gdf_proj[0].distance(gdf_proj[1])
    
    return distance

# 计算所有线路总长度的函数
def calculate_total_route_length(route_data):
    """
    计算所有线路的总长度（米）
    
    参数:
        route_data (dict): 包含各线路数据的字典
        
    返回:
        dict: 各线路类型的长度和总长度
    """
    length_stats = {
        'total': 0,
        'by_route_type': {},
        'by_structure': {
            'bridge': 0,
            'tunnel': 0,
            'normal': 0
        }
    }
    
    for route_type, df in route_data.items():
        route_length = 0
        
        # 按线段ID和点ID排序处理
        segments = df['segment_id'].unique()
        
        for segment in segments:
            segment_points = df[df['segment_id'] == segment].sort_values('point_id')
            
            prev_point = None
            for _, point in segment_points.iterrows():
                if prev_point is not None:
                    point1 = [prev_point['longitude'], prev_point['latitude']]
                    point2 = [point['longitude'], point['latitude']]
                    
                    dist_meters = calculate_distance(point1, point2)
                    route_length += dist_meters
                    
                    # 统计不同结构类型的长度
                    if point['bridge'] == 'T':
                        length_stats['by_structure']['bridge'] += dist_meters
                    elif point['tunnel'] == 'T':
                        length_stats['by_structure']['tunnel'] += dist_meters
                    else:
                        length_stats['by_structure']['normal'] += dist_meters
                
                prev_point = point
        
        length_stats['by_route_type'][route_type] = route_length
        length_stats['total'] += route_length
    
    return length_stats

##########################
# 6. 创建站点几何数据
##########################
def create_station_geometries(station_coords):
    station_geometries = {}
    for station_name, coords in station_coords.items():
        point = Point(coords[0], coords[1])
        station_geometries[station_name] = point
    return station_geometries

##########################
# 7. 跨线运行线路组合（上海铁路网概览） (Figure 1)
##########################
def visualize_route_overview(route_geometries, station_geometries, railways_data=None):
    """
    创建线路类型概览图，按线路来源区分
    """
    plt.figure(1, figsize=(16, 12), clear=True)
    ax = plt.gca()
    
    # 背景铁路网
    if railways_data is not None:
        railways_data.plot(ax=ax, color='lightgray', linewidth=0.8, alpha=0.5)
    
    # 线路颜色和线型映射
    route_colors = {
        '上海17号线': 'green',
        '市域铁机场联络线': 'blue',
        '沪苏湖高速铁路': 'red'
    }
    line_styles = {
        '上海17号线': '-',
        '市域铁机场联络线': '--',
        '沪苏湖高速铁路': '-.'
    }
    
    # 绘制各线路
    for route_type, segments in route_geometries.items():
        color = route_colors.get(route_type, 'gray')
        line_style = line_styles.get(route_type, '-')
        lines = [line for line, _ in segments]
        if lines:
            gdf = gpd.GeoDataFrame(geometry=lines)
            gdf.plot(ax=ax, color=color, linestyle=line_style, linewidth=2.5, label=f'{route_type}')
    
    # 站点圆圈标记大小缩小为原来一半 (原150 -> 75)
    # 部分站点需要额外调整偏移
    # - "国家会展中心" 与 "上海松江站" 再往左 0.05 (在原 -0.025 基础上再 -0.05 => -0.075)
    
    station_offsets = {
        '国家会展中心':   (-0.075, 0.0),   # 原先 -0.025 再减 0.05 => -0.075
        '上海松江站':    (-0.075, 0.0),
        '春申站':        (-0.025 - 0.04, 0.0),
        '上海虹桥客运站': (0.0,   0.025),
        '虹桥2号航站楼':  (0.025, 0.0),
        '中春路':        (0.025, 0.0)
    }
    
    for station_name, point in station_geometries.items():
        ax.scatter(point.x, point.y, s=75, color='white', edgecolor='black', 
                   marker='o', zorder=10)
        
        offset_x, offset_y = station_offsets.get(station_name, (0.02, 0.02))
        
        ax.annotate(
            station_name,
            xy=(point.x, point.y),
            xytext=(point.x + offset_x, point.y + offset_y),
            arrowprops=dict(arrowstyle='-', color='black', lw=1),  # 无箭头，黑色细线
            bbox=None,
            fontsize=12,
            weight='bold'
        )
    
    ax.set_title('跨线运行线路组合（上海铁路网概览）', fontsize=16, pad=20)
    ax.set_xlabel('经度', fontsize=14)
    ax.set_ylabel('纬度', fontsize=14)
    
    # 显示范围
    ax.set_xlim(121.1, 121.7)
    ax.set_ylim(30.95, 31.30)
    
    handles, labels = ax.get_legend_handles_labels()
    handles.append(plt.Line2D([0], [0], marker='o', color='w', 
                              markerfacecolor='white', markeredgecolor='black', 
                              markersize=8, label='站点'))
    ax.legend(handles=handles, loc='upper right')
    plt.tight_layout()
    
    return plt.gcf()

##########################
# 8. 跨线运行路径线路类型概览（局部细节图） (Figure 2)
##########################
def visualize_cross_line_structure(route_geometries, route_data, railways_data=None, station_geometries=None):
    """
    创建路段类型结构分布图，并保持纵横比
    """
    plt.figure(2, figsize=(16, 12), clear=True)
    ax = plt.gca()
    
    # 背景铁路网
    if railways_data is not None:
        railways_data.plot(ax=ax, color='lightgray', linewidth=0.8, alpha=0.5)
    
    # 分离不同类型的线段
    bridge_lines = []
    tunnel_lines = []
    normal_lines = []
    
    for route_type, segments in route_geometries.items():
        for line, properties in segments:
            if properties['bridge'] == 'T':
                bridge_lines.append(line)
            elif properties['tunnel'] == 'T':
                tunnel_lines.append(line)
            else:
                normal_lines.append(line)
    
    # 计算统计
    total_points = 0
    bridge_points = 0
    tunnel_points = 0
    for route_type, df in route_data.items():
        total_points += len(df)
        bridge_points += len(df[df['bridge'] == 'T'])
        tunnel_points += len(df[df['tunnel'] == 'T'])
    normal_points = total_points - bridge_points - tunnel_points
    
    # 计算线路总长度
    length_stats = calculate_total_route_length(route_data)
    total_length_km = length_stats['total'] / 1000  # 转换为公里
    bridge_length_km = length_stats['by_structure']['bridge'] / 1000
    tunnel_length_km = length_stats['by_structure']['tunnel'] / 1000
    normal_length_km = length_stats['by_structure']['normal'] / 1000
    
    # 各线路类型长度
    route_lengths = {route: length / 1000 for route, length in length_stats['by_route_type'].items()}
    
    # 绘制
    if normal_lines:
        normal_gdf = gpd.GeoDataFrame(geometry=normal_lines)
        normal_gdf.plot(ax=ax, color='blue', linewidth=2.5, label='地面段')
    if bridge_lines:
        bridge_gdf = gpd.GeoDataFrame(geometry=bridge_lines)
        bridge_gdf.plot(ax=ax, color='red', linewidth=2.5, label='高架桥段')
    if tunnel_lines:
        tunnel_gdf = gpd.GeoDataFrame(geometry=tunnel_lines)
        tunnel_gdf.plot(ax=ax, color='green', linewidth=2.5, label='地下隧道段')
    
    # 站点标记（同理可定制偏移）
    if station_geometries:
        station_offsets = {
            '国家会展中心':   (-0.02 - 0.02, 0.0),
            '上海松江站':    (-0.0125 - 0.02, 0.0),
            '春申站':        (-0.025,  0.0),
            '上海虹桥客运站': (0.0,    0.025),
            '虹桥2号航站楼':  (0.025,  0.0),
            '中春路':        (0.025,  0.0)
        }
        
        for station_name, point in station_geometries.items():
            # 可以继续用缩小后的圆圈 s=75
            ax.scatter(point.x, point.y, s=75, color='white', edgecolor='black',
                       marker='o', zorder=10)
            offset_x, offset_y = station_offsets.get(station_name, (0.02, 0.02))
            ax.annotate(
                station_name,
                xy=(point.x, point.y),
                xytext=(point.x + offset_x, point.y + offset_y),
                arrowprops=dict(arrowstyle='-', color='black', lw=1),
                bbox=None,
                fontsize=12,
                weight='bold'
            )
    
    # 计算百分比
    bridge_pct = bridge_points / total_points * 100 if total_points > 0 else 0
    tunnel_pct = tunnel_points / total_points * 100 if total_points > 0 else 0
    normal_pct = normal_points / total_points * 100 if total_points > 0 else 0
    
    bridge_length_pct = bridge_length_km / total_length_km * 100 if total_length_km > 0 else 0
    tunnel_length_pct = tunnel_length_km / total_length_km * 100 if total_length_km > 0 else 0
    normal_length_pct = normal_length_km / total_length_km * 100 if total_length_km > 0 else 0
    
    # 更新统计信息文本框，增加长度信息
    textstr = '\n'.join([
        f"线路构成统计:",
        f"总点数: {total_points}",
        f"地面段: {normal_points}点 ({normal_pct:.1f}%)",
        f"高架桥段: {bridge_points}点 ({bridge_pct:.1f}%)",
        f"地下隧道段: {tunnel_points}点 ({tunnel_pct:.1f}%)",
        f"\n线路长度统计 (km):",
        f"总长度: {total_length_km:.2f} km",
        f"地面段: {normal_length_km:.2f} km ({normal_length_pct:.1f}%)",
        f"高架桥段: {bridge_length_km:.2f} km ({bridge_length_pct:.1f}%)",
        f"地下隧道段: {tunnel_length_km:.2f} km ({tunnel_length_pct:.1f}%)"
    ])
    
    # 线路类型长度信息
    route_length_text = '\n'.join([f"{route}: {length:.2f} km" for route, length in route_lengths.items()])
    textstr2 = f"各线路长度:\n{route_length_text}"
    
    # 主统计信息框
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.7)
    ax.text(0.1, 0.55, textstr, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props)
    
    # 线路长度信息框（右下角）
    props2 = dict(boxstyle='round', facecolor='lightblue', alpha=0.7)
    ax.text(0.6, 0.25, textstr2, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props2)
    
    ax.set_title('跨线运行路径线路类型概览（局部细节图）', fontsize=16, pad=20)
    ax.set_xlabel('经度', fontsize=14)
    ax.set_ylabel('纬度', fontsize=14)
    
    # 修改后新的范围
    ax.set_xlim(121.1846, 121.3892)
    ax.set_ylim(30.9568, 31.25)
    
    # 保持纵横比
    ax.set_aspect('equal', 'box')
    
    ax.legend(loc='upper right')
    plt.tight_layout()
    return plt.gcf()

##########################
# 9. 主函数
##########################
def main():
    excel_path = r"E:\ResearchDocuments\ROS2WithSPCK\docs\线路规划.xlsx"
    shp_root_path = r"E:\ResearchDocuments\ROS2WithSPCK\GIS_Track\路网GIS数据集\shanghai-latest-free.shp"
    
    station_coords = {
        '上海虹桥客运站': [121.3133074, 31.1959353],
        '国家会展中心':   [121.2882604, 31.1935644],
        '虹桥2号航站楼':  [121.3199359, 31.1961374],
        '中春路':        [121.3299336, 31.1519780],
        '春申站':        [121.3499051, 31.0797510],
        '上海松江站':    [121.2258998, 30.9885458]
    }
    
    paths = configure_paths(shp_root_path)
    railways_data = load_gis_data(paths['railways_file'], verbose=False)
    route_df = load_route_data(excel_path, sheet_name="线路信息汇总")
    if route_df is None:
        return
    
    route_data = process_route_data(route_df)
    route_geometries = create_route_geometries(route_data)
    station_geometries = create_station_geometries(station_coords)
    
    # 概览图 (Figure 1)
    fig_overview = visualize_route_overview(
        route_geometries, 
        station_geometries, 
        railways_data
    )
    # 结构分布图 (Figure 2)
    fig_structure = visualize_cross_line_structure(
        route_geometries, 
        route_data, 
        railways_data=railways_data,
        station_geometries=station_geometries
    )
    
    plt.show()

if __name__ == "__main__":
    main()