# Water Stress + High-Growth Data Center Overlay

This notebook overlays Aqueduct 4.0 projected water stress (future annual) with projected high-growth data center locations, and lets you toggle market gravity scenarios.

In [11]:
from pathlib import Path

import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import ipywidgets as widgets

# Paths
ROOT = Path('.')
AQUEDUCT_CSV = ROOT / 'datasets/Aqueduct40_waterrisk_download_Y2023M07D05/CVS/Aqueduct40_future_annual_y2023m07d05.csv'
AQUEDUCT_GDB = ROOT / 'datasets/Aqueduct40_waterrisk_download_Y2023M07D05/GDB/Aq40_Y2023D07M05.gdb'
PROJECTED_DC_DIR = ROOT / 'datasets/im3_projected_data_centers'

# Aqueduct 2050 medium-emissions water stress fields
WATER_STRESS_VALUE_FIELD = 'bau50_ws_x_r'
WATER_STRESS_CODE_FIELD = 'bau50_ws_x_c'
WATER_STRESS_LABEL_FIELD = 'bau50_ws_x_l'

for pth in [AQUEDUCT_CSV, AQUEDUCT_GDB, PROJECTED_DC_DIR]:
    if not pth.exists():
        raise FileNotFoundError(f'Missing required path: {pth}')

print('Aqueduct CSV:', AQUEDUCT_CSV)
print('Aqueduct GDB:', AQUEDUCT_GDB)
print('Projected data center dir:', PROJECTED_DC_DIR)


Aqueduct CSV: datasets/Aqueduct40_waterrisk_download_Y2023M07D05/CVS/Aqueduct40_future_annual_y2023m07d05.csv
Aqueduct GDB: datasets/Aqueduct40_waterrisk_download_Y2023M07D05/GDB/Aq40_Y2023D07M05.gdb
Projected data center dir: datasets/im3_projected_data_centers


In [17]:
# Load water stress attributes and geometry, then join on pfaf_id
water_attr = pd.read_csv(
    AQUEDUCT_CSV,
    usecols=['pfaf_id', WATER_STRESS_VALUE_FIELD, WATER_STRESS_CODE_FIELD, WATER_STRESS_LABEL_FIELD],
).copy()
water_attr['pfaf_id'] = water_attr['pfaf_id'].astype('int64')

water_geom = gpd.read_file(AQUEDUCT_GDB, layer='future_annual')[['pfaf_id', 'geometry']].copy()
water_geom['pfaf_id'] = water_geom['pfaf_id'].astype('int64')

water_gdf = water_geom.merge(water_attr, on='pfaf_id', how='left')
water_gdf = water_gdf.set_crs('EPSG:4326') if water_gdf.crs is None else water_gdf.to_crs('EPSG:4326')

# Contiguous US bounding box (lon/lat)
US_MINX, US_MAXX, US_MINY, US_MAXY = -125.0, -66.5, 24.0, 49.8
water_us = water_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()

# Build a clean outline from dissolved stress polygons in the US study window
us_outline = water_us.dissolve().boundary

# Discrete class palette for all Aqueduct stress classes
stress_order = [-1, 0, 1, 2, 3, 4]
stress_labels = {
    -1: 'Arid and low water use',
    0: 'Low (<10%)',
    1: 'Low-medium (10-20%)',
    2: 'Medium-high (20-40%)',
    3: 'High (40-80%)',
    4: 'Extremely high (>80%)',
}
stress_colors = {
    -1: '#e8efe3',
    0: '#b8e186',
    1: '#7fbc41',
    2: '#fdae61',
    3: '#f46d43',
    4: '#d73027',
}

print('Water polygons (global):', len(water_gdf))
print('Water polygons (US view):', len(water_us))
print('Stress levels present in US:')
print(water_us[WATER_STRESS_LABEL_FIELD].value_counts().to_string())


Water polygons (global): 16395
Water polygons (US view): 1303
Stress levels present in US:
bau50_ws_x_l
Low (<10%)                535
Extremely high (>80%)     231
Medium-high (20-40%)      157
Low-medium (10-20%)       129
Arid and low water use    118
High (40-80%)             115


In [18]:
# Load all growth projections and market gravity variants
projection_dirs = sorted([d for d in PROJECTED_DC_DIR.iterdir() if d.is_dir()])
if not projection_dirs:
    raise FileNotFoundError(f'No projection folders found in {PROJECTED_DC_DIR}')

dc_by_projection_and_gravity = {}
for proj_dir in projection_dirs:
    projection = proj_dir.name
    geojson_files = sorted(proj_dir.glob(f'{projection}_*_market_gravity.geojson'))
    if not geojson_files:
        continue

    by_gravity = {}
    for fp in geojson_files:
        token = fp.stem.split('_')[-3]  # <projection>_<gravity>_market_gravity
        gravity = int(token)
        gdf = gpd.read_file(fp)
        gdf = gdf.set_crs('EPSG:4326') if gdf.crs is None else gdf.to_crs('EPSG:4326')
        by_gravity[gravity] = gdf

    if by_gravity:
        dc_by_projection_and_gravity[projection] = by_gravity

if not dc_by_projection_and_gravity:
    raise FileNotFoundError('No projected data center GeoJSON files were loaded.')

projection_values = sorted(dc_by_projection_and_gravity.keys())
gravity_values = sorted({g for by_g in dc_by_projection_and_gravity.values() for g in by_g.keys()})

print('Available growth projections:', projection_values)
for proj in projection_values:
    gravities = sorted(dc_by_projection_and_gravity[proj].keys())
    print(f'{proj}: market gravities {gravities}')


Available growth projections: ['high_growth', 'higher_growth', 'low_growth', 'moderate_growth']
high_growth: market gravities [0, 25, 50, 75, 100]
higher_growth: market gravities [0, 25, 50, 75, 100]
low_growth: market gravities [0, 25, 50, 75, 100]
moderate_growth: market gravities [0, 25, 50, 75, 100]


