# Economic 多场景空间图

## 0 Load data

In [23]:
import sys
import os
from pathlib import Path
import pandas as pd
import geopandas as gpd
import numpy as np
from shapely.geometry import Point
from matplotlib.patches import FancyArrowPatch
def find_project_root(start_path=None):
    """查找项目根目录（包含data和function目录的目录）"""
    if start_path is None:
        start_path = Path.cwd()
    
    current = Path(start_path).resolve()
    
    # 向上查找，直到找到包含data和function目录的目录
    for _ in range(5):  # 最多向上查找5层
        if (current / 'data').exists() and (current / 'function').exists():
            return current
        parent = current.parent
        if parent == current:  # 到达根目录
            break
        current = parent
    
    # 如果找不到，假设当前目录的父目录是项目根目录
    return Path.cwd().parent

project_root = find_project_root()

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

DATA_PATH = project_root / 'data'

print(f"项目根目录: {project_root}")
print(f"数据路径: {DATA_PATH}")



项目根目录: C:\Dev\Landuse_Zhong_clean
数据路径: C:\Dev\Landuse_Zhong_clean\data


In [24]:
df_economic = pd.read_csv(DATA_PATH / 'US_data/df_economic.csv')
us_nation = gpd.read_file(DATA_PATH / r'US_data\cb_2018_us_nation_5m.shp')
us_states = gpd.read_file(DATA_PATH / r'cb_2018_us_state_500k.shp')
us_nation_4326 = us_nation.to_crs('EPSG:4326')
us_states_4326 = us_states.to_crs('EPSG:4326')


In [25]:
# 加载DEM数据
import rioxarray
from shapely.geometry import mapping

try:
    dem = rioxarray.open_rasterio(DATA_PATH.parent / 'figure' / 'draw_shp' / 'DEM.nc', masked=True)
    if dem.rio.crs is None:
        dem = dem.rio.write_crs("EPSG:4326")
    # 裁剪DEM到美国范围
    dem_us = dem.rio.clip(us_nation.geometry.apply(mapping), us_nation.crs, drop=True)
    print("DEM数据加载成功")
except Exception as e:
    print(f"DEM数据加载失败: {e}")
    dem_us = None


DEM数据加载成功


In [27]:
df_economic

Unnamed: 0,lat,lon,pv_category,pv_model,pv_scenario,policy_category,rcp_category,net_npv_usd,net_cost_usd,net_npv_usd_demand,analysis_year
0,33.595833,-117.587500,C1,REMIND 2.1,CEMICS_GDPgrowth_1p5,P2a,RCP2.6,-2.389517e+06,2.389517e+06,0.000000e+00,2020
1,33.629166,-117.579170,C1,REMIND 2.1,CEMICS_GDPgrowth_1p5,P2a,RCP2.6,-2.389462e+06,2.389462e+06,0.000000e+00,2020
2,33.720833,-117.737500,C1,REMIND 2.1,CEMICS_GDPgrowth_1p5,P2a,RCP2.6,-2.389492e+06,2.389492e+06,0.000000e+00,2020
3,33.904167,-117.820830,C1,REMIND 2.1,CEMICS_GDPgrowth_1p5,P2a,RCP2.6,-2.390356e+06,2.390356e+06,0.000000e+00,2020
4,33.920834,-117.620834,C1,REMIND 2.1,CEMICS_GDPgrowth_1p5,P2a,RCP2.6,-2.390268e+06,2.390268e+06,0.000000e+00,2020
...,...,...,...,...,...,...,...,...,...,...,...
74557215,47.262500,-68.379166,C7,TIAM-ECN 1.1,EN_NPi2100_COV,P1b,RCP8.5,4.192593e+06,4.705003e+06,-4.692993e+06,2050
74557216,47.262500,-68.370834,C7,TIAM-ECN 1.1,EN_NPi2100_COV,P1b,RCP8.5,4.216429e+06,4.705007e+06,-4.692929e+06,2050
74557217,47.270832,-68.387500,C7,TIAM-ECN 1.1,EN_NPi2100_COV,P1b,RCP8.5,4.146022e+06,4.705002e+06,-4.693090e+06,2050
74557218,47.270832,-68.370834,C7,TIAM-ECN 1.1,EN_NPi2100_COV,P1b,RCP8.5,4.144691e+06,4.705003e+06,-4.693083e+06,2050


In [28]:
df_economic.replace({'policy_category': {'P2': 'P2b'}}, inplace=True)
df_economic['policy_category'].unique()

array(['P2a', 'P1b', 'P1a', 'P4', 'P2c', 'P1c', 'P2b', 'P1d', 'P3b',
       'P3c', 'P3a'], dtype=object)

## 1 Help functions

In [144]:
## 1. Helper Functions

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import geopandas as gpd
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, LinearSegmentedColormap
from pyproj import Transformer
from shapely.geometry import box, LineString
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from shapely.ops import unary_union
try:
    from shapely import make_valid
except Exception:
    make_valid = None
import matplotlib as mpl
import matplotlib.ticker as mticker

def _render_terrain(ax, dem_data, usa_bounds_main, transformer):
    """渲染地形"""
    qm = None
    dem_lon = dem_data.x.values
    dem_lat = dem_data.y.values
    lon_mask = (dem_lon >= usa_bounds_main['lon_min']) & (dem_lon <= usa_bounds_main['lon_max'])
    lat_mask = (dem_lat >= usa_bounds_main['lat_min']) & (dem_lat <= usa_bounds_main['lat_max'])
    
    if lon_mask.any() and lat_mask.any():
        lon_idx = np.where(lon_mask)[0]
        lat_idx = np.where(lat_mask)[0]
        
        # 抽稀栅格
        step_x = max(1, len(lon_idx)//1800)
        step_y = max(1, len(lat_idx)//900)
        lon_sub = lon_idx[::step_x]
        lat_sub = lat_idx[::step_y]
        elev = dem_data.squeeze().values[np.ix_(lat_sub, lon_sub)]

        mask = ~np.isnan(elev)
        if mask.any():
            lo = np.nanpercentile(elev, 35)
            elev[elev < lo] = np.nan

            # 地形配色
            terrain_colors = ListedColormap(["#ffffff", "#eef3ef", "#dfe8e0",
                                           "#d2ddcf", "#c7d1c1", "#b7c2af"])
            terrain_colors.set_bad((0, 0, 0, 0))
            qs = np.nanpercentile(elev, [35, 50, 65, 78, 88, 98])
            norm = BoundaryNorm(qs, terrain_colors.N)

            lon_grid, lat_grid = np.meshgrid(dem_lon[lon_sub], dem_lat[lat_sub])
            gx, gy = transformer.transform(lon_grid, lat_grid)            
            qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,
                              shading='auto', alpha=0.55, antialiased=False, zorder=1)
            qm.set_clip_on(True)
    
    return qm


def _draw_glow_outline_outside(ax, gdf_albers, line_color="#444444", line_width=0.8,
                              widths_km=(0, 30, 50), alphas=(0.95, 0.5, 0.3), 
                              colors=("#08336e", "#105ca4", "#3888c0")):
    """绘制国界线光晕效果"""
    g = gdf_albers.copy()
    geom = unary_union(g.geometry)
    
    if make_valid is not None:
        geom = make_valid(geom)
    else:
        geom = geom.buffer(0)

    radii_m = [k*1000.0 for k in widths_km]
    ordered = list(zip(radii_m, alphas, colors))
    ordered.sort(key=lambda x: x[0])

    last_inner = geom.buffer(0, join_style=2, cap_style=2)
    glow_meshes = []
    
    for i, (r, alpha, color) in enumerate(ordered):
        outer = geom.buffer(r, join_style=2, cap_style=2)
        ring = outer.difference(last_inner)
        if not ring.is_empty:
            glow_series = gpd.GeoSeries([ring], crs=g.crs)
            glow_mesh = glow_series.plot(ax=ax, color=color, alpha=alpha, edgecolor="none",
                                       zorder=8.2 + i*0.01)
            glow_meshes.append(glow_mesh)
        last_inner = outer

    g.boundary.plot(ax=ax, color=line_color, linewidth=line_width, zorder=9.6)
    return glow_meshes


def _draw_background_map(ax, usa_bounds_main):
    """绘制背景地图"""
    try:
        bbox_expanded = box(
            usa_bounds_main['lon_min'] - 3, usa_bounds_main['lat_min'] - 3,
            usa_bounds_main['lon_max'] + 3, usa_bounds_main['lat_max'] + 3
        )
        helper_gdf = gpd.read_file(r'figure\draw_shp\ne_110m_land.shp').to_crs(4326)
        helper_gdf = helper_gdf.clip(bbox_expanded)
        
        if 'featurecla' in helper_gdf.columns:
            helper_gdf = helper_gdf[helper_gdf['featurecla'].str.contains('Land', na=False)]
        
        if make_valid is not None:
            helper_gdf['geometry'] = helper_gdf.geometry.map(make_valid)
        else:
            helper_gdf['geometry'] = helper_gdf.geometry.buffer(0)
            
        helper_gdf = helper_gdf[helper_gdf.geometry.geom_type.isin(['Polygon', 'MultiPolygon'])]
        helper_gdf = helper_gdf.explode(index_parts=False, ignore_index=True)
        helper_gdf_proj = helper_gdf.to_crs('ESRI:102003')
        helper_gdf_proj['geometry'] = helper_gdf_proj.buffer(0)
        helper_gdf_proj.plot(ax=ax, facecolor='white', edgecolor='lightgray',
                           linewidth=0.3, alpha=1, zorder=0.5)
    except Exception as e:
        print(f"Helper map error: {e}")

def _add_colorbar(fig, ax, scatter, var_bins, variable_name, unit, var_values, colors, policy_category=None):
    """添加颜色条到背景地图下方，并在上方显示数据分布曲线+柱状图"""
    def add_horizontal_cbar_equiv(fig, ax, mappable, var_bins, variable_name, unit_display=None, var_values=None, colors=None):
        # 获取主地图位置
        pos = ax.get_position()
        
        # colorbar放在底部，靠右显示（左边留给similarity分布图）
        # 左侧留出空间给similarity分布图，右侧放置colorbar
        left = pos.x0 + 0.18 * pos.width  # 从中间开始，靠右显示
        bottom = 0.1  # 固定在底部
        width = 0.65 * pos.width  # 宽度减小，靠右显示
        height = 0.035  # colorbar高度

        cax = fig.add_axes([left, bottom, width, height])
        cax.set_in_layout(False)
        cax.set_zorder(50)  
        cb = fig.colorbar(mappable, cax=cax, orientation='horizontal')

        cb.set_ticks(var_bins[1::2])

        # 格式化刻度标签
        tick_vals = var_bins[1::2]
        if variable_name == 'avg_npv':
            tick_labels = [f'{v/1_000:.1f}' for v in tick_vals]
        elif variable_name == 'predicted_prob':
            tick_labels = [f'{v:.2f}' for v in tick_vals]
        elif variable_name == 'Expectation_net_benefit':
            tick_labels = [f'{v/1_000:.2f}' for v in tick_vals]
        else:
            tick_labels = [f'{v:.0f}' for v in tick_vals]
        cb.set_ticklabels(tick_labels)

        # 细长
        cax.tick_params(axis='x', which='major', length=2.5, width=0.5, pad=1, labelsize=5)
        cax.tick_params(axis='x', which='minor', length=1.5, width=0.4, label1On=False)
        cb.outline.set_linewidth(0.7)

        # 百分位标签
        percentile_labels = [f'Top{p}' for p in range(10, 101, 20)][::-1]
        major_tick_locs = cb.get_ticks()
        for loc, label in zip(major_tick_locs, percentile_labels):
            # 使用colorbar坐标系统的正确定位方式
            x_disp = cax.transData.transform((loc, 0))[0]
            x_frac = cax.transAxes.inverted().transform((x_disp, 0))[0]
            cax.text(x_frac, 1.3, label, transform=cax.transAxes,
                    ha='center', va='bottom', fontsize=5, clip_on=False)
        cax.text(1.02, 1.3, "(%)", ha='center', va='bottom', fontsize=5, 
                fontweight='bold', transform=cax.transAxes)

        # ========== 添加数据分布曲线和柱状图 ==========
        chart_ax = None  # 初始化以便后续引用
        if var_values is not None and colors is not None:
            # 创建图表坐标轴，位置在colorbar上方，与colorbar对齐（靠右）
            chart_height = 0.11  # 图表高度
            chart_bottom = bottom + height + 0.05  # 在colorbar上方
            
            chart_ax = fig.add_axes([left, chart_bottom, width, chart_height])  # 使用相同的left和width，保持对齐
            chart_ax.set_in_layout(False)
            chart_ax.set_zorder(51)
            chart_ax.grid(False)  # 取消grid网格
            
            clean_values = var_values[~np.isnan(var_values)]
            if len(clean_values) > 0:
                # 1. 绘制柱状图（底层）- 使用10%分位数组
                # 定义10个分位数组 (0-10%, 10-20%, ..., 90-100%)
                percentiles = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
                bin_edges = np.percentile(clean_values, percentiles)
                
                # 计算每个10%分位数段的数据值和颜色
                bar_heights = []
                bar_positions = []
                bar_colors = []
                bar_widths = []
                

                for i in range(10):
                    bin_min = bin_edges[i]
                    bin_max = bin_edges[i+1]
                    
                    # x位置为该分位数段的中间值（百分比）
                    bar_positions.append((percentiles[i] + percentiles[i+1]) / 2)
                    
                    # 计算该分位数段的数据密度
                    count_in_bin = len(clean_values[(clean_values >= bin_min) & (clean_values <= bin_max)])
                    # y值为该分位数段的平均数值（不归一化）
                    mean_value = np.mean(clean_values[(clean_values >= bin_min) & (clean_values <= bin_max)])
                    
                    # 直接使用原始均值，不进行归一化
                    bar_heights.append(mean_value)
                    
                    # 柱子宽度
                    bar_widths.append(10)  # 每个柱子宽度为10%
                    
                    # 确定颜色
                    bar_colors.append(colors[i])
                # 绘制柱子（背景）
                for i in range(10):
                    chart_ax.bar(bar_positions[i], bar_heights[i], width=bar_widths[i],
                                color=bar_colors[i], alpha=0.4, edgecolor='none')
                
                # 2. 绘制曲线（顶层）- 使用每个bin的中心点
                chart_ax.plot(bar_positions, bar_heights, color='#333333', linewidth=1.2, alpha=0.8, zorder=10)
                
                # 设置x轴范围：0-100%（对应优先级/分位数）
                chart_ax.set_xlim(0, 100)
                
                # 动态设置y轴范围（不归一化，基于实际数据范围）
                if len(bar_heights) > 0:
                    y_min = min(bar_heights)
                    y_max = max(bar_heights)
                    y_range = y_max - y_min
                    # 设置y轴范围，留出10%的顶部空间
                    chart_ax.set_ylim(y_min - 0.05 * y_range if y_range > 0 else y_min - 0.1, 
                                     y_max + 0.2 * y_range if y_range > 0 else y_max + 0.1)
                else:
                    chart_ax.set_ylim(0, 1.1)
                
                # 3. 不显示x轴刻度和标签
                chart_ax.set_xticks([])
                chart_ax.set_xticklabels([])
                chart_ax.tick_params(axis='x', which='both', length=0, labelsize=0)
                # 显示x轴的实线
                chart_ax.spines['bottom'].set_visible(True)
                chart_ax.spines['bottom'].set_linewidth(1.0)
                chart_ax.spines['bottom'].set_color('#333333')

                # 在x轴末端（100%）添加美观的实心箭头
                from matplotlib.patches import FancyArrowPatch

                # 创建一个实心（三角形）箭头，arrowstyle='simple'
                arrow_solid = FancyArrowPatch(
                    posA=(101, 0), posB=(104, 0),
                    transform=chart_ax.get_xaxis_transform(),  
                    arrowstyle='simple',
                    color='black', linewidth=0, mutation_scale=8, zorder=20
                )
                arrow_solid.set_clip_on(False)
                chart_ax.add_patch(arrow_solid)

            
            # 移除顶部和侧边边框
            chart_ax.spines['top'].set_visible(False)
            chart_ax.spines['right'].set_visible(False)
            chart_ax.spines['left'].set_visible(True)  # 显示左侧Y轴
            chart_ax.spines['bottom'].set_visible(True)  # 显示底部X轴（修改这里）

            # 动态设置y轴刻度和标签，仅显示头、中、尾三个刻度，仅两位小数
            if len(bar_heights) > 0:
                y_min = min(bar_heights)
                y_max = max(bar_heights)
                y_range = y_max - y_min

                # 生成3个刻度：头、中、尾
                if y_range > 0:
                    y_mid = y_min + y_range / 2
                    y_ticks = [y_min, y_mid, y_max]
                    chart_ax.set_yticks(y_ticks)
                    chart_ax.set_yticklabels([f'{v/1000:.2f}' for v in y_ticks], fontsize=5)
                else:
                    chart_ax.set_yticks([y_min])
                    chart_ax.set_yticklabels([f'{y_min:.2f}'], fontsize=5)
            else:
                # 没有数据时，显示0,0.5,1三个刻度
                chart_ax.set_yticks([0, 0.5, 1.0])
                chart_ax.set_yticklabels(['0.00', '0.50', '1.00'], fontsize=5)

            chart_ax.tick_params(axis='y', which='major', length=2.5, width=0.5, pad=2, labelsize=5)
            chart_ax.tick_params(axis='y', which='minor', length=1.5, width=0.4)

            chart_ax.set_ylabel(f'Mean \n{unit}', fontsize=5, fontweight='bold', labelpad=2)

            from matplotlib.patches import FancyArrowPatch
            # Y轴箭头位置需要根据实际y_max调整
            y_lim = chart_ax.get_ylim()

            # 判断y_lim[1]的正负，确保箭头始终向上（指向更大的数值）
            # if y_lim[1] >= 0:
            #     # 正数或零：箭头从 y_lim[1] 指向 y_lim[1] * 1.1（向上）
            #     arrow_y_start = y_lim[1]
            #     arrow_y_end = y_lim[1] * 1.13
            # else:
            #     # 负数：箭头从 y_lim[1] 指向 y_lim[1] / 1.1（向上，因为除以1.1会让负数更接近0，即数值更大）
            #     arrow_y_start = y_lim[1]
            #     arrow_y_end = y_lim[1] / 1.13

            # arrow_y = FancyArrowPatch(
            #     posA=(0, arrow_y_start), posB=(0, arrow_y_end),
            #     transform=chart_ax.get_yaxis_transform(),
            #     arrowstyle='simple',
            #     color='black', linewidth=0, mutation_scale=8, zorder=20
            # )
            # arrow_y.set_clip_on(False)
            # chart_ax.add_patch(arrow_y)

            arrow_y = FancyArrowPatch(
                posA=(0, 1.0), posB=(0, 1.18),  
                transform=chart_ax.transAxes,   
                arrowstyle='simple',
                color='black', linewidth=0, mutation_scale=8, zorder=20
            )
            arrow_y.set_clip_on(False)
            chart_ax.add_patch(arrow_y)
            
            chart_ax.patch.set_visible(False)
            
            # 添加虚线引线连接colorbar的tick位置到数据分布图的对应位置
            if chart_ax is not None and clean_values is not None and len(clean_values) > 0:
                from matplotlib.lines import Line2D
                
                # 计算每个major_tick_loc对应的百分位数值
                for loc in major_tick_locs:
                    # 找到loc对应的数值在数据中的百分位数
                    pct = (np.sum(clean_values <= loc) / len(clean_values)) * 100
                    
                    # 引线的起点和终点x坐标都使用colorbar的刻度位置（确保垂直线）
                    # 将loc转换为figure坐标
                    x_tick_display, _ = cax.transData.transform((loc, 0))
                    x_tick_fig, _ = fig.transFigure.inverted().transform((x_tick_display, 0))
                    
                    # 起点的y坐标（colorbar顶部）
                    y_start = cax.get_position().y1
                    
                    # 终点的y坐标（数据分布图底部）
                    y_end = chart_ax.get_position().y0
                    
                    # 绘制垂直线（起点和终点的x坐标相同）
                    line = Line2D([x_tick_fig, x_tick_fig], [y_start, y_end],
                                transform=fig.transFigure, color="gray", 
                                lw=0.8, alpha=0.4, linestyle='--', zorder=45)
                    fig.add_artist(line)

        for ch in cax.get_children():
            try:
                ch.set_clip_path(cax.patch)
            except Exception:
                pass

        return cax, cb
    # 控制单位标签
    if variable_name == 'net_npv_usd':
        unit_display = f'Economic viability {unit} ' 
    elif variable_name == 'Expectation_net_benefit':
        unit_display = f'{unit} '  #(10³)
    else:
        if policy_category:
            unit_display = f'Economic viability ({policy_category})'
        else:
            unit_display = f'Economic viability '

    cbar_ax, cbar = add_horizontal_cbar_equiv(fig, ax, scatter, var_bins, variable_name, unit_display, var_values, colors)
    cbar_ax.text(0.5, -1.6, unit_display, ha='center', va='top', fontsize=5, 
                fontweight='bold', transform=cbar_ax.transAxes)
    return cbar_ax, cbar


def _draw_graticule_top_labels(ax, lon_ticks, lat_ticks, usa_bounds_main, proj_fwd, proj_inv):
    """绘制经纬度网格，标签显示在上方和左侧（曲线网格，去除右侧标签）"""
    def extent_lonlat_from_axes(ax, pad_deg=3):
        xmin, xmax = ax.get_xlim(); ymin, ymax = ax.get_ylim()
        xs = np.linspace(xmin, xmax, 512)
        ys = np.linspace(ymin, ymax, 512)
        
        b_lon, b_lat = proj_inv.transform(xs, np.full_like(xs, ymin))
        t_lon, t_lat = proj_inv.transform(xs, np.full_like(xs, ymax))
        l_lon, l_lat = proj_inv.transform(np.full_like(ys, xmin), ys)
        r_lon, r_lat = proj_inv.transform(np.full_like(ys, xmax), ys)
        
        lon_min = np.nanmin([b_lon.min(), t_lon.min(), l_lon.min(), r_lon.min()]) - pad_deg
        lon_max = np.nanmax([b_lon.max(), t_lon.max(), l_lon.max(), r_lon.max()]) + pad_deg
        lat_min = np.nanmin([b_lat.min(), t_lat.min(), l_lat.min(), r_lat.min()]) - pad_deg
        lat_max = np.nanmax([b_lat.max(), t_lat.max(), l_lat.max(), r_lat.max()]) + pad_deg
        
        return (lon_min, lon_max), (lat_min, lat_max)

    def _project_xy(xy):
        xs, ys = proj_fwd.transform(xy[:,0], xy[:,1])
        return np.column_stack([xs, ys])

    def _choose_point_on_edge(geom):
        if geom.is_empty:
            return None
        if geom.geom_type == "Point":
            return geom
        if geom.geom_type.startswith("Multi") or geom.geom_type == "GeometryCollection":
            pts = [g for g in getattr(geom, "geoms", []) if g.geom_type == "Point"]
            return pts[0] if pts else None
        if geom.geom_type == "LineString":
            return geom.interpolate(0.5, normalized=True)
        return None

    # 轴框与四边
    xmin, xmax = ax.get_xlim(); ymin, ymax = ax.get_ylim()
    frame = box(xmin, ymin, xmax, ymax)
    left = LineString([(xmin, ymin), (xmin, ymax)])
    right = LineString([(xmax, ymin), (xmax, ymax)])  
    bottom = LineString([(xmin, ymin), (xmax, ymin)])
    top = LineString([(xmin, ymax), (xmax, ymax)])

    # 获取经纬度范围
    lon_ext, lat_ext = extent_lonlat_from_axes(ax, pad_deg=3)

    # 经度刻度和网格线（曲线）
    xticks = []
    for i, lon in enumerate(lon_ticks):
        lats = np.linspace(lat_ext[0], lat_ext[1], 1000)
        xy = _project_xy(np.column_stack([np.full_like(lats, lon), lats]))
        line = LineString(xy)

        # 绘制曲线网格线
        ax.plot(xy[:,0], xy[:,1], lw=0.3, color="#cfcfcf", zorder=0.6, alpha=0.7)

        # 在顶部显示标签
        p = _choose_point_on_edge(line.intersection(top))
        if p is None:
            p = _choose_point_on_edge(line.intersection(bottom))
        if p is not None:
            # 检查与已有标签的最小距离
            min_distance = 0.02 * (ax.get_xlim()[1] - ax.get_xlim()[0])
            if not xticks or min([abs(p.x - x) for x in xticks]) > min_distance:
                xticks.append(p.x)
                # 在上方显示标签
                ax.text(p.x, ymax + (ymax - ymin) * 0.02, f"{abs(int(round(lon)))}°W",
                       ha='center', va='bottom', fontsize=5, zorder=20, color='#666666',
                       bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7, 
                               edgecolor='none'))

    # 纬度刻度和网格线（曲线）
    yticks = []
    for lat in lat_ticks:
        lons = np.linspace(lon_ext[0], lon_ext[1], 1000)
        xy = _project_xy(np.column_stack([lons, np.full_like(lons, lat)]))
        line = LineString(xy)

        # 绘制曲线网格线
        ax.plot(xy[:,0], xy[:,1], lw=0.3, color="#cfcfcf", zorder=0.6, alpha=0.7)

        # 只在左侧显示标签（不在右侧）
        p = _choose_point_on_edge(line.intersection(left))
        if p is not None:
            yticks.append(p.y)
            ax.text(xmin - (xmax - xmin) * 0.02, p.y, f"{int(round(lat))}°N",
                   ha='right', va='center', fontsize=5, zorder=20, color='#666666',
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7,
                           edgecolor='none'))