In [19]:
def plot_overlay(projection: str, gravity: int, site_size: int = 30, show_basin_boundaries: bool = True):
    available_gravities = sorted(dc_by_projection_and_gravity[projection].keys())
    if gravity not in dc_by_projection_and_gravity[projection]:
        gravity = available_gravities[0]

    dc_gdf = dc_by_projection_and_gravity[projection][int(gravity)].copy()
    dc_us = dc_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()

    fig, ax = plt.subplots(figsize=(14, 8.5), facecolor='#f7f6f2')
    ax.set_facecolor('#eaf2f7')

    # Draw all stress classes explicitly so no level disappears from the map
    for code in stress_order:
        subset = water_us[water_us[WATER_STRESS_CODE_FIELD] == code]
        if len(subset) == 0:
            continue
        subset.plot(
            ax=ax,
            color=stress_colors[code],
            linewidth=0,
            alpha=0.85,
        )

    if show_basin_boundaries:
        water_us.boundary.plot(ax=ax, color='#ffffff', linewidth=0.15, alpha=0.35)

    # US outline on top
    us_outline.plot(ax=ax, color='#1f2937', linewidth=1.2, alpha=0.95, zorder=5)

    # Represent sites as points and tag each with the local stress class
    sites = dc_us.copy()
    sites['geometry'] = sites.geometry.representative_point()

    water_join = water_us[[WATER_STRESS_CODE_FIELD, WATER_STRESS_LABEL_FIELD, 'geometry']].dropna(subset=[WATER_STRESS_CODE_FIELD])
    sites = gpd.sjoin(sites, water_join, how='left', predicate='within')
    sites = sites.rename(columns={WATER_STRESS_CODE_FIELD: 'site_stress_code', WATER_STRESS_LABEL_FIELD: 'site_stress_label'})

    # Site styling: lower stress vs medium stress vs higher stress
    def site_group(code):
        if pd.isna(code):
            return 'Unknown'
        code = int(code)
        if code >= 3:
            return 'High/Extremely high stress'
        if code == 2:
            return 'Medium-high stress'
        return 'Low to low-medium stress'

    sites['site_group'] = sites['site_stress_code'].apply(site_group)

    site_styles = {
        'Low to low-medium stress': {'marker': 'o', 'color': '#0f766e'},
        'Medium-high stress': {'marker': '^', 'color': '#b45309'},
        'High/Extremely high stress': {'marker': 'X', 'color': '#b91c1c'},
        'Unknown': {'marker': 's', 'color': '#6b7280'},
    }

    for group_name, style in site_styles.items():
        group = sites[sites['site_group'] == group_name]
        if len(group) == 0:
            continue
        group.plot(
            ax=ax,
            marker=style['marker'],
            color=style['color'],
            markersize=site_size,
            edgecolor='white',
            linewidth=0.4,
            alpha=0.95,
            zorder=8,
        )

    # Legends
    stress_handles = [
        Line2D([0], [0], marker='s', linestyle='None', markerfacecolor=stress_colors[c], markeredgecolor='none', markersize=10, label=stress_labels[c])
        for c in stress_order
    ]
    site_handles = [
        Line2D([0], [0], marker=v['marker'], linestyle='None', markerfacecolor=v['color'], markeredgecolor='white', markersize=9, label=k)
        for k, v in site_styles.items()
    ]

    leg1 = ax.legend(handles=stress_handles, title='Water stress levels', loc='lower left', frameon=True, framealpha=0.94, fontsize=9, title_fontsize=10)
    leg2 = ax.legend(handles=site_handles, title='Proposed site stress class', loc='upper left', frameon=True, framealpha=0.94, fontsize=9, title_fontsize=10)
    ax.add_artist(leg1)

    ax.set_xlim(US_MINX, US_MAXX)
    ax.set_ylim(US_MINY, US_MAXY)
    ax.set_title(
        f'US Water Stress (Aqueduct 2050 BAU50) with Proposed Data Centers | Projection: {projection} | Market Gravity {gravity}',
        fontsize=13,
        pad=12,
    )
    ax.set_axis_off()
    plt.tight_layout()
    plt.show()


In [None]:
widgets.interact(
    plot_overlay,
    projection=widgets.ToggleButtons(
        options=projection_values,
        value='high' if 'high' in projection_values else projection_values[0],
        description='Projection',
    ),
    gravity=widgets.SelectionSlider(
        options=gravity_values,
        value=25 if 25 in gravity_values else gravity_values[0],
        description='Market gravity',
        continuous_update=False,
    ),
    site_size=widgets.IntSlider(
        value=30,
        min=12,
        max=80,
        step=2,
        description='Site size',
        continuous_update=False,
    ),
    show_basin_boundaries=widgets.Checkbox(
        value=True,
        description='Show basin lines',
    ),
);