## 2 Plot scenarios

In [145]:

def plot_single_variable(
    merged_data_for_plot, 
    us_states_gdf, 
    dem_data, 
    variable_name,
    title=None, 
    unit=None, 
    colors=None, 
    cmap_type='sequential',
    add_north_arrow_and_scalebar=True,
    add_embedding_points=False,
    figsize=None,
    policy_category=None
):
    """
    绘制单变量地图，简化版本
    
    Parameters:
    -----------
    merged_data_for_plot : DataFrame
        包含经纬度和变量数据的数据框
    us_states_gdf : GeoDataFrame
        美国州界数据
    dem_data : xarray.Dataset
        DEM地形数据
    variable_name : str
        要绘制的变量名
    title : str, optional
        图表标题
    unit : str, optional
        变量单位
    colors : list, optional
        颜色列表
    cmap_type : str, default 'sequential'
        颜色映射类型
    add_north_arrow_and_scalebar : bool, default True
        是否添加指北针和比例尺
    add_embedding_points : bool, default False
        是否添加嵌入点
    figsize : tuple, optional
        图形大小（英寸），如果为None则使用默认值
        
    Returns:
    --------
    fig, ax : matplotlib figure and axes
    """
    
    plt.rcParams.update({
        'font.size': 5, 'axes.titlesize': 5, 'axes.labelsize': 5,
        'xtick.labelsize': 5, 'ytick.labelsize': 5, 'legend.fontsize': 5,
        'font.family': 'Arial'
    })


    # 创建更大的图形和坐标轴
    if figsize is None:
        figsize = (85/25.4, 80/25.4)
    fig, ax = plt.subplots(1, 1, figsize=figsize)  
    ax.set_autoscale_on(False)
    fig.patch.set_facecolor('white')     

    # ==================== 2. 地理边界设置 ====================
    usa_bounds_main = {'lon_min': -125, 'lon_max': -65, 'lat_min': 24, 'lat_max': 51}
    bbox = box(usa_bounds_main['lon_min'], usa_bounds_main['lat_min'],
               usa_bounds_main['lon_max'], usa_bounds_main['lat_max'])
    
    # 处理州界数据
    us_states_bound = us_states_gdf.to_crs(epsg=4326).clip(bbox)
    us_states_albers = us_states_bound.to_crs('ESRI:102003')
    
    # 处理国界线
    us_nation_bound = us_nation_4326.clip(bbox)
    us_nation_albers = us_nation_bound.to_crs('ESRI:102003')

    # ==================== 3. 坐标变换设置 ====================
    transformer = Transformer.from_crs("EPSG:4326", "ESRI:102003", always_xy=True)
    proj_fwd = Transformer.from_crs("EPSG:4326", "ESRI:102003", always_xy=True)
    proj_inv = Transformer.from_crs("ESRI:102003", "EPSG:4326", always_xy=True)

    # ==================== 4. 地图范围计算 ====================
    xmin, ymin, xmax, ymax = us_states_albers.total_bounds

    # 减小边界，让图幅更大
    margin_x = (xmax - xmin) * 0.02  
    margin_y = (ymax - ymin) * 0.02
    colorbar_space = (ymax - ymin) * 0.05  
    
    xmin = xmin - margin_x
    xmax = xmax + margin_x
    ymin = ymin - margin_y - colorbar_space
    ymax = ymax + margin_y

    # 设置坐标轴位置和范围 - 图像整体向上移动，比例不变
    ax.set_position([0.09, 0.17, 0.9, 0.9])
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    # ==================== 5. 数据坐标转换 ====================
    x_proj, y_proj = transformer.transform(merged_data_for_plot['lon'].values, 
                                          merged_data_for_plot['lat'].values)

    # ==================== 6. 地形渲染 ====================
    qm = _render_terrain(ax, dem_data, usa_bounds_main, transformer)

    # ==================== 7. 地图元素绘制 ====================
    # 绘制国界线光晕效果
    _draw_glow_outline_outside(ax, us_nation_albers)
    
    # 绘制州界
    us_states_albers.plot(ax=ax, color='none', edgecolor='black', linewidth=0.4, alpha=0.4, zorder=9)

    # ==================== 8. 取消坐标轴和网格 ====================
    # 移除坐标轴边框
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

    lon_ticks = np.arange(-115, -65, 10)
    lat_ticks = [ 30,35,40,45]
    
    # 绘制曲线网格线（标签显示在上方）
    proj_fwd = Transformer.from_crs("EPSG:4326", "ESRI:102003", always_xy=True)
    proj_inv = Transformer.from_crs("ESRI:102003", "EPSG:4326", always_xy=True)
    _draw_graticule_top_labels(ax, lon_ticks, lat_ticks, usa_bounds_main, proj_fwd, proj_inv)

    # ==================== 9. 背景地图 ====================
    # _draw_background_map(ax, usa_bounds_main)

    # ==================== 10. 数据可视化 ====================
    # 关键修复：确保 colors 列表有11个颜色（对应11个分位数区间：0, 10, 20, ..., 100）
    if colors is None or len(colors) < 11:
        raise ValueError(f"colors 列表必须有11个颜色，当前有 {len(colors) if colors else 0} 个")
    
    var_values = np.round(np.asarray(merged_data_for_plot[variable_name].values, dtype=float), 3)
    var_bins = np.nanpercentile(var_values, [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    
    
    cmap = LinearSegmentedColormap.from_list(f"{variable_name}_cmap", colors, N=11)
    norm = BoundaryNorm(var_bins, ncolors=cmap.N)
    
    # scatter 使用自定义的 cmap 和 norm（这样 colorbar 才能正确显示对应的颜色）
    # 注意：colorbar 会从 scatter 的 mappable 中获取 cmap，所以必须确保 scatter 使用了正确的 cmap
    scatter = ax.scatter(x_proj, y_proj, c=var_values, cmap=cmap, norm=norm,
                        s=0.3, alpha=1, edgecolors='none', zorder=5)

    # ==================== 11. 颜色条（传入var_values用于绘制分布图）====================
    # colorbar 会从 scatter 的 mappable 中自动获取 cmap，所以 scatter 必须使用正确的 cmap
    _add_colorbar(fig, ax, scatter, var_bins, variable_name, unit, var_values, colors, policy_category=policy_category)
    # ==================== 12. 最终处理 ====================
    fig.canvas.draw()          
    if qm is not None:
        qm.set_clip_path(ax.patch)
        qm.set_rasterized(True)

    return fig, ax

print("Helper functions loaded successfully!")

Helper functions loaded successfully!


In [146]:
## 绘制每个政策类别的经济地图

# 定义重命名后的政策类别（不包括P4，共9个）
policy_categories_renamed = ['P1a', 'P1b', 'P1c', 'P2a', 'P2b', 'P2c', 'P3a', 'P3b', 'P3c','P4','P1d']
# 如果只要9个，可以排除最后一个
policy_categories_to_plot = policy_categories_renamed[:11]  

print(f"将绘制以下政策类别: {policy_categories_to_plot}")
print(f"共 {len(policy_categories_to_plot)} 张图")

color_schemes = {
    "P1": [
        "#08306b", "#0f4a90", "#1a67b0", "#3a85c9", "#64a6df",
        "#9bc7f0", "#d1e6f9", "#f6e3c0", "#e8b96a", "#cc8429", "#9c5400"
    ],
    "P2": [
        "#00441b", "#0a5b27", "#19743a", "#2e8c50", "#4ca56b",
        "#78c193", "#b2dcbd", "#e0efe4", "#d4c7eb", "#a287d3", "#6a43a6"
    ],
    "P3": [
        "#fff200",  # bright yellow
        "#ffd94e", "#ffc16e", "#ffab92", "#f798b5", "#e184d1",
        "#c06fdc", "#9b58dd", "#6f41d6", "#4027a9", "#2e0c6f"   # deep purple
    ],
    "P4": [
        "#003f5c", "#235173", "#49648b", "#6d77a6", "#928ac0",
        "#b5a2d4", "#dac8e4", "#f9eedf", "#dfc68d", "#b18e3b", "#7d5c00"
    ]
}


# 变量配置
variable_config = {
    'name': 'avg_npv',
    'title': 'Economic Viability', 
    'unit': 'kUSD ha$^{-1}$'
}

# 设置figsize为60mm x 60mm
figsize_mm = (60, 60)
figsize_inches = (figsize_mm[0]/25.4, figsize_mm[1]/25.4)

# 为每个政策类别绘制地图
import os
output_dir = 'Supplymentary_figure'
os.makedirs(output_dir, exist_ok=True)

for idx, policy in enumerate(policy_categories_to_plot):
    print(f"\n正在绘制政策类别: {policy} ({idx+1}/{len(policy_categories_to_plot)})")
    
    if policy.startswith('P1'):
        colors = color_schemes['P1']
    elif policy.startswith('P2'):
        colors = color_schemes['P2']
    elif policy.startswith('P3'):
        colors = color_schemes['P3']
    else:
        colors = color_schemes['P4']  
    
    # 筛选该政策类别的数据（2050年）
    economic_policy = df_economic[
        (df_economic['policy_category'] == policy) & 
        (df_economic['analysis_year'] == 2050)
    ].copy()
    
    if economic_policy.empty:
        print(f"  警告: 政策类别 {policy} 没有数据，跳过")
        continue
    
    # 计算该政策下的avg_npv（按lat, lon分组平均）
    avg_npv_policy = economic_policy.groupby(['lat', 'lon'])['net_npv_usd'].mean().reset_index()
    avg_npv_policy = avg_npv_policy.rename(columns={'net_npv_usd': 'avg_npv'})
    

    # 绘制地图
    
    fig, ax = plot_single_variable(
        avg_npv_policy, 
        us_states_4326, 
        dem_us,
        variable_config['name'],
        variable_config['title'],
        variable_config['unit'],
        colors,
        figsize=figsize_inches,
        policy_category=policy
    )
    
    # 添加panel标签（a, b, c, ...）
    panel_label = chr(97 + idx)  # 97是'a'的ASCII码
    fig.text(0.01, 0.99, panel_label, ha='left', va='top', fontsize=7, fontweight='bold',
             bbox=dict(facecolor='white', alpha=0.7, pad=0.2, lw=0), zorder=100)

    # 保存图片
    filename_png = f"{output_dir}/economic_policy_{policy}_60mm.png"
    filename_pdf = f"{output_dir}/economic_policy_{policy}_60mm.pdf"
    
    fig.savefig(filename_png, dpi=300, facecolor='White')
    fig.canvas.draw()
    fig.savefig(filename_pdf, dpi=300, facecolor='None')
    
    print(f"  图片已保存: {filename_png}")
    plt.close(fig)

print("\n所有政策类别的地图绘制完成！")

将绘制以下政策类别: ['P1a', 'P1b', 'P1c', 'P2a', 'P2b', 'P2c', 'P3a', 'P3b', 'P3c', 'P4', 'P1d']
共 11 张图

正在绘制政策类别: P1a (1/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P1a_60mm.png

正在绘制政策类别: P1b (2/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P1b_60mm.png

正在绘制政策类别: P1c (3/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P1c_60mm.png

正在绘制政策类别: P2a (4/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P2a_60mm.png

正在绘制政策类别: P2b (5/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P2b_60mm.png

正在绘制政策类别: P2c (6/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P2c_60mm.png

正在绘制政策类别: P3a (7/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P3a_60mm.png

正在绘制政策类别: P3b (8/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P3b_60mm.png

正在绘制政策类别: P3c (9/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P3c_60mm.png

正在绘制政策类别: P4 (10/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P4_60mm.png

正在绘制政策类别: P1d (11/11)


  qm = ax.pcolormesh(gx, gy, elev, cmap=terrain_colors, norm=norm,


  图片已保存: Supplymentary_figure/economic_policy_P1d_60mm.png

所有政策类别的地图绘制完成！