interactive(children=(ToggleButtons(description='Projection', options=('high_growth', 'higher_growth', 'low_gr…

In [27]:
# Renewable subregion heatmap layer + projected data center overlay
import re
import zipfile
import xml.etree.ElementTree as ET

EGRID_XLSX = ROOT / 'grid data/egrid2023_data_rev2 (2).xlsx'
EGRID_SUBREGION_KMZ = ROOT / 'grid data/egrid2023_subregions.kmz'

if not EGRID_XLSX.exists() or not EGRID_SUBREGION_KMZ.exists():
    raise FileNotFoundError('Missing eGRID workbook or subregion KMZ in grid data folder.')

def _col_to_idx(col_ref: str) -> int:
    n = 0
    for ch in col_ref:
        if ch.isalpha():
            n = n * 26 + (ord(ch.upper()) - 64)
    return n - 1

def read_xlsx_sheet_no_openpyxl(xlsx_path: Path, sheet_name: str) -> pd.DataFrame:
    """Read an .xlsx sheet using XML so this notebook works without openpyxl."""
    ns = {
        'm': 'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
        'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
    }

    with zipfile.ZipFile(xlsx_path) as zf:
        wb = ET.fromstring(zf.read('xl/workbook.xml'))
        rels = ET.fromstring(zf.read('xl/_rels/workbook.xml.rels'))
        rel_map = {
            r.attrib['Id']: r.attrib['Target']
            for r in rels.findall('{http://schemas.openxmlformats.org/package/2006/relationships}Relationship')
        }

        sheet_rid = None
        for sh in wb.findall('m:sheets/m:sheet', ns):
            if sh.attrib['name'] == sheet_name:
                sheet_rid = sh.attrib['{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id']
                break
        if sheet_rid is None:
            raise KeyError(f'Sheet {sheet_name!r} not found in {xlsx_path.name}')

        target = rel_map[sheet_rid]
        if not target.startswith('xl/'):
            target = 'xl/' + target

        shared = []
        if 'xl/sharedStrings.xml' in zf.namelist():
            sroot = ET.fromstring(zf.read('xl/sharedStrings.xml'))
            for si in sroot.findall('m:si', ns):
                shared.append(''.join((t.text or '') for t in si.findall('.//m:t', ns)))

        sroot = ET.fromstring(zf.read(target))
        rows = []
        for row in sroot.findall('.//m:sheetData/m:row', ns):
            rec = {}
            for c in row.findall('m:c', ns):
                ref = c.attrib.get('r', 'A1')
                col = ''.join(ch for ch in ref if ch.isalpha())
                idx = _col_to_idx(col)

                v = c.find('m:v', ns)
                if v is None or v.text is None:
                    rec[idx] = ''
                    continue

                val = v.text
                if c.attrib.get('t') == 's':
                    val = shared[int(val)]
                rec[idx] = val
            rows.append(rec)

    width = max(max(r.keys(), default=0) for r in rows) + 1
    matrix = [[r.get(i, '') for i in range(width)] for r in rows]

    col_names = matrix[1]  # eGRID field codes on row 2
    df = pd.DataFrame(matrix[2:], columns=col_names)
    return df

def extract_subrgn_from_description(desc: str) -> str | None:
    text = str(desc)
    match = re.search(r'Subregion\s*</td>\s*<td[^>]*>\s*([^<\s]+)', text, flags=re.IGNORECASE)
    if not match:
        return None
    code = match.group(1).strip()
    if code in {'&lt;Null&gt;', '<Null>', 'NULL', 'null'}:
        return None
    return code

# Build renewable-percentage layer by eGRID subregion
sr = read_xlsx_sheet_no_openpyxl(EGRID_XLSX, 'SRL23')[['SUBRGN', 'SRNAME', 'SRTRPR']].copy()
sr['SUBRGN'] = sr['SUBRGN'].astype(str).str.strip()
sr['renewable_pct'] = pd.to_numeric(sr['SRTRPR'], errors='coerce') * 100.0

subregion_gdf = gpd.read_file(EGRID_SUBREGION_KMZ)[['description', 'geometry']].copy()
subregion_gdf = subregion_gdf.set_crs('EPSG:4326') if subregion_gdf.crs is None else subregion_gdf.to_crs('EPSG:4326')
subregion_gdf['SUBRGN'] = subregion_gdf['description'].map(extract_subrgn_from_description)
subregion_gdf = subregion_gdf.dropna(subset=['SUBRGN'])

renewables_gdf = subregion_gdf.merge(sr[['SUBRGN', 'SRNAME', 'renewable_pct']], on='SUBRGN', how='left')
renewables_us = renewables_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()
us_outline_renewables = renewables_us.dissolve().boundary

print('Renewables subregions in US view:', len(renewables_us))
print('Renewables % range:', f"{renewables_us['renewable_pct'].min():.1f}% to {renewables_us['renewable_pct'].max():.1f}%")
print('Missing renewable % values:', int(renewables_us['renewable_pct'].isna().sum()))

def plot_renewables_overlay(projection: str, gravity: int, site_size: int = 30, show_subregion_lines: bool = True):
    available_gravities = sorted(dc_by_projection_and_gravity[projection].keys())
    if gravity not in dc_by_projection_and_gravity[projection]:
        gravity = available_gravities[0]

    dc_gdf = dc_by_projection_and_gravity[projection][int(gravity)].copy()
    dc_us = dc_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()

    fig, ax = plt.subplots(figsize=(14, 8.5), facecolor='#f7f6f2')
    ax.set_facecolor('#ecf7f2')

    renewables_us.plot(
        column='renewable_pct',
        cmap='YlGn',
        vmin=0,
        vmax=100,
        linewidth=0,
        alpha=0.88,
        legend=True,
        ax=ax,
        legend_kwds={'label': 'Renewable generation share (%)', 'shrink': 0.72},
    )

    if show_subregion_lines:
        renewables_us.boundary.plot(ax=ax, color='white', linewidth=0.45, alpha=0.9, zorder=5)

    us_outline_renewables.plot(ax=ax, color='#1f2937', linewidth=1.25, alpha=0.95, zorder=6)

    # Sites as representative points, classified by local renewable share
    sites = dc_us.copy()
    sites['geometry'] = sites.geometry.representative_point()
    sites = gpd.sjoin(sites, renewables_us[['renewable_pct', 'geometry']], how='left', predicate='within')

    def renewable_group(val):
        if pd.isna(val):
            return 'Unknown'
        if val >= 50:
            return 'High renewables (>=50%)'
        if val >= 30:
            return 'Medium renewables (30-50%)'
        return 'Low renewables (<30%)'

    sites['renew_group'] = sites['renewable_pct'].apply(renewable_group)

    site_styles = {
        'Low renewables (<30%)': {'marker': 'X', 'color': '#b91c1c'},
        'Medium renewables (30-50%)': {'marker': '^', 'color': '#b45309'},
        'High renewables (>=50%)': {'marker': 'o', 'color': '#0f766e'},
        'Unknown': {'marker': 's', 'color': '#6b7280'},
    }

    for group_name, style in site_styles.items():
        group = sites[sites['renew_group'] == group_name]
        if len(group) == 0:
            continue
        group.plot(
            ax=ax,
            marker=style['marker'],
            color=style['color'],
            markersize=site_size,
            edgecolor='white',
            linewidth=0.45,
            alpha=0.95,
            zorder=8,
        )

    site_handles = [
        Line2D([0], [0], marker=v['marker'], linestyle='None', markerfacecolor=v['color'], markeredgecolor='white', markersize=9, label=k)
        for k, v in site_styles.items()
    ]
    ax.legend(handles=site_handles, title='Site renewable context', loc='upper left', frameon=True, framealpha=0.94, fontsize=9, title_fontsize=10)

    ax.set_xlim(US_MINX, US_MAXX)
    ax.set_ylim(US_MINY, US_MAXY)
    ax.set_title(
        f'US Renewable Generation Share by eGRID Subregion + Proposed Data Centers | Projection: {projection} | Market Gravity {gravity}',
        fontsize=13,
        pad=12,
    )
    ax.set_axis_off()
    plt.tight_layout()
    plt.show()

widgets.interact(
    plot_renewables_overlay,
    projection=widgets.ToggleButtons(
        options=projection_values,
        value='high_growth' if 'high_growth' in projection_values else projection_values[0],
        description='Projection',
    ),
    gravity=widgets.SelectionSlider(
        options=gravity_values,
        value=25 if 25 in gravity_values else gravity_values[0],
        description='Market gravity',
        continuous_update=False,
    ),
    site_size=widgets.IntSlider(
        value=30,
        min=12,
        max=80,
        step=2,
        description='Site size',
        continuous_update=False,
    ),
    show_subregion_lines=widgets.Checkbox(
        value=True,
        description='Show subregion lines',
    ),
);


FileNotFoundError: Missing eGRID workbook or subregion KMZ in grid data folder.

In [28]:
# Unified map: toggle background layer (water stress vs renewables)
def plot_toggle_background(
    background: str,
    projection: str,
    gravity: int,
    site_size: int = 30,
    show_boundaries: bool = True,
):
    available_gravities = sorted(dc_by_projection_and_gravity[projection].keys())
    if gravity not in dc_by_projection_and_gravity[projection]:
        gravity = available_gravities[0]

    dc_gdf = dc_by_projection_and_gravity[projection][int(gravity)].copy()
    dc_us = dc_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()

    fig, ax = plt.subplots(figsize=(14, 8.5), facecolor='#f7f6f2')

    sites = dc_us.copy()
    sites['geometry'] = sites.geometry.representative_point()

    if background == 'Water stress':
        ax.set_facecolor('#eaf2f7')

        for code in stress_order:
            subset = water_us[water_us[WATER_STRESS_CODE_FIELD] == code]
            if len(subset) == 0:
                continue
            subset.plot(ax=ax, color=stress_colors[code], linewidth=0, alpha=0.85)

        if show_boundaries:
            water_us.boundary.plot(ax=ax, color='#ffffff', linewidth=0.15, alpha=0.35)
        us_outline.plot(ax=ax, color='#1f2937', linewidth=1.2, alpha=0.95, zorder=6)

        water_join = water_us[[WATER_STRESS_CODE_FIELD, WATER_STRESS_LABEL_FIELD, 'geometry']].dropna(subset=[WATER_STRESS_CODE_FIELD])
        sites = gpd.sjoin(sites, water_join, how='left', predicate='within')
        sites = sites.rename(columns={WATER_STRESS_CODE_FIELD: 'site_metric'})

        def site_group(v):
            if pd.isna(v):
                return 'Unknown'
            v = int(v)
            if v >= 3:
                return 'High/Extremely high stress'
            if v == 2:
                return 'Medium-high stress'
            return 'Low to low-medium stress'

        sites['site_group'] = sites['site_metric'].apply(site_group)

        bg_handles = [
            Line2D([0], [0], marker='s', linestyle='None', markerfacecolor=stress_colors[c], markeredgecolor='none', markersize=10, label=stress_labels[c])
            for c in stress_order
        ]
        bg_title = 'Water stress levels'

    else:
        ax.set_facecolor('#ecf7f2')

        renewables_us.plot(
            column='renewable_pct',
            cmap='YlGn',
            vmin=0,
            vmax=100,
            linewidth=0,
            alpha=0.88,
            legend=True,
            ax=ax,
            legend_kwds={'label': 'Renewable generation share (%)', 'shrink': 0.72},
        )

        if show_boundaries:
            renewables_us.boundary.plot(ax=ax, color='white', linewidth=0.45, alpha=0.9, zorder=5)
        us_outline_renewables.plot(ax=ax, color='#1f2937', linewidth=1.25, alpha=0.95, zorder=6)

        sites = gpd.sjoin(sites, renewables_us[['renewable_pct', 'geometry']], how='left', predicate='within')

        def site_group(v):
            if pd.isna(v):
                return 'Unknown'
            if v >= 50:
                return 'High renewables (>=50%)'
            if v >= 30:
                return 'Medium renewables (30-50%)'
            return 'Low renewables (<30%)'

        sites['site_group'] = sites['renewable_pct'].apply(site_group)

        bg_handles = []
        bg_title = ''

    site_styles = {
        'Low to low-medium stress': {'marker': 'o', 'color': '#0f766e'},
        'Medium-high stress': {'marker': '^', 'color': '#b45309'},
        'High/Extremely high stress': {'marker': 'X', 'color': '#b91c1c'},
        'Low renewables (<30%)': {'marker': 'X', 'color': '#b91c1c'},
        'Medium renewables (30-50%)': {'marker': '^', 'color': '#b45309'},
        'High renewables (>=50%)': {'marker': 'o', 'color': '#0f766e'},
        'Unknown': {'marker': 's', 'color': '#6b7280'},
    }

    present_groups = [g for g in site_styles if g in set(sites['site_group'])]
    for group_name in present_groups:
        style = site_styles[group_name]
        group = sites[sites['site_group'] == group_name]
        if len(group) == 0:
            continue
        group.plot(
            ax=ax,
            marker=style['marker'],
            color=style['color'],
            markersize=site_size,
            edgecolor='white',
            linewidth=0.45,
            alpha=0.95,
            zorder=8,
        )

    site_handles = [
        Line2D([0], [0], marker=site_styles[k]['marker'], linestyle='None', markerfacecolor=site_styles[k]['color'], markeredgecolor='white', markersize=9, label=k)
        for k in present_groups
    ]
    ax.legend(handles=site_handles, title='Site classification', loc='upper left', frameon=True, framealpha=0.94, fontsize=9, title_fontsize=10)

    if bg_handles:
        leg_bg = ax.legend(handles=bg_handles, title=bg_title, loc='lower left', frameon=True, framealpha=0.94, fontsize=9, title_fontsize=10)
        ax.add_artist(leg_bg)

    ax.set_xlim(US_MINX, US_MAXX)
    ax.set_ylim(US_MINY, US_MAXY)
    ax.set_title(
        f'US {background} Background + Proposed Data Centers | Projection: {projection} | Market Gravity {gravity}',
        fontsize=13,
        pad=12,
    )
    ax.set_axis_off()
    plt.tight_layout()
    plt.show()

widgets.interact(
    plot_toggle_background,
    background=widgets.ToggleButtons(
        options=['Water stress', 'Renewables %'],
        value='Water stress',
        description='Background',
    ),
    projection=widgets.ToggleButtons(
        options=projection_values,
        value='high_growth' if 'high_growth' in projection_values else projection_values[0],
        description='Projection',
    ),
    gravity=widgets.SelectionSlider(
        options=gravity_values,
        value=25 if 25 in gravity_values else gravity_values[0],
        description='Market gravity',
        continuous_update=False,
    ),
    site_size=widgets.IntSlider(
        value=30,
        min=12,
        max=80,
        step=2,
        description='Site size',
        continuous_update=False,
    ),
    show_boundaries=widgets.Checkbox(
        value=True,
        description='Show boundaries',
    ),
);


interactive(children=(ToggleButtons(description='Background', options=('Water stress', 'Renewables %'), value=…

In [None]:
# Sustainability-adjusted siting index (IM3 + water stress + renewables)
import re
import zipfile
import xml.etree.ElementTree as ET

PRESETS = {
    'Economic-first': {'im3_weight': 0.8, 'env_weight': 0.2, 'w_water': 0.5, 'w_renew': 0.5},
    'Balanced': {'im3_weight': 0.7, 'env_weight': 0.3, 'w_water': 0.6, 'w_renew': 0.4},
    'Sustainability-first': {'im3_weight': 0.5, 'env_weight': 0.5, 'w_water': 0.7, 'w_renew': 0.3},
}

def _resolve_existing_path(candidates):
    for cand in candidates:
        if Path(cand).exists():
            return Path(cand)
    return None

def _col_to_idx_local(col_ref: str) -> int:
    n = 0
    for ch in col_ref:
        if ch.isalpha():
            n = n * 26 + (ord(ch.upper()) - 64)
    return n - 1

def _read_xlsx_sheet_no_openpyxl_local(xlsx_path: Path, sheet_name: str) -> pd.DataFrame:
    ns = {
        'm': 'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
        'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
    }

    with zipfile.ZipFile(xlsx_path) as zf:
        wb = ET.fromstring(zf.read('xl/workbook.xml'))
        rels = ET.fromstring(zf.read('xl/_rels/workbook.xml.rels'))
        rel_map = {
            r.attrib['Id']: r.attrib['Target']
            for r in rels.findall('{http://schemas.openxmlformats.org/package/2006/relationships}Relationship')
        }

        sheet_rid = None
        for sh in wb.findall('m:sheets/m:sheet', ns):
            if sh.attrib['name'] == sheet_name:
                sheet_rid = sh.attrib['{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id']
                break
        if sheet_rid is None:
            raise KeyError(f'Sheet {sheet_name!r} not found in {xlsx_path.name}')

        target = rel_map[sheet_rid]
        if not target.startswith('xl/'):
            target = 'xl/' + target

        shared = []
        if 'xl/sharedStrings.xml' in zf.namelist():
            sroot = ET.fromstring(zf.read('xl/sharedStrings.xml'))
            for si in sroot.findall('m:si', ns):
                shared.append(''.join((t.text or '') for t in si.findall('.//m:t', ns)))

        sroot = ET.fromstring(zf.read(target))
        rows = []
        for row in sroot.findall('.//m:sheetData/m:row', ns):
            rec = {}
            for c in row.findall('m:c', ns):
                ref = c.attrib.get('r', 'A1')
                col = ''.join(ch for ch in ref if ch.isalpha())
                idx = _col_to_idx_local(col)

                v = c.find('m:v', ns)
                if v is None or v.text is None:
                    rec[idx] = ''
                    continue

                val = v.text
                if c.attrib.get('t') == 's':
                    val = shared[int(val)]
                rec[idx] = val
            rows.append(rec)

    width = max(max(r.keys(), default=0) for r in rows) + 1
    matrix = [[r.get(i, '') for i in range(width)] for r in rows]
    return pd.DataFrame(matrix[2:], columns=matrix[1])

def _extract_subrgn_from_description_local(desc: str) -> str | None:
    text = str(desc)
    match = re.search(r'Subregion\s*</td>\s*<td[^>]*>\s*([^<\s]+)', text, flags=re.IGNORECASE)
    if not match:
        return None
    code = match.group(1).strip()
    if code in {'&lt;Null&gt;', '<Null>', 'NULL', 'null'}:
        return None
    return code

def _build_renewables_us_local() -> gpd.GeoDataFrame:
    xlsx = _resolve_existing_path([
        ROOT / 'grid data/egrid2023_data_rev2 (2).xlsx',
        ROOT / 'datasets/egrid2023_data_rev2 (2).xlsx',
    ])
    kmz = _resolve_existing_path([
        ROOT / 'grid data/egrid2023_subregions.kmz',
        ROOT / 'datasets/egrid2023_subregions.kmz',
    ])
    if xlsx is None or kmz is None:
        print('Warning: eGRID workbook/KMZ not found; renewables will be treated as neutral (0.5).')
        return gpd.GeoDataFrame(columns=['renewable_pct', 'SUBRGN', 'geometry'], geometry='geometry', crs='EPSG:4326')

    sr = _read_xlsx_sheet_no_openpyxl_local(xlsx, 'SRL23')[['SUBRGN', 'SRNAME', 'SRTRPR']].copy()
    sr['SUBRGN'] = sr['SUBRGN'].astype(str).str.strip()
    sr['renewable_pct'] = pd.to_numeric(sr['SRTRPR'], errors='coerce') * 100.0

    subregion_gdf = gpd.read_file(kmz)[['description', 'geometry']].copy()
    subregion_gdf = subregion_gdf.set_crs('EPSG:4326') if subregion_gdf.crs is None else subregion_gdf.to_crs('EPSG:4326')
    subregion_gdf['SUBRGN'] = subregion_gdf['description'].map(_extract_subrgn_from_description_local)
    subregion_gdf = subregion_gdf.dropna(subset=['SUBRGN'])

    renewables_gdf = subregion_gdf.merge(sr[['SUBRGN', 'SRNAME', 'renewable_pct']], on='SUBRGN', how='left')
    return renewables_gdf.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()

if 'renewables_us' not in globals() or renewables_us is None or len(renewables_us) == 0:
    renewables_us = _build_renewables_us_local()

def minmax_norm(s: pd.Series) -> pd.Series:
    s = pd.to_numeric(s, errors='coerce')
    lo, hi = s.min(), s.max()
    if pd.isna(lo) or pd.isna(hi) or hi == lo:
        return pd.Series(0.5, index=s.index)
    return (s - lo) / (hi - lo)

def assign_water_level_from_code(code) -> float:
    # Convert Aqueduct class code to 1..5 where 5 is worst
    if pd.isna(code):
        return float('nan')
    c = int(code)
    if c <= 0:
        return 1
    return min(5, c + 1)

def assign_renewable_level_from_pct(pct) -> float:
    # Map renewable percentage to 1..4 where 4 is best
    if pd.isna(pct):
        return float('nan')
    if pct < 20:
        return 1
    if pct < 35:
        return 2
    if pct < 50:
        return 3
    return 4

def build_sustainability_index(
    projection: str,
    gravity: int,
    preset_name: str = 'Balanced',
    use_penalty_model: bool = False,
    penalty_p: float = 0.25,
    renew_bonus_b: float = 0.10,
) -> pd.DataFrame:
    cfg = PRESETS[preset_name]

    dc = dc_by_projection_and_gravity[projection][int(gravity)].copy()
    dc = dc.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()
    dc['geometry'] = dc.geometry.representative_point()

    water_join = water_us[[WATER_STRESS_CODE_FIELD, WATER_STRESS_LABEL_FIELD, 'geometry']].dropna(subset=[WATER_STRESS_CODE_FIELD])
    renew_join = renewables_us[['renewable_pct', 'SUBRGN', 'geometry']].dropna(subset=['renewable_pct']) if len(renewables_us) else renewables_us

    out = gpd.sjoin(dc, water_join, how='left', predicate='within')
    out = out.drop(columns=['index_right'], errors='ignore')

    if len(renew_join):
        out = gpd.sjoin(out, renew_join, how='left', predicate='within')
        out = out.drop(columns=['index_right'], errors='ignore')
    else:
        out['renewable_pct'] = float('nan')
        out['SUBRGN'] = None

    out['im3_norm'] = minmax_norm(out['weighted_siting_score'])
    out['water_stress_level'] = out[WATER_STRESS_CODE_FIELD].apply(assign_water_level_from_code)
    out['renewable_level'] = out['renewable_pct'].apply(assign_renewable_level_from_pct)

    out['water_good'] = (5 - out['water_stress_level']) / 4
    out['renew_good'] = (out['renewable_level'] - 1) / 3

    out['water_good'] = out['water_good'].fillna(0.5)
    out['renew_good'] = out['renew_good'].fillna(0.5)

    out['env_score'] = cfg['w_water'] * out['water_good'] + cfg['w_renew'] * out['renew_good']

    if use_penalty_model:
        out['water_bad'] = (out['water_stress_level'] - 1) / 4
        out['water_bad'] = out['water_bad'].fillna(0.5)
        out['final_index'] = out['im3_norm'] * (1 - penalty_p * out['water_bad']) + renew_bonus_b * out['renew_good']
        out['model_type'] = f'Penalty model (p={penalty_p:.2f}, b={renew_bonus_b:.2f})'
    else:
        out['final_index'] = cfg['im3_weight'] * out['im3_norm'] + cfg['env_weight'] * out['env_score']
        out['model_type'] = (
            f"Weighted blend (IM3={cfg['im3_weight']:.1f}, Env={cfg['env_weight']:.1f}; "
            f"Water={cfg['w_water']:.1f}, Renew={cfg['w_renew']:.1f})"
        )

    out['rank'] = out['final_index'].rank(method='min', ascending=False).astype(int)
    return out.sort_values('final_index', ascending=False).reset_index(drop=True)

def preview_sustainability_index(
    projection: str,
    gravity: int,
    preset_name: str = 'Balanced',
    use_penalty_model: bool = False,
    penalty_p: float = 0.25,
    renew_bonus_b: float = 0.10,
    top_n: int = 20,
):
    res = build_sustainability_index(
        projection=projection,
        gravity=gravity,
        preset_name=preset_name,
        use_penalty_model=use_penalty_model,
        penalty_p=penalty_p,
        renew_bonus_b=renew_bonus_b,
    )

    print('Model:', res['model_type'].iloc[0])
    print('Projection:', projection, '| Gravity:', gravity, '| Rows:', len(res))
    print('Caution: IM3 may already encode water/cooling effects, so high env weights can double-count.')
    print('Final index summary:', {
        'min': round(float(res['final_index'].min()), 4),
        'mean': round(float(res['final_index'].mean()), 4),
        'max': round(float(res['final_index'].max()), 4),
    })

    cols = [
        'rank', 'id', 'final_index', 'region', 'weighted_siting_score', 'im3_norm',
        'water_stress_level', 'renewable_level', 'water_good', 'renew_good',
        'env_score'
    ]
    cols = [c for c in cols if c in res.columns]
    display(res[cols].head(int(top_n)))

widgets.interact(
    preview_sustainability_index,
    projection=widgets.ToggleButtons(
        options=projection_values,
        value='high_growth' if 'high_growth' in projection_values else projection_values[0],
        description='Projection',
    ),
    gravity=widgets.SelectionSlider(
        options=gravity_values,
        value=25 if 25 in gravity_values else gravity_values[0],
        description='Market gravity',
        continuous_update=False,
    ),
    preset_name=widgets.Dropdown(
        options=list(PRESETS.keys()),
        value='Balanced',
        description='Preset',
    ),
    use_penalty_model=widgets.Checkbox(
        value=False,
        description='Use penalty model',
    ),
    penalty_p=widgets.FloatSlider(
        value=0.25,
        min=0.0,
        max=0.6,
        step=0.05,
        description='Penalty p',
        continuous_update=False,
    ),
    renew_bonus_b=widgets.FloatSlider(
        value=0.10,
        min=0.0,
        max=0.3,
        step=0.02,
        description='Renew b',
        continuous_update=False,
    ),
    top_n=widgets.IntSlider(
        value=20,
        min=5,
        max=100,
        step=5,
        description='Top N',
        continuous_update=False,
    ),
);


interactive(children=(ToggleButtons(description='Projection', options=('high_growth', 'higher_growth', 'low_gr…

In [47]:
# Polygons by environmental score + points by sustainability score (robust rendering)
def plot_env_polygons_and_site_sustainability(
    projection: str,
    gravity: int,
    preset_name: str = 'Balanced',
    use_penalty_model: bool = False,
    penalty_p: float = 0.25,
    renew_bonus_b: float = 0.10,
    site_size: int = 24,
    label_regions: bool = False,
):
    cfg = PRESETS[preset_name]

    # Start from projected candidate sites
    dc = dc_by_projection_and_gravity[projection][int(gravity)].copy()
    dc = dc.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()
    dc['geometry'] = dc.geometry.representative_point()

    # Attach water class and renewables %
    water_join = water_us[[WATER_STRESS_CODE_FIELD, 'geometry']].dropna(subset=[WATER_STRESS_CODE_FIELD])
    scored = gpd.sjoin(dc, water_join, how='left', predicate='within').drop(columns=['index_right'], errors='ignore')

    if 'renewables_us' in globals() and renewables_us is not None and len(renewables_us):
        renew_join = renewables_us[['renewable_pct', 'geometry']].dropna(subset=['renewable_pct'])
        if len(renew_join):
            scored = gpd.sjoin(scored, renew_join, how='left', predicate='within').drop(columns=['index_right'], errors='ignore')
        else:
            scored['renewable_pct'] = pd.NA
    else:
        scored['renewable_pct'] = pd.NA

    # Environmental-only score components
    scored['water_stress_level'] = scored[WATER_STRESS_CODE_FIELD].apply(assign_water_level_from_code)
    scored['water_good'] = (5 - scored['water_stress_level']) / 4

    scored['renewable_level'] = scored['renewable_pct'].apply(assign_renewable_level_from_pct)
    scored['renew_good'] = (scored['renewable_level'] - 1) / 3

    # Robust fallback so polygons never go blank from NaNs
    scored['water_good'] = pd.to_numeric(scored['water_good'], errors='coerce').fillna(0.5)
    scored['renew_good'] = pd.to_numeric(scored['renew_good'], errors='coerce').fillna(0.5)

    scored['env_score'] = cfg['w_water'] * scored['water_good'] + cfg['w_renew'] * scored['renew_good']

    # Sustainability score for points
    scored['im3_norm'] = minmax_norm(scored['weighted_siting_score'])
    if use_penalty_model:
        scored['water_bad'] = ((scored['water_stress_level'] - 1) / 4).fillna(0.5)
        scored['final_index'] = scored['im3_norm'] * (1 - penalty_p * scored['water_bad']) + renew_bonus_b * scored['renew_good']
    else:
        scored['final_index'] = cfg['im3_weight'] * scored['im3_norm'] + cfg['env_weight'] * scored['env_score']

    # Region polygons + environmental score (polygon heatmap)
    region_env = (
        scored.groupby('region', dropna=False)['env_score']
        .mean()
        .rename('region_environment_score')
        .reset_index()
    )

    dc_region_geom = dc_by_projection_and_gravity[projection][int(gravity)][['region', 'geometry']].copy()
    dc_region_geom = dc_region_geom.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()
    region_geom = dc_region_geom.dissolve(by='region', as_index=False)
    region_map = region_geom.merge(region_env, on='region', how='left')
    region_map['region_environment_score'] = pd.to_numeric(region_map['region_environment_score'], errors='coerce').fillna(0.5)

    fig, ax = plt.subplots(figsize=(14, 8.5), facecolor='#f7f6f2')
    ax.set_facecolor('#eef6ef')

    region_map.plot(
        column='region_environment_score',
        cmap='RdYlGn',
        linewidth=0.8,
        edgecolor='#f5f5f5',
        alpha=0.95,
        legend=True,
        ax=ax,
        vmin=0,
        vmax=1,
        legend_kwds={'label': 'Combined environmental score (region mean)', 'shrink': 0.75},
        zorder=2,
    )

    # Site points heatmapped by sustainability score
    pts = scored.dropna(subset=['geometry', 'final_index']).copy()
    sc = ax.scatter(
        pts.geometry.x, pts.geometry.y,
        c=pts['final_index'],
        cmap='plasma',
        vmin=0, vmax=1,
        s=site_size,
        edgecolors='white',
        linewidths=0.3,
        alpha=0.9,
        zorder=7,
    )
    cbar = plt.colorbar(sc, ax=ax, fraction=0.025, pad=0.01)
    cbar.set_label('Site sustainability score (final_index)')

    if 'us_outline' in globals() and us_outline is not None:
        us_outline.plot(ax=ax, color='#1f2937', linewidth=1.2, alpha=0.95, zorder=9)

    if label_regions:
        labels = region_map.copy()
        labels['geometry'] = labels.geometry.representative_point()
        for _, row in labels.iterrows():
            if row.geometry is None:
                continue
            ax.text(row.geometry.x, row.geometry.y, str(row['region']).title(), fontsize=7, ha='center', va='center', color='#1f2937', zorder=10)

    ax.set_xlim(US_MINX, US_MAXX)
    ax.set_ylim(US_MINY, US_MAXY)
    ax.set_title(
        f'Environmental Polygon Heatmap + Site Sustainability Points | Projection: {projection} | Market Gravity {gravity}',
        fontsize=13,
        pad=12,
    )
    ax.set_axis_off()
    plt.tight_layout()
    plt.show()

widgets.interact(
    plot_env_polygons_and_site_sustainability,
    projection=widgets.ToggleButtons(
        options=projection_values,
        value='high_growth' if 'high_growth' in projection_values else projection_values[0],
        description='Projection',
    ),
    gravity=widgets.SelectionSlider(
        options=gravity_values,
        value=25 if 25 in gravity_values else gravity_values[0],
        description='Market gravity',
        continuous_update=False,
    ),
    preset_name=widgets.Dropdown(
        options=list(PRESETS.keys()),
        value='Balanced',
        description='Preset',
    ),
    use_penalty_model=widgets.Checkbox(
        value=False,
        description='Use penalty model',
    ),
    penalty_p=widgets.FloatSlider(
        value=0.25,
        min=0.0,
        max=0.6,
        step=0.05,
        description='Penalty p',
        continuous_update=False,
    ),
    renew_bonus_b=widgets.FloatSlider(
        value=0.10,
        min=0.0,
        max=0.3,
        step=0.02,
        description='Renew b',
        continuous_update=False,
    ),
    site_size=widgets.IntSlider(
        value=24,
        min=8,
        max=80,
        step=2,
        description='Site size',
        continuous_update=False,
    ),
    label_regions=widgets.Checkbox(
        value=False,
        description='Label regions',
    ),
);


interactive(children=(ToggleButtons(description='Projection', options=('high_growth', 'higher_growth', 'low_gr…

In [34]:
# Export drought-stress CSVs (one file per growth scenario, no renewables column)
def _water_level_from_code(code):
    # Aqueduct class code to 1..5 drought factor (5 = worst)
    if pd.isna(code):
        return pd.NA
    c = int(code)
    return 1 if c <= 0 else min(5, c + 1)

export_dir = ROOT / 'exports' / 'projected_site_drought_stress'
export_dir.mkdir(parents=True, exist_ok=True)

for projection in sorted(dc_by_projection_and_gravity.keys()):
    rows = []
    for gravity in sorted(dc_by_projection_and_gravity[projection].keys()):
        dc = dc_by_projection_and_gravity[projection][gravity].copy()
        dc = dc.cx[US_MINX:US_MAXX, US_MINY:US_MAXY].copy()
        dc['geometry'] = dc.geometry.representative_point()

        water_join = water_us[[WATER_STRESS_CODE_FIELD, 'geometry']].dropna(subset=[WATER_STRESS_CODE_FIELD])
        out = gpd.sjoin(dc, water_join, how='left', predicate='within')
        out = out.drop(columns=['index_right'], errors='ignore')

        out['drought_stress_factor_1_to_5'] = out[WATER_STRESS_CODE_FIELD].apply(_water_level_from_code).astype('Int64')

        keep = out[['id', 'drought_stress_factor_1_to_5']].copy()
        keep['growth_scenario'] = projection
        keep['market_gravity'] = gravity
        rows.append(keep)

    scenario_df = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['id', 'drought_stress_factor_1_to_5', 'growth_scenario', 'market_gravity'])
    scenario_path = export_dir / f'{projection}_projected_sites_drought_stress.csv'
    scenario_df.to_csv(scenario_path, index=False)
    print(f'Wrote {scenario_path} ({len(scenario_df)} rows)')

print('Done. Export folder:', export_dir)


Wrote exports/projected_site_drought_stress/high_growth_projected_sites_drought_stress.csv (4545 rows)
Wrote exports/projected_site_drought_stress/higher_growth_projected_sites_drought_stress.csv (9350 rows)
Wrote exports/projected_site_drought_stress/low_growth_projected_sites_drought_stress.csv (1110 rows)
Wrote exports/projected_site_drought_stress/moderate_growth_projected_sites_drought_stress.csv (1640 rows)
Done. Export folder: exports/projected_site_drought_stress


In [46]:
# Export a US-only annual Aqueduct water stress CSV (smaller file)
annual_csv = ROOT / 'datasets/Aqueduct40_waterrisk_download_Y2023M07D05/CVS/Aqueduct40_future_annual_y2023m07d05.csv'
annual_gdb = ROOT / 'datasets/Aqueduct40_waterrisk_download_Y2023M07D05/GDB/Aq40_Y2023D07M05.gdb'

if not annual_csv.exists() or not annual_gdb.exists():
    raise FileNotFoundError('Missing annual CSV or GDB path for Aqueduct data.')

# Use basin geometry extent to identify US pfaf_id values
us_pfaff = gpd.read_file(annual_gdb, layer='future_annual')[['pfaf_id', 'geometry']].copy()
us_pfaff = us_pfaff.set_crs('EPSG:4326') if us_pfaff.crs is None else us_pfaff.to_crs('EPSG:4326')
us_pfaff = us_pfaff.cx[US_MINX:US_MAXX, US_MINY:US_MAXY]
us_pfaff_ids = set(pd.to_numeric(us_pfaff['pfaf_id'], errors='coerce').dropna().astype('int64').tolist())

annual_df = pd.read_csv(annual_csv)
annual_df['pfaf_id'] = pd.to_numeric(annual_df['pfaf_id'], errors='coerce').astype('Int64')
waterstress_us = annual_df[annual_df['pfaf_id'].isin(us_pfaff_ids)].copy()

out_path = ROOT / 'datasets/Aqueduct40_waterrisk_download_Y2023M07D05/CVS/waterstress_us.csv'
waterstress_us.to_csv(out_path, index=False)

print('Wrote:', out_path)
print('Rows:', len(waterstress_us), '| Columns:', len(waterstress_us.columns))
print('Original rows:', len(annual_df), '| US rows:', len(waterstress_us))


Wrote: datasets/Aqueduct40_waterrisk_download_Y2023M07D05/CVS/waterstress_us.csv
Rows: 1303 | Columns: 182
Original rows: 16395 | US rows: 1303
