In [1]:
import pandas as pd
import geopandas as gpd

In [12]:
# Load the lag analysis data
lag_analysis_cal = pd.read_csv('../data/features/q_buffer2_pair_lag.csv')

# Run the watershed correlation analysis with inset maps
# run_watershed_correlation_analysis()

In [21]:
no_lag = pd.read_csv('../data/features/q_buffer2_pair_delta_30m.csv')
no_lag.head()

Unnamed: 0,well_id,date,wte,gse,gage_id,well_lat,well_lon,gage_lat,gage_lon,wte_meters,q,bfd,delta_wte,delta_q
0,411605111481601,1932-08-08,4833.81,4840.0,10141000,41.267997,-111.805216,41.278277,-112.091887,1473.345288,38.0,1.0,0.0,0.0
1,411605111481601,1932-08-09,4833.807273,4840.0,10141000,41.267997,-111.805216,41.278277,-112.091887,1473.344457,38.0,1.0,-0.002727,0.0
2,411605111481601,1932-08-10,4833.80463,4840.0,10141000,41.267997,-111.805216,41.278277,-112.091887,1473.343651,34.0,1.0,-0.00537,-4.0
3,411605111481601,1932-08-11,4833.80207,4840.0,10141000,41.267997,-111.805216,41.278277,-112.091887,1473.342871,9.0,1.0,-0.00793,-29.0
4,411605111481601,1932-08-12,4833.799591,4840.0,10141000,41.267997,-111.805216,41.278277,-112.091887,1473.342115,9.0,1.0,-0.010409,-29.0


In [2]:
subbasin_gdf = gpd.read_file('../data/raw/hydrography/gsl_catchment.shp')
gage_df = pd.read_csv('../data/raw/hydrography/gsl_nwm_gage.csv')
well_gdf = gpd.read_file('../data/raw/hydrography/well_shp.shp')
stream_gdf = gpd.read_file('../data/raw/hydrography/gslb_stream.shp')
lake_gdf = gpd.read_file('../data/raw/hydrography/lake.shp')

In [7]:
reach_distances = pd.read_csv('../data/processed/well_reach_relationships_final.csv').rename(columns=str.lower)
reach_distances.head()


Unnamed: 0,well_id,reach_id,reach_elevation,distance_to_reach,downstream_gage
0,381033100000000.0,710579638.0,2018.0,15.997118,
1,381037100000000.0,710579638.0,2018.0,387.946514,
2,381152100000000.0,710579638.0,2018.0,642.085536,
3,381236100000000.0,710258231.0,2033.5,95.71774,
4,382113100000000.0,710549872.0,1770.0,375.764745,


In [13]:
well_gage = pd.read_csv('../data/processed/wells_with_catchment_info.csv')
well_gage.head()

Unnamed: 0,well_id,well_name,lat_dec,long_dec,gse,geometry,catchment_id,gage_id,gage_name
0,394618100000000.0,(D-12- 4)16ccb- 1,39.771627,-111.488248,5980.0,POINT (-111.4882484 39.7716267),710648999.0,10152000.0,SPANISH FORK NEAR LAKE SHORE - UTAH
1,394634100000000.0,(D-12- 4)16bcc- 1,39.776071,-111.488248,5960.0,POINT (-111.4882483 39.7760711),710647014.0,10152000.0,SPANISH FORK NEAR LAKE SHORE - UTAH
2,394643100000000.0,(D-12- 4)16bcb- 1,39.778571,-111.48797,5955.0,POINT (-111.4879704 39.77857106),710647014.0,10152000.0,SPANISH FORK NEAR LAKE SHORE - UTAH
3,394649100000000.0,(D-12- 4)17abd- 1,39.780238,-111.494915,5938.0,POINT (-111.4949154 39.78023766),710549798.0,10152000.0,SPANISH FORK NEAR LAKE SHORE - UTAH
4,394746100000000.0,(D-12- 4) 9bab- 1,39.796071,-111.48297,5977.0,POINT (-111.4829697 39.79607075),710647014.0,10152000.0,SPANISH FORK NEAR LAKE SHORE - UTAH


# plot the top 5 r squared wells

In [9]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
from scipy.stats import linregress
import warnings

warnings.filterwarnings('ignore')

# Set matplotlib style for clean, minimal maps
plt.style.use('default')
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['font.size'] = 10


def calculate_well_gage_correlations(lag_data, min_points=10):
    """
    Calculate R-squared values between each well and its associated gage
    using a lagged relationship: delta_q(t) vs delta_wte_lag1_year(t).

    This function:
    - Drops NaNs in both variables
    - Skips pairs with constant X or Y (to avoid linregress errors)
    - Returns a DataFrame with one row per well–gage pair
    """
    print("=== Calculating Well-Gage Correlations (ΔQ vs ΔWTE_lag1yr) ===")

    # Group by well and gage to calculate R-squared
    correlation_results = []
    grouped = lag_data.groupby(['well_id', 'gage_id'])
    print(f"Processing {len(grouped)} well-gage pairs...")

    skipped_low_n = 0
    skipped_constant_x = 0
    skipped_constant_y = 0

    for (well_id, gage_id), group in tqdm(grouped, desc="Calculating correlations"):
        # Remove NaN values in the variables used
        clean_data = group.dropna(subset=['delta_wte_lag1_year', 'delta_q'])

        # Enforce minimum observations
        if len(clean_data) < min_points:
            skipped_low_n += 1
            continue

        # Require variability in X and Y
        if clean_data['delta_wte_lag1_year'].nunique() <= 1:
            skipped_constant_x += 1
            continue
        if clean_data['delta_q'].nunique() <= 1:
            skipped_constant_y += 1
            continue

        try:
            # Linear regression of ΔQ(t) on ΔWTE(t-1yr)
            slope, intercept, r_value, p_value, std_err = linregress(
                clean_data['delta_wte_lag1_year'],
                clean_data['delta_q']
            )
            r_squared = r_value ** 2

            # Extract coordinates (assumed constant within the group)
            well_lat = clean_data['well_lat'].iloc[0]
            well_lon = clean_data['well_lon'].iloc[0]
            gage_lat = clean_data['gage_lat'].iloc[0]
            gage_lon = clean_data['gage_lon'].iloc[0]

            correlation_results.append({
                'well_id': well_id,
                'gage_id': gage_id,
                'r_squared': r_squared,
                'r_value': r_value,
                'p_value': p_value,
                'n_observations': len(clean_data),
                'well_lat': well_lat,
                'well_lon': well_lon,
                'gage_lat': gage_lat,
                'gage_lon': gage_lon
            })
        except Exception as e:
            # Defensive: this should be very rare after checks above
            print(f"Error calculating correlation for well {well_id}, gage {gage_id}: {e}")

    correlation_df = pd.DataFrame(correlation_results)
    print(f"Successfully calculated correlations for {len(correlation_df)} well-gage pairs")
    print(f"Skipped (n<{min_points}): {skipped_low_n} | constant X: {skipped_constant_x} | constant Y: {skipped_constant_y}")

    return correlation_df


def load_watershed_data():
    """Load watershed relationship data (terminal gage → upstream catchments)."""
    try:
        df = pd.read_csv('../data/processed/terminal_gage_upstream_catchments.csv')

        # Standardize column names
        rename_map = {}
        if 'Gage_ID' in df.columns:
            rename_map['Gage_ID'] = 'gage_id'
        if 'Upstream_Catchment_ID' in df.columns:
            rename_map['Upstream_Catchment_ID'] = 'upstream_catchment_id'
        if rename_map:
            df = df.rename(columns=rename_map)

        # Enforce types where possible
        if 'gage_id' in df.columns:
            df['gage_id'] = pd.to_numeric(df['gage_id'], errors='coerce').astype('Int64')
        if 'upstream_catchment_id' in df.columns:
            df['upstream_catchment_id'] = pd.to_numeric(df['upstream_catchment_id'], errors='coerce').astype('Int64')

        return df
    except Exception as e:
        print(f"Warning: Could not load watershed relationships: {e}")
        return None


def _ensure_point_in_subbasin_crs(lon, lat, subbasin_gdf):
    """
    Build a GeoSeries point at (lon, lat) and convert to the subbasin CRS if needed.
    """
    from shapely.geometry import Point
    gpt = gpd.GeoSeries([Point(lon, lat)], crs="EPSG:4326")
    try:
        if subbasin_gdf.crs is not None and subbasin_gdf.crs != "EPSG:4326":
            gpt = gpt.to_crs(subbasin_gdf.crs)
    except Exception as e:
        print(f"CRS transform failed; proceeding in EPSG:4326. Error: {e}")
    return gpt


def get_gage_terminal_basin(gage_id, subbasin_gdf, gage_df):
    """
    Get the terminal basin (catchment polygon) that contains the gage.
    If containment fails (e.g., due to topology or slight offsets), use nearest.
    Handles CRS properly.
    """
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        print(f"Warning: No gage info found for gage {gage_id}")
        return gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']
    print(f"Debug: Looking for terminal basin for gage {gage_id} at ({gage_lon}, {gage_lat})")

    gpt = _ensure_point_in_subbasin_crs(gage_lon, gage_lat, subbasin_gdf)

    # Containment test
    try:
        containing = subbasin_gdf[subbasin_gdf.geometry.contains(gpt.iloc[0])]
        if len(containing) > 0:
            idx = containing.index[0]
            print(f"Debug: Found containing basin {idx} for gage {gage_id}")
            return subbasin_gdf.loc[[idx]]
    except Exception as e:
        print(f"Debug: Containment check failed for gage {gage_id}: {e}")

    # Fallback: nearest polygon (distance computed in subbasin CRS)
    print(f"Debug: No containing basin found, finding nearest for gage {gage_id}")
    try:
        distances = subbasin_gdf.geometry.distance(gpt.iloc[0])
        nearest_idx = distances.idxmin()
        nearest_dist = distances.min()
        print(f"Debug: Nearest basin {nearest_idx} at distance {nearest_dist} for gage {gage_id}")
        return subbasin_gdf.loc[[nearest_idx]]
    except Exception as e:
        print(f"Warning: Nearest basin selection failed for gage {gage_id}: {e}")
        return gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)


def get_gage_watersheds(gage_id, subbasin_gdf, gage_df, terminal_relationships=None):
    """
    Get the terminal basin and all upstream basins for a gage.
    - If relationships are provided, select upstream catchments by linkno/LINKNO.
    - The terminal basin is excluded from the upstream set.
    """
    terminal_basin = get_gage_terminal_basin(gage_id, subbasin_gdf, gage_df)
    print(f"Debug: Gage {gage_id} - Terminal basin found: {not terminal_basin.empty}")

    upstream_basins = gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    if terminal_relationships is not None and 'gage_id' in terminal_relationships.columns:
        # Identify the linkno column
        linkno_col = None
        for cand in ['linkno', 'LINKNO', 'LinkNo', 'LINK_NO']:
            if cand in subbasin_gdf.columns:
                linkno_col = cand
                break

        if linkno_col is None:
            print("Warning: No linkno-like column found in subbasin_gdf; upstream basins cannot be resolved.")
            return terminal_basin, upstream_basins

        # Get upstream catchments list for this gage
        upstream_catchments = terminal_relationships.loc[
            terminal_relationships['gage_id'] == gage_id, 'upstream_catchment_id'
        ].dropna().astype(int).tolist()

        print(f"Debug: Gage {gage_id} - Upstream catchments: {len(upstream_catchments)}")

        if upstream_catchments:
            upstream_basins = subbasin_gdf[subbasin_gdf[linkno_col].astype(int).isin(upstream_catchments)].copy()
            # Remove terminal basin if included
            if not terminal_basin.empty:
                terminal_linkno = int(terminal_basin.iloc[0][linkno_col])
                upstream_basins = upstream_basins[upstream_basins[linkno_col].astype(int) != terminal_linkno]

    return terminal_basin, upstream_basins


def create_clean_correlation_maps_with_watersheds(
    correlation_df,
    lag_data,
    save_dir='../reports/figures/gage_well_correlations_watershed'
):
    """
    Create clean correlation maps with watershed boundaries and overview inset.

    Key changes:
    - Iterate over gages found in the *input lag_data*, not only those that survived correlation,
      so gages with 0 valid wells still get a QA map (terminal basin + star).
    - Defensive plotting when no wells exist.
    """
    print("=== Creating Clean Correlation Maps with Watersheds ===")
    os.makedirs(save_dir, exist_ok=True)

    # Load geographic data
    try:
        subbasin_gdf = gpd.read_file('../data/raw/hydrography/gsl_catchment.shp')
        stream_gdf = gpd.read_file('../data/raw/hydrography/gslb_stream.shp')
        lake_gdf = gpd.read_file('../data/raw/hydrography/lake.shp')
        gage_df = pd.read_csv('../data/raw/hydrography/gsl_nwm_gage.csv')

        # Load watershed relationships
        terminal_relationships = load_watershed_data()

        print("✅ Geographic data loaded successfully")
        print(f"Subbasin columns: {list(subbasin_gdf.columns)}")
        print(f"Subbasin CRS: {subbasin_gdf.crs}")
        print(f"Total catchments: {len(subbasin_gdf)}")
    except Exception as e:
        print(f"❌ Error loading geographic data: {e}")
        return

    # Use all gages from the input lag dataset to ensure coverage
    unique_gages = pd.Series(lag_data['gage_id'].dropna().unique()).astype(int).tolist()
    print(f"Creating watershed maps for {len(unique_gages)} gages...")

    for gage_id in tqdm(unique_gages, desc="Creating watershed correlation maps"):
        try:
            result = create_single_watershed_map(
                gage_id,
                correlation_df,
                subbasin_gdf,
                stream_gdf,
                lake_gdf,
                gage_df,
                terminal_relationships,
                save_dir
            )
            if result is False:
                print(f"⚠️ Skipped map creation for Gage {gage_id}")
        except Exception as e:
            print(f"❌ Failed to create map for Gage {gage_id}: {e}")
            import traceback
            traceback.print_exc()

    print(f"✅ All watershed correlation maps saved to: {save_dir}")


def add_labels_with_leader_lines(ax, wells_data, gage_lon, gage_lat, max_distance=0.05):
    """
    Add R² labels with leader lines to avoid overlaps.
    If wells_data is empty, this function does nothing.
    """
    if wells_data is None or wells_data.empty:
        return

    positions = []
    for i, (_, well) in enumerate(wells_data.iterrows()):
        number = i + 1
        well_x, well_y = well['well_lon'], well['well_lat']

        # Try different angles around the well to find non-overlapping position
        angles = [45, 135, 315, 225, 90, 270, 0, 180]
        best_pos = None
        min_conflict = float('inf')

        for angle in angles:
            angle_rad = np.radians(angle)
            label_x = well_x + max_distance * np.cos(angle_rad)
            label_y = well_y + max_distance * np.sin(angle_rad)

            conflict_count = 0
            gage_dist = np.sqrt((label_x - gage_lon)**2 + (label_y - gage_lat)**2)
            if gage_dist < max_distance * 0.6:
                conflict_count += 2

            for other_pos in positions:
                other_dist = np.sqrt((label_x - other_pos[0])**2 + (label_y - other_pos[1])**2)
                if other_dist < max_distance * 0.7:
                    conflict_count += 1

            if conflict_count < min_conflict:
                min_conflict = conflict_count
                best_pos = (label_x, label_y, angle)

        if best_pos:
            label_x, label_y, angle = best_pos
            positions.append((label_x, label_y))

            ax.plot([well_x, label_x], [well_y, label_y],
                    color='black', linewidth=1.2, alpha=0.7, zorder=11)

            ha = 'left' if angle < 180 else 'right'
            va = 'bottom' if 45 <= angle <= 135 else 'top'
            ax.text(label_x, label_y,
                    f'{number}: R² = {well["r_squared"]:.3f}',
                    ha=ha, va=va,
                    fontsize=9, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
                              alpha=0.95, edgecolor='black', linewidth=1),
                    zorder=12)
        else:
            # Fallback to simple offset
            offset_x = 0.025 if i % 2 == 0 else -0.025
            offset_y = 0.020 if i < 3 else -0.010
            ax.text(well_x + offset_x, well_y + offset_y,
                    f'{number}: R² = {well["r_squared"]:.3f}',
                    ha='left' if offset_x > 0 else 'right',
                    va='bottom' if offset_y > 0 else 'top',
                    fontsize=9, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
                              alpha=0.95, edgecolor='black', linewidth=1),
                    zorder=12)


def create_single_watershed_map(
    gage_id,
    correlation_df,
    subbasin_gdf,
    stream_gdf,
    lake_gdf,
    gage_df,
    terminal_relationships,
    save_dir
):
    """
    Create a single clean correlation map with watershed boundaries and inset map.

    Defensive behaviors:
    - If gage has 0 valid wells, still plot the terminal basin + gage star, and a stats box (Wells: 0).
    - If R² is constant (or a single value), avoid normalizer/colorbar errors.
    """
    print(f"\n=== Processing Gage {gage_id} ===")

    # Well correlations available for this gage
    gage_wells = correlation_df[correlation_df['gage_id'] == gage_id].copy()

    # Gage metadata
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        print(f"No gage info found for gage {gage_id}")
        return False

    gage_name = gage_info.iloc[0].get('name', f'Gage {gage_id}')
    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']

    # Watersheds: terminal + upstream
    terminal_basin, upstream_basins = get_gage_watersheds(
        gage_id, subbasin_gdf, gage_df, terminal_relationships
    )

    # Sort wells by R-squared and get top 5
    if not gage_wells.empty:
        gage_wells = gage_wells.sort_values('r_squared', ascending=False)
        top_5_wells = gage_wells.head(5)
    else:
        top_5_wells = gage_wells  # empty

    # Figure layout
    fig = plt.figure(figsize=(15, 10))
    fig.patch.set_facecolor('white')
    ax_main = plt.subplot2grid((10, 10), (0, 0), colspan=7, rowspan=10)
    ax_main.set_facecolor('white')

    # Plot watershed polygons
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax_main, color='#B8860B', alpha=0.8,
                            edgecolor='#8B7355', linewidth=2, zorder=1)
        print(f"✅ Plotted terminal basin for gage {gage_id}")
    else:
        print(f"⚠️ No terminal basin found for gage {gage_id}")

    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax_main, color='#F5E6D3', alpha=0.6,
                             edgecolor='#D2B48C', linewidth=1.5, zorder=1)
        print(f"✅ Plotted {len(upstream_basins)} upstream basins for gage {gage_id}")

    # Clip streams/lakes by watershed extent if possible
    all_basins = pd.concat([terminal_basin, upstream_basins], ignore_index=True)
    if not all_basins.empty:
        watershed_union = all_basins.unary_union
        local_streams = stream_gdf[stream_gdf.geometry.intersects(watershed_union)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(watershed_union)]
    else:
        # Fallback: extent around the gage and any wells (if exist)
        all_lons = [gage_lon] + (gage_wells['well_lon'].tolist() if not gage_wells.empty else [])
        all_lats = [gage_lat] + (gage_wells['well_lat'].tolist() if not gage_wells.empty else [])
        min_lon, max_lon = min(all_lons), max(all_lons)
        min_lat, max_lat = min(all_lats), max(all_lats)
        buffer = max(max_lon - min_lon, max_lat - min_lat) * 0.3 if (max_lon > min_lon and max_lat > min_lat) else 0.1

        from shapely.geometry import box
        extent_box = box(min_lon - buffer, min_lat - buffer, max_lon + buffer, max_lat + buffer)
        local_streams = stream_gdf[stream_gdf.geometry.intersects(extent_box)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(extent_box)]
        print(f"Using extent-based clipping for gage {gage_id}")

    # Base layers
    if not local_streams.empty:
        local_streams.plot(ax=ax_main, color='#4A90E2', linewidth=1.5, alpha=0.8, zorder=2)
    if not local_lakes.empty:
        local_lakes.plot(ax=ax_main, color='#E6F3FF', alpha=0.8, edgecolor='#4A90E2', linewidth=0.8, zorder=2)

    # Plot ALL wells as small gray circles first
    ax_main.scatter(
        gage_wells['well_lon'],
        gage_wells['well_lat'],
        c='lightgray',
        s=50,
        alpha=0.6,
        edgecolor='gray',
        linewidth=0.5,
        zorder=3
    )

    # Plot wells colored by R-squared using viridis colormap (yellow-green-blue)
    norm = plt.Normalize(vmin=gage_wells['r_squared'].min(), vmax=gage_wells['r_squared'].max())
    scatter = ax_main.scatter(
        gage_wells['well_lon'],
        gage_wells['well_lat'],
        c=gage_wells['r_squared'],
        cmap='viridis',  # Yellow-green-blue colormap
        s=90,
        alpha=0.9,
        edgecolor='black',
        linewidth=0.8,
        zorder=4,
        norm=norm
    )

    # Plot gage as large bright yellow star - NO LABELS
    ax_main.scatter(gage_lon, gage_lat,
              color='#FFD700',  # Bright yellow/gold
              marker='*',       # Star
              s=500,           # Much larger size
              edgecolor='black',  # Black edge
              linewidth=2,
              zorder=10)

    # Add numbers for top 5 wells
    for i, (_, well) in enumerate(top_5_wells.iterrows()):
        number = i + 1

        # Larger red marker for top 5 wells
        ax_main.scatter(well['well_lon'], well['well_lat'],
                  color='red',
                  s=180,
                  marker='o',
                  edgecolor='white',
                  linewidth=2,
                  zorder=8,
                  alpha=0.95)

        # Add number inside the marker
        ax_main.text(well['well_lon'], well['well_lat'],
               str(number),
               ha='center', va='center',
               fontsize=11, fontweight='bold',
               color='white',
               zorder=9)

    # Add R-squared labels with leader lines
    add_labels_with_leader_lines(ax_main, top_5_wells, gage_lon, gage_lat)

    # Set main map extent
    if not all_basins.empty:
        bounds = all_basins.total_bounds
        buffer = max(bounds[2] - bounds[0], bounds[3] - bounds[1]) * 0.1
        ax_main.set_xlim(bounds[0] - buffer, bounds[2] + buffer)
        ax_main.set_ylim(bounds[1] - buffer, bounds[3] + buffer)
    else:
        # Fallback to well and gage extent
        all_lons = list(gage_wells['well_lon']) + [gage_lon]
        all_lats = list(gage_wells['well_lat']) + [gage_lat]
        lon_buffer = (max(all_lons) - min(all_lons)) * 0.2
        lat_buffer = (max(all_lats) - min(all_lats)) * 0.2
        ax_main.set_xlim(min(all_lons) - lon_buffer, max(all_lons) + lon_buffer)
        ax_main.set_ylim(min(all_lats) - lat_buffer, max(all_lats) + lat_buffer)

    # Style main map
    ax_main.set_title(f'Well–Gage Correlations (Lagged): {gage_name}\nGage ID: {gage_id}',
                      fontsize=14, fontweight='bold', pad=15)
    ax_main.set_xlabel('Longitude', fontsize=11)
    ax_main.set_ylabel('Latitude', fontsize=11)
    for spine in ax_main.spines.values():
        spine.set_visible(False)
    ax_main.grid(True, alpha=0.3, linestyle='-', linewidth=0.5, color='lightgray')
    ax_main.tick_params(labelsize=9)

    # Inset
    ax_inset = plt.subplot2grid((10, 10), (0, 7), colspan=3, rowspan=3)
    ax_inset.set_facecolor('white')
    plot_overview_inset(ax_inset, subbasin_gdf, stream_gdf, lake_gdf,
                        terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id)

    # Colorbar only if we actually plotted colored wells
    if scatter is not None:
        cbar_ax = plt.subplot2grid((10, 10), (4, 7), colspan=3, rowspan=1)
        cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('R² (Correlation Strength)', fontsize=10, fontweight='bold')
        cbar.ax.tick_params(labelsize=8)

    # Stats + legend
    add_statistics_and_legend(fig, ax_main, gage_wells, gage_id, terminal_basin, upstream_basins)

    # Save
    safe_name = gage_name.replace('/', '_').replace('\\', '_')[:40]
    filename = f"gage_{gage_id}_watershed_{safe_name}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none', pad_inches=0.1)
    plt.close()

    print(f"✅ Saved map for gage {gage_id}: {filename}")
    return True


def plot_overview_inset(ax, subbasin_gdf, stream_gdf, lake_gdf,
                        terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id):
    """
    Plot the basin-wide overview inset with highlighted terminal/upstream basins.
    """
    # All catchments
    subbasin_gdf.plot(ax=ax, color='#FAFAFA', edgecolor='#B0B0B0',
                      linewidth=0.6, alpha=0.9)

    # Lakes for context
    lake_gdf.plot(ax=ax, color='#D6EAF8', alpha=0.9,
                  edgecolor='#3498DB', linewidth=0.6)

    # Highlight watersheds
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax, color='#CD853F', alpha=0.8,
                            edgecolor='#8B7355', linewidth=1.5)
    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax, color='#F5E6D3', alpha=0.7,
                             edgecolor='#D2B48C', linewidth=1.2)

    # Gage star (small)
    ax.scatter(gage_lon, gage_lat, color='#FFD700', marker='*', s=150,
               edgecolor='black', linewidth=1.5, zorder=5)

    # Full basin extent
    basin_bounds = subbasin_gdf.total_bounds
    ax.set_xlim(basin_bounds[0], basin_bounds[2])
    ax.set_ylim(basin_bounds[1], basin_bounds[3])

    # Style
    ax.set_title('Basin Overview', fontsize=10, fontweight='bold', pad=8)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    ax.set_aspect('equal')


def add_statistics_and_legend(fig, ax_main, gage_wells, gage_id, terminal_basin, upstream_basins):
    """
    Add statistics box and legend to the figure.
    Works whether or not wells exist for the gage.
    """
    # Stats text
    watershed_info = ""
    if not terminal_basin.empty:
        watershed_info += f"Terminal Basin: 1\n"
    if not upstream_basins.empty:
        watershed_info += f"Upstream Basins: {len(upstream_basins)}\n"

    if not gage_wells.empty:
        max_r2 = gage_wells['r_squared'].max()
        mean_r2 = gage_wells['r_squared'].mean()
        wells_n = len(gage_wells)
    else:
        max_r2 = float('nan')
        mean_r2 = float('nan')
        wells_n = 0

    stats_text = (f"{watershed_info}"
                  f"Wells: {wells_n}\n"
                  f"Max R²: {max_r2:.3f}" if wells_n > 0 else f"{watershed_info}Wells: 0")

    if wells_n > 0:
        stats_text += f"\nMean R²: {mean_r2:.3f}"

    ax_main.text(0.02, 0.98, stats_text, transform=ax_main.transAxes,
                 bbox=dict(boxstyle="round,pad=0.4", facecolor='white',
                           alpha=0.95, edgecolor='black', linewidth=1),
                 verticalalignment='top', fontsize=10,
                 fontweight='bold', zorder=15)

    # Legend
    legend_elements = [
        plt.Line2D([0], [0], marker='*', color='w',
                   markerfacecolor='#FFD700', markersize=18,
                   markeredgecolor='black', markeredgewidth=1.5,
                   label=f'Gage {gage_id}', linestyle='None'),
    ]

    if wells_n > 0:
        legend_elements.extend([
            plt.Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='red', markersize=10,
                       markeredgecolor='white', markeredgewidth=1,
                       label='Top 5 Wells (numbered)', linestyle='None'),
            plt.Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='lightgray', markersize=8,
                       markeredgecolor='gray', markeredgewidth=0.5,
                       label='Other Wells', linestyle='None')
        ])

    if not terminal_basin.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#B8860B', edgecolor='#8B7355',
                          alpha=0.8, label='Terminal Basin (Gage Location)')
        )
    if not upstream_basins.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#F5E6D3', edgecolor='#D2B48C',
                          alpha=0.6, label='Upstream Basins')
        )

    ax_main.legend(handles=legend_elements, loc='lower left',
                   bbox_to_anchor=(0.02, 0.02), fontsize=9,
                   frameon=True, fancybox=False, shadow=False,
                   edgecolor='black', facecolor='white', framealpha=0.95)


def create_summary_statistics(correlation_df):
    """
    Create summary statistics and save CSVs.
    """
    print("=== Creating Summary Statistics ===")
    correlation_df.to_csv('../data/processed/well_gage_correlations.csv', index=False)

    if not correlation_df.empty:
        gage_summary = correlation_df.groupby('gage_id').agg({
            'r_squared': ['count', 'mean', 'std', 'max', 'min'],
            'n_observations': ['mean', 'sum']
        }).round(4)
        gage_summary.columns = ['_'.join(col) for col in gage_summary.columns]
        gage_summary.to_csv('../data/processed/gage_correlation_summary.csv')

        print("✅ Summary statistics saved")
        print(f"Overall correlation statistics:")
        print(f"  Mean R²: {correlation_df['r_squared'].mean():.4f}")
        print(f"  Median R²: {correlation_df['r_squared'].median():.4f}")
        print(f"  Max R²: {correlation_df['r_squared'].max():.4f}")
        print(f"  Wells with R² > 0.1: {(correlation_df['r_squared'] > 0.1).sum()}")
        print(f"  Wells with R² > 0.5: {(correlation_df['r_squared'] > 0.5).sum()}")
    else:
        print("⚠️ correlation_df is empty; no summary CSVs created.")


def run_watershed_correlation_analysis():
    """
    Run the complete watershed correlation analysis.

    Steps:
    1) Load lag dataset (ΔQ vs ΔWTE_lag1yr input)
    2) Compute per-well–gage correlations (skip constants)
    3) Render per-gage maps for ALL gages in the input lag data (even with 0 wells)
    4) Save summary statistics
    """
    print("🚀 Starting Watershed Correlation Analysis")

    # Load lag data
    try:
        lag_data = pd.read_csv('../data/features/q_buffer2_pair_lag.csv')
        print(f"✅ Loaded lag data with {len(lag_data):,} records")
    except Exception as e:
        print(f"❌ Error loading lag data: {e}")
        return

    # Calculate correlations (ΔQ vs ΔWTE_lag1yr)
    correlation_df = calculate_well_gage_correlations(lag_data, min_points=10)
    if correlation_df is None:
        print("❌ No correlations calculated (None). Check your data.")
        return

    # Create watershed maps (iterate over all gages in lag_data)
    create_clean_correlation_maps_with_watersheds(correlation_df, lag_data)

    # Create summary stats
    create_summary_statistics(correlation_df)

    print("🎉 Watershed Correlation Analysis Complete!")


if __name__ == "__main__":
    try:
        run_watershed_correlation_analysis()
        print("🎉 Watershed analysis completed successfully!")
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        import traceback
        traceback.print_exc()


🚀 Starting Watershed Correlation Analysis
✅ Loaded lag data with 1,493,879 records
=== Calculating Well-Gage Correlations (ΔQ vs ΔWTE_lag1yr) ===
Processing 879 well-gage pairs...


Calculating correlations: 100%|██████████| 879/879 [00:00<00:00, 1666.19it/s]


Successfully calculated correlations for 861 well-gage pairs
Skipped (n<10): 16 | constant X: 2 | constant Y: 0
=== Creating Clean Correlation Maps with Watersheds ===
✅ Geographic data loaded successfully
Subbasin columns: ['ogc_fid', 'linkno', 'geometry']
Subbasin CRS: EPSG:4326
Total catchments: 7013
Creating watershed maps for 6 gages...


Creating watershed correlation maps:   0%|          | 0/6 [00:00<?, ?it/s]findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont


=== Processing Gage 10141000 ===
Debug: Looking for terminal basin for gage 10141000 at (-112.091887, 41.278277)
Debug: Found containing basin 6179 for gage 10141000
Debug: Gage 10141000 - Terminal basin found: True
Debug: Gage 10141000 - Upstream catchments: 359
✅ Plotted terminal basin for gage 10141000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Plotted 358 upstream basins for gage 10141000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10141000: gage_10141000_watershed_WEBER RIVER NEAR PLAIN CITY - UT.png

=== Processing Gage 10163000 ===
Debug: Looking for terminal basin for gage 10163000 at (-111.69937, 40.237732)
Debug: Found containing basin 6959 for gage 10163000
Debug: Gage 10163000 - Terminal basin found: True
Debug: Gage 10163000 - Upstream catchments: 124
✅ Plotted terminal basin for gage 10163000
✅ Plotted 123 upstream basins for gage 10163000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10163000: gage_10163000_watershed_PROVO RIVER AT PROVO - UT.png

=== Processing Gage 10152000 ===
Debug: Looking for terminal basin for gage 10152000 at (-111.726039, 40.150232)
Debug: Found containing basin 6851 for gage 10152000
Debug: Gage 10152000 - Terminal basin found: True
Debug: Gage 10152000 - Upstream catchments: 110
✅ Plotted terminal basin for gage 10152000
✅ Plotted 109 upstream basins for gage 10152000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10152000: gage_10152000_watershed_SPANISH FORK NEAR LAKE SHORE - UTAH.png

=== Processing Gage 10143500 ===
Debug: Looking for terminal basin for gage 10143500 at (-111.862993, 40.916334)
Debug: Found containing basin 4334 for gage 10143500
Debug: Gage 10143500 - Terminal basin found: True
Debug: Gage 10143500 - Upstream catchments: 1
✅ Plotted terminal basin for gage 10143500


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10143500: gage_10143500_watershed_CENTERVILLE CREEK ABV. DIV NEAR CENTERVI.png

=== Processing Gage 10126000 ===
Debug: Looking for terminal basin for gage 10126000 at (-112.100782, 41.576321)
Debug: Found containing basin 3506 for gage 10126000
Debug: Gage 10126000 - Terminal basin found: True
Debug: Gage 10126000 - Upstream catchments: 1261
✅ Plotted terminal basin for gage 10126000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Plotted 1260 upstream basins for gage 10126000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10126000: gage_10126000_watershed_BEAR RIVER NEAR CORINNE - UT.png

=== Processing Gage 10168000 ===
Debug: Looking for terminal basin for gage 10168000 at (-111.90188, 40.663836)
Debug: Found containing basin 6256 for gage 10168000
Debug: Gage 10168000 - Terminal basin found: True
Debug: Gage 10168000 - Upstream catchments: 6
✅ Plotted terminal basin for gage 10168000
✅ Plotted 5 upstream basins for gage 10168000


findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberation Sans' not found.
findfont: Font family 'Liberati

✅ Saved map for gage 10168000: gage_10168000_watershed_LITTLE COTTONWOOD CREEK @ JORDAN RIVER N.png
✅ All watershed correlation maps saved to: ../reports/figures/gage_well_correlations_watershed
=== Creating Summary Statistics ===
✅ Summary statistics saved
Overall correlation statistics:
  Mean R²: 0.0671
  Median R²: 0.0152
  Max R²: 0.8630
  Wells with R² > 0.1: 174
  Wells with R² > 0.5: 16
🎉 Watershed Correlation Analysis Complete!
🎉 Watershed analysis completed successfully!





# plot horizontal distance

In [36]:
# Fixed watershed distance mapping code with proper styling (using well_gage data)
import os
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Point
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

def create_watershed_distance_maps_styled():
    """
    Create watershed maps for each gage using the same style as correlation maps.
    Uses well_gage data for well-gage relationships.
    """
    print("=== Creating Watershed Distance Maps (Styled Version) ===")
    save_dir = "../reports/figures/watershed_distance_maps_styled"
    os.makedirs(save_dir, exist_ok=True)

    # Load all datasets
    print("Loading datasets...")

    # Geographic data
    subbasin_gdf = gpd.read_file('../data/raw/hydrography/gsl_catchment.shp')
    stream_gdf = gpd.read_file('../data/raw/hydrography/gslb_stream.shp')
    lake_gdf = gpd.read_file('../data/raw/hydrography/lake.shp')
    gage_df = pd.read_csv('../data/raw/hydrography/gsl_nwm_gage.csv')

    # Well data
    reach_distances = pd.read_csv('../data/processed/well_reach_relationships_final.csv')
    well_locations = pd.read_csv('../data/processed/well_reach.csv')[['well_id', 'well_lat', 'well_lon']].drop_duplicates()

    # CHANGED: Use well_gage data for well-gage relationships
    well_gage_data = pd.read_csv('../data/processed/wells_with_catchment_info.csv')
    well_gage_pairs = well_gage_data[['well_id', 'gage_id']].dropna().drop_duplicates()
    print(f"Using well_gage data for well-gage relationships: {len(well_gage_pairs)} unique pairs")

    # Load watershed relationships
    try:
        terminal_relationships = pd.read_csv('../data/processed/terminal_gage_upstream_catchments.csv')
        print("✅ Loaded terminal relationships")
    except FileNotFoundError:
        print("⚠️ Terminal relationships file not found, using simplified approach")
        terminal_relationships = None

    print("✅ All datasets loaded")

    # Merge data using numeric IDs
    print("Merging data using numeric approach...")

    # Convert to numeric
    reach_distances['well_id_numeric'] = pd.to_numeric(reach_distances['Well_ID'], errors='coerce')
    well_locations['well_id_numeric'] = pd.to_numeric(well_locations['well_id'], errors='coerce')
    well_gage_pairs['well_id_numeric'] = pd.to_numeric(well_gage_pairs['well_id'], errors='coerce')

    # Merge step by step
    merge1 = pd.merge(reach_distances, well_locations, on='well_id_numeric', how='inner')
    print(f"After adding coordinates: {len(merge1)} wells")

    final_data = pd.merge(merge1, well_gage_pairs, on='well_id_numeric', how='inner')
    print(f"Final merged data: {len(final_data)} wells with gage relationships")

    if len(final_data) == 0:
        print("❌ No data after merging - cannot create maps")
        return

    # Ensure CRS consistency
    subbasin_gdf = subbasin_gdf.to_crs("EPSG:4326")
    stream_gdf = stream_gdf.to_crs("EPSG:4326")
    lake_gdf = lake_gdf.to_crs("EPSG:4326")

    # Get unique gages (process all gages)
    unique_gages = final_data['gage_id'].dropna().unique()
    print(f"Creating maps for {len(unique_gages)} gages: {unique_gages}")

    # Create maps
    for gage_id in tqdm(unique_gages, desc="Creating styled maps"):
        try:
            gage_wells = final_data[final_data['gage_id'] == gage_id].copy()
            print(f"\nProcessing Gage {gage_id} with {len(gage_wells)} wells")

            if len(gage_wells) == 0:
                continue

            # Create the map
            create_single_styled_map(gage_id, gage_wells, subbasin_gdf, stream_gdf,
                                   lake_gdf, gage_df, terminal_relationships, save_dir)

        except Exception as e:
            print(f"❌ Error for gage {gage_id}: {e}")
            import traceback
            traceback.print_exc()

    print(f"✅ Maps saved to: {save_dir}")

def get_gage_watersheds_styled(gage_id, subbasin_gdf, gage_df, terminal_relationships=None):
    """Get terminal and upstream basins for a gage - styled version."""
    # Get gage info
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        return gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs), gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']

    # Create gage point
    gage_point = Point(gage_lon, gage_lat)
    gage_gdf = gpd.GeoDataFrame([1], geometry=[gage_point], crs="EPSG:4326")
    gage_gdf = gage_gdf.to_crs(subbasin_gdf.crs)
    gage_point = gage_gdf.geometry.iloc[0]

    # Find terminal basin (containing the gage)
    containing = subbasin_gdf[subbasin_gdf.geometry.contains(gage_point)]
    if len(containing) > 0:
        terminal_basin = containing.iloc[[0]]
    else:
        # Fallback: find nearest basin
        distances = subbasin_gdf.geometry.distance(gage_point)
        nearest_idx = distances.idxmin()
        terminal_basin = subbasin_gdf.iloc[[nearest_idx]]

    # Get upstream basins if relationships are available
    upstream_basins = gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    if terminal_relationships is not None:
        # Standardize column names
        if 'Gage_ID' in terminal_relationships.columns:
            gage_col = 'Gage_ID'
            upstream_col = 'Upstream_Catchment_ID'
        else:
            gage_col = 'gage_id'
            upstream_col = 'upstream_catchment_id'

        # Find linkno column in subbasin_gdf
        linkno_col = None
        for col in ['linkno', 'LINKNO', 'LinkNo', 'LINK_NO']:
            if col in subbasin_gdf.columns:
                linkno_col = col
                break

        if linkno_col and gage_col in terminal_relationships.columns:
            # Get upstream catchments for this gage
            upstream_catchments = terminal_relationships.loc[
                terminal_relationships[gage_col] == gage_id, upstream_col
            ].dropna().astype(int).tolist()

            if upstream_catchments:
                upstream_basins = subbasin_gdf[
                    subbasin_gdf[linkno_col].astype(int).isin(upstream_catchments)
                ].copy()

                # Remove terminal basin if it's included in upstream list
                if not terminal_basin.empty:
                    terminal_linkno = int(terminal_basin.iloc[0][linkno_col])
                    upstream_basins = upstream_basins[
                        upstream_basins[linkno_col].astype(int) != terminal_linkno
                    ]

    return terminal_basin, upstream_basins

def plot_overview_inset_styled(ax, subbasin_gdf, stream_gdf, lake_gdf,
                              terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id):
    """Plot the basin-wide overview inset - styled version."""
    # All catchments
    subbasin_gdf.plot(ax=ax, color='#FAFAFA', edgecolor='#B0B0B0',
                      linewidth=0.6, alpha=0.9)

    # Lakes for context
    lake_gdf.plot(ax=ax, color='#D6EAF8', alpha=0.9,
                  edgecolor='#3498DB', linewidth=0.6)

    # Highlight watersheds
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax, color='#CD853F', alpha=0.8,
                            edgecolor='#8B7355', linewidth=1.5)
    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax, color='#F5E6D3', alpha=0.7,
                             edgecolor='#D2B48C', linewidth=1.2)

    # Gage star (small)
    ax.scatter(gage_lon, gage_lat, color='#FFD700', marker='*', s=150,
               edgecolor='black', linewidth=1.5, zorder=5)

    # Full basin extent
    basin_bounds = subbasin_gdf.total_bounds
    ax.set_xlim(basin_bounds[0], basin_bounds[2])
    ax.set_ylim(basin_bounds[1], basin_bounds[3])

    # Style
    ax.set_title('Basin Overview', fontsize=10, fontweight='bold', pad=8)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    ax.set_aspect('equal')

def create_single_styled_map(gage_id, gage_wells, subbasin_gdf, stream_gdf, lake_gdf,
                           gage_df, terminal_relationships, save_dir):
    """Create a single map using the correlation map style."""

    # Get gage info
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        print(f"No gage info for {gage_id}")
        return False

    gage_name = gage_info.iloc[0].get('name', f'Gage {gage_id}')
    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']

    # Get watersheds: terminal + upstream
    terminal_basin, upstream_basins = get_gage_watersheds_styled(
        gage_id, subbasin_gdf, gage_df, terminal_relationships
    )

    if terminal_basin.empty:
        print(f"No terminal basin found for gage {gage_id}")
        return False

    print(f"Found terminal basin and {len(upstream_basins)} upstream basins for gage {gage_id}")

    # Sort wells by distance (closest first)
    gage_wells = gage_wells.sort_values('Distance_to_Reach')

    # Figure layout - SAME AS CORRELATION MAPS
    fig = plt.figure(figsize=(15, 10))
    fig.patch.set_facecolor('white')
    ax_main = plt.subplot2grid((10, 10), (0, 0), colspan=7, rowspan=10)
    ax_main.set_facecolor('white')

    # Plot watershed polygons - SAME COLORS AS CORRELATION MAPS
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax_main, color='#B8860B', alpha=0.8,
                            edgecolor='#8B7355', linewidth=2, zorder=1)
        print(f"✅ Plotted terminal basin for gage {gage_id}")

    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax_main, color='#F5E6D3', alpha=0.6,
                             edgecolor='#D2B48C', linewidth=1.5, zorder=1)
        print(f"✅ Plotted {len(upstream_basins)} upstream basins for gage {gage_id}")

    # Clip streams/lakes by watershed extent
    all_basins = pd.concat([terminal_basin, upstream_basins], ignore_index=True)
    if not all_basins.empty:
        watershed_union = all_basins.unary_union
        local_streams = stream_gdf[stream_gdf.geometry.intersects(watershed_union)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(watershed_union)]
    else:
        # Fallback: extent around the gage and wells
        all_lons = [gage_lon] + gage_wells['well_lon'].tolist()
        all_lats = [gage_lat] + gage_wells['well_lat'].tolist()
        min_lon, max_lon = min(all_lons), max(all_lons)
        min_lat, max_lat = min(all_lats), max(all_lats)
        buffer = max(max_lon - min_lon, max_lat - min_lat) * 0.3 if (max_lon > min_lon and max_lat > min_lat) else 0.1

        from shapely.geometry import box
        extent_box = box(min_lon - buffer, min_lat - buffer, max_lon + buffer, max_lat + buffer)
        local_streams = stream_gdf[stream_gdf.geometry.intersects(extent_box)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(extent_box)]

    # Base layers - SAME AS CORRELATION MAPS
    if not local_streams.empty:
        local_streams.plot(ax=ax_main, color='#4A90E2', linewidth=1.5, alpha=0.8, zorder=2)
    if not local_lakes.empty:
        local_lakes.plot(ax=ax_main, color='#E6F3FF', alpha=0.8, edgecolor='#4A90E2', linewidth=0.8, zorder=2)

    # Plot wells colored by distance - NO TOP 5 HIGHLIGHTING
    if len(gage_wells) > 0:
        norm = plt.Normalize(
            vmin=gage_wells['Distance_to_Reach'].min(),
            vmax=gage_wells['Distance_to_Reach'].max()
        )

        scatter = ax_main.scatter(
            gage_wells['well_lon'],
            gage_wells['well_lat'],
            c=gage_wells['Distance_to_Reach'],
            cmap='viridis',  # Same as correlation maps
            s=90,
            alpha=0.9,
            edgecolor='black',
            linewidth=0.8,
            zorder=4,
            norm=norm
        )

    # Plot gage as large bright yellow star - SAME AS CORRELATION MAPS
    ax_main.scatter(gage_lon, gage_lat,
              color='#FFD700',  # Bright yellow/gold
              marker='*',       # Star
              s=500,           # Much larger size
              edgecolor='black',  # Black edge
              linewidth=2,
              zorder=10)

    # Set main map extent - SAME AS CORRELATION MAPS
    if not all_basins.empty:
        bounds = all_basins.total_bounds
        buffer = max(bounds[2] - bounds[0], bounds[3] - bounds[1]) * 0.1
        ax_main.set_xlim(bounds[0] - buffer, bounds[2] + buffer)
        ax_main.set_ylim(bounds[1] - buffer, bounds[3] + buffer)
    else:
        # Fallback to well and gage extent
        all_lons = list(gage_wells['well_lon']) + [gage_lon]
        all_lats = list(gage_wells['well_lat']) + [gage_lat]
        lon_buffer = (max(all_lons) - min(all_lons)) * 0.2
        lat_buffer = (max(all_lats) - min(all_lats)) * 0.2
        ax_main.set_xlim(min(all_lons) - lon_buffer, max(all_lons) + lon_buffer)
        ax_main.set_ylim(min(all_lats) - lat_buffer, max(all_lats) + lat_buffer)

    # Style main map - SAME AS CORRELATION MAPS - FIXED TITLE LINEBREAK
    ax_main.set_title(f'Wells by Distance to Reach:\n {gage_name}\nGage ID: {int(gage_id)}',
                      fontsize=14, fontweight='bold', pad=15)
    ax_main.set_xlabel('Longitude', fontsize=11)
    ax_main.set_ylabel('Latitude', fontsize=11)
    for spine in ax_main.spines.values():
        spine.set_visible(False)
    ax_main.grid(True, alpha=0.3, linestyle='-', linewidth=0.5, color='lightgray')
    ax_main.tick_params(labelsize=9)

    # Inset - SAME AS CORRELATION MAPS
    ax_inset = plt.subplot2grid((10, 10), (0, 7), colspan=3, rowspan=3)
    ax_inset.set_facecolor('white')
    plot_overview_inset_styled(ax_inset, subbasin_gdf, stream_gdf, lake_gdf,
                               terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id)

    # Colorbar - SAME LAYOUT AS CORRELATION MAPS
    if len(gage_wells) > 0:
        cbar_ax = plt.subplot2grid((10, 10), (4, 7), colspan=3, rowspan=1)
        cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Distance to Reach (meters)', fontsize=10, fontweight='bold')
        cbar.ax.tick_params(labelsize=8)

    # Add statistics and legend - SAME STYLE AS CORRELATION MAPS
    add_distance_statistics_and_legend_styled(fig, ax_main, gage_wells, gage_id,
                                            terminal_basin, upstream_basins)

    # Save
    safe_name = gage_name.replace('/', '_').replace('\\', '_')[:40]
    filename = f"gage_{gage_id}_distance_{safe_name}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none', pad_inches=0.1)
    plt.close()

    print(f"✅ Saved map for gage {gage_id}: {filename}")
    return True

def add_distance_statistics_and_legend_styled(fig, ax_main, gage_wells, gage_id,
                                            terminal_basin, upstream_basins):
    """Add statistics box and legend - SAME STYLE AS CORRELATION MAPS."""
    # Stats text
    watershed_info = ""
    if not terminal_basin.empty:
        watershed_info += f"Terminal Basin: 1\n"
    if not upstream_basins.empty:
        watershed_info += f"Upstream Basins: {len(upstream_basins)}\n"

    if not gage_wells.empty:
        min_dist = gage_wells['Distance_to_Reach'].min()
        max_dist = gage_wells['Distance_to_Reach'].max()
        mean_dist = gage_wells['Distance_to_Reach'].mean()
        wells_n = len(gage_wells)
    else:
        min_dist = float('nan')
        max_dist = float('nan')
        mean_dist = float('nan')
        wells_n = 0

    stats_text = (f"{watershed_info}"
                  f"Wells: {wells_n}\n"
                  f"Min Distance: {min_dist:.1f}m\n"
                  f"Max Distance: {max_dist:.1f}m" if wells_n > 0 else f"{watershed_info}Wells: 0")

    if wells_n > 0:
        stats_text += f"\nMean Distance: {mean_dist:.1f}m"

    ax_main.text(0.02, 0.98, stats_text, transform=ax_main.transAxes,
                 bbox=dict(boxstyle="round,pad=0.4", facecolor='white',
                           alpha=0.95, edgecolor='black', linewidth=1),
                 verticalalignment='top', fontsize=10,
                 fontweight='bold', zorder=15)

    # Legend - SAME AS CORRELATION MAPS
    legend_elements = [
        plt.Line2D([0], [0], marker='*', color='w',
                   markerfacecolor='#FFD700', markersize=18,
                   markeredgecolor='black', markeredgewidth=1.5,
                   label=f'Gage {gage_id}', linestyle='None'),
    ]

    if wells_n > 0:
        legend_elements.extend([
            plt.Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='lightgray', markersize=8,
                       markeredgecolor='gray', markeredgewidth=0.5,
                       label='Wells (colored by distance)', linestyle='None')
        ])

    if not terminal_basin.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#B8860B', edgecolor='#8B7355',
                          alpha=0.8, label='Terminal Basin (Gage Location)')
        )
    if not upstream_basins.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#F5E6D3', edgecolor='#D2B48C',
                          alpha=0.6, label='Upstream Basins')
        )

    ax_main.legend(handles=legend_elements, loc='lower left',
                   bbox_to_anchor=(0.02, 0.02), fontsize=9,
                   frameon=True, fancybox=False, shadow=False,
                   edgecolor='black', facecolor='white', framealpha=0.95)

# Run the styled version
create_watershed_distance_maps_styled()

=== Creating Watershed Distance Maps (Styled Version) ===
Loading datasets...
Using well_gage data for well-gage relationships: 2880 unique pairs
✅ Loaded terminal relationships
✅ All datasets loaded
Merging data using numeric approach...
After adding coordinates: 8752 wells
Final merged data: 2880 wells with gage relationships
Creating maps for 10 gages: [10152000. 10141000. 10153100. 10163000. 10168000. 10143500. 10126000.
 10142000. 10172952. 10172860.]


Creating styled maps:   0%|          | 0/10 [00:00<?, ?it/s]


Processing Gage 10152000.0 with 92 wells
Found terminal basin and 109 upstream basins for gage 10152000.0
✅ Plotted terminal basin for gage 10152000.0
✅ Plotted 109 upstream basins for gage 10152000.0


Creating styled maps:  10%|█         | 1/10 [00:21<03:11, 21.28s/it]

✅ Saved map for gage 10152000.0: gage_10152000.0_distance_SPANISH FORK NEAR LAKE SHORE - UTAH.png

Processing Gage 10141000.0 with 732 wells
Found terminal basin and 358 upstream basins for gage 10141000.0
✅ Plotted terminal basin for gage 10141000.0
✅ Plotted 358 upstream basins for gage 10141000.0


Creating styled maps:  20%|██        | 2/10 [00:44<02:59, 22.46s/it]

✅ Saved map for gage 10141000.0: gage_10141000.0_distance_WEBER RIVER NEAR PLAIN CITY - UT.png

Processing Gage 10153100.0 with 16 wells
Found terminal basin and 16 upstream basins for gage 10153100.0
✅ Plotted terminal basin for gage 10153100.0
✅ Plotted 16 upstream basins for gage 10153100.0


Creating styled maps:  30%|███       | 3/10 [01:03<02:27, 21.00s/it]

✅ Saved map for gage 10153100.0: gage_10153100.0_distance_HOBBLE CREEK AT 1650 WEST AT SPRINGVILLE.png

Processing Gage 10163000.0 with 471 wells
Found terminal basin and 123 upstream basins for gage 10163000.0
✅ Plotted terminal basin for gage 10163000.0
✅ Plotted 123 upstream basins for gage 10163000.0


Creating styled maps:  40%|████      | 4/10 [01:24<02:04, 20.70s/it]

✅ Saved map for gage 10163000.0: gage_10163000.0_distance_PROVO RIVER AT PROVO - UT.png

Processing Gage 10168000.0 with 73 wells
Found terminal basin and 5 upstream basins for gage 10168000.0
✅ Plotted terminal basin for gage 10168000.0
✅ Plotted 5 upstream basins for gage 10168000.0


Creating styled maps:  50%|█████     | 5/10 [01:45<01:45, 21.09s/it]

✅ Saved map for gage 10168000.0: gage_10168000.0_distance_LITTLE COTTONWOOD CREEK @ JORDAN RIVER N.png

Processing Gage 10143500.0 with 20 wells
Found terminal basin and 0 upstream basins for gage 10143500.0
✅ Plotted terminal basin for gage 10143500.0


Creating styled maps:  60%|██████    | 6/10 [02:03<01:19, 19.83s/it]

✅ Saved map for gage 10143500.0: gage_10143500.0_distance_CENTERVILLE CREEK ABV. DIV NEAR CENTERVI.png

Processing Gage 10126000.0 with 1461 wells
Found terminal basin and 1260 upstream basins for gage 10126000.0
✅ Plotted terminal basin for gage 10126000.0
✅ Plotted 1260 upstream basins for gage 10126000.0


Creating styled maps:  70%|███████   | 7/10 [02:39<01:15, 25.19s/it]

✅ Saved map for gage 10126000.0: gage_10126000.0_distance_BEAR RIVER NEAR CORINNE - UT.png

Processing Gage 10142000.0 with 13 wells
Found terminal basin and 0 upstream basins for gage 10142000.0
✅ Plotted terminal basin for gage 10142000.0


Creating styled maps:  80%|████████  | 8/10 [02:56<00:45, 22.68s/it]

✅ Saved map for gage 10142000.0: gage_10142000.0_distance_FARMINGTON CR ABV DIV NR FARMINGTON - UT.png

Processing Gage 10172952.0 with 1 wells
Found terminal basin and 0 upstream basins for gage 10172952.0
✅ Plotted terminal basin for gage 10172952.0


Creating styled maps:  90%|█████████ | 9/10 [03:14<00:21, 21.04s/it]

✅ Saved map for gage 10172952.0: gage_10172952.0_distance_DUNN CREEK NEAR PARK VALLEY - UT.png

Processing Gage 10172860.0 with 1 wells
Found terminal basin and 11 upstream basins for gage 10172860.0
✅ Plotted terminal basin for gage 10172860.0
✅ Plotted 11 upstream basins for gage 10172860.0


Creating styled maps: 100%|██████████| 10/10 [03:34<00:00, 21.40s/it]

✅ Saved map for gage 10172860.0: gage_10172860.0_distance_WARM CREEK NEAR GANDY - UT.png
✅ Maps saved to: ../reports/figures/watershed_distance_maps_styled





In [26]:

# Count the number of unique well IDs for each gage in the well_gage dataframe
unique_well_counts = well_gage.groupby('gage_id')['well_id'].nunique().reset_index()
unique_well_counts.columns = ['gage_id', 'unique_well_count']

# Display the result
unique_well_counts


Unnamed: 0,gage_id,unique_well_count
0,10126000.0,1461
1,10141000.0,732
2,10142000.0,13
3,10143500.0,20
4,10152000.0,92
5,10153100.0,16
6,10163000.0,471
7,10168000.0,73
8,10172860.0,1
9,10172952.0,1


# plot vertical distance

In [30]:
well_elev = pd.read_csv('../data/raw/groundwater/GSLB_1900-2023_wells_with_aquifers.csv')
well_elev.head()

Unnamed: 0,Well_ID,Well_Name,lat_dec,long_dec,GSE,AquiferID,Aquifer_Name,State
0,381033113480701,(C-30-18)25aad- 1,38.175796,-113.80275,7098.0,1,GSL Basin,UT
1,381037113474001,(C-30-17)30bab- 1,38.176306,-113.7955,7193.0,1,GSL Basin,UT
2,381152113442801,(C-30-17)15cab- 1,38.197833,-113.741167,6550.0,1,GSL Basin,UT
3,381236113485601,(C-30-18)12cdb- 1,38.210028,-113.8155,7190.0,1,GSL Basin,UT
4,382113113435401,(C-28-17)22dda- 1,38.353571,-113.732473,5775.0,1,GSL Basin,UT


In [37]:
# Watershed vertical distance mapping code with proper styling (using well_gage data)
import os
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Point
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

def create_watershed_vertical_distance_maps_styled():
    """
    Create watershed maps for each gage using the same style as correlation maps.
    Uses well_gage data for well-gage relationships.
    Colors wells by vertical distance (well elevation - reach elevation).
    """
    print("=== Creating Watershed Vertical Distance Maps (Styled Version) ===")
    save_dir = "../reports/figures/watershed_vertical_distance_maps_styled"
    os.makedirs(save_dir, exist_ok=True)

    # Load all datasets
    print("Loading datasets...")

    # Geographic data
    subbasin_gdf = gpd.read_file('../data/raw/hydrography/gsl_catchment.shp')
    stream_gdf = gpd.read_file('../data/raw/hydrography/gslb_stream.shp')
    lake_gdf = gpd.read_file('../data/raw/hydrography/lake.shp')
    gage_df = pd.read_csv('../data/raw/hydrography/gsl_nwm_gage.csv')

    # Well data
    reach_distances = pd.read_csv('../data/processed/well_reach_relationships_final.csv')
    well_locations = pd.read_csv('../data/processed/well_reach.csv')[['well_id', 'well_lat', 'well_lon']].drop_duplicates()

    # CHANGED: Use well_gage data for well-gage relationships
    well_gage_data = pd.read_csv('../data/processed/wells_with_catchment_info.csv')
    well_gage_pairs = well_gage_data[['well_id', 'gage_id', 'gse']].dropna().drop_duplicates()
    
    # Convert gage_id to integer (remove decimal points)
    well_gage_pairs['gage_id'] = well_gage_pairs['gage_id'].astype(int)
    
    print(f"Using well_gage data for well-gage relationships: {len(well_gage_pairs)} unique pairs")

    # Load watershed relationships
    try:
        terminal_relationships = pd.read_csv('../data/processed/terminal_gage_upstream_catchments.csv')
        print("✅ Loaded terminal relationships")
    except FileNotFoundError:
        print("⚠️ Terminal relationships file not found, using simplified approach")
        terminal_relationships = None

    print("✅ All datasets loaded")

    # Merge data using numeric IDs
    print("Merging data using numeric approach...")

    # Convert to numeric
    reach_distances['well_id_numeric'] = pd.to_numeric(reach_distances['Well_ID'], errors='coerce')
    well_locations['well_id_numeric'] = pd.to_numeric(well_locations['well_id'], errors='coerce')
    well_gage_pairs['well_id_numeric'] = pd.to_numeric(well_gage_pairs['well_id'], errors='coerce')

    # Merge step by step
    merge1 = pd.merge(reach_distances, well_locations, on='well_id_numeric', how='inner')
    print(f"After adding coordinates: {len(merge1)} wells")

    final_data = pd.merge(merge1, well_gage_pairs, on='well_id_numeric', how='inner')
    print(f"Final merged data: {len(final_data)} wells with gage relationships")

    if len(final_data) == 0:
        print("❌ No data after merging - cannot create maps")
        return

    # Calculate vertical distance (well elevation - reach elevation)
    # Convert GSE from feet to meters (1 ft = 0.3048 m)
    final_data['well_elev_m'] = final_data['gse'] * 0.3048
    final_data['reach_elev_m'] = final_data['Reach_Elevation'] * 0.3048
    final_data['vertical_distance_m'] = final_data['well_elev_m'] - final_data['reach_elev_m']

    print(f"Calculated vertical distances - Range: {final_data['vertical_distance_m'].min():.1f}m to {final_data['vertical_distance_m'].max():.1f}m")

    # Ensure CRS consistency
    subbasin_gdf = subbasin_gdf.to_crs("EPSG:4326")
    stream_gdf = stream_gdf.to_crs("EPSG:4326")
    lake_gdf = lake_gdf.to_crs("EPSG:4326")

    # Get unique gages (process all gages)
    unique_gages = final_data['gage_id'].dropna().unique()
    print(f"Creating maps for {len(unique_gages)} gages: {unique_gages}")

    # Create maps
    for gage_id in tqdm(unique_gages, desc="Creating styled maps"):
        try:
            gage_wells = final_data[final_data['gage_id'] == gage_id].copy()
            print(f"\nProcessing Gage {gage_id} with {len(gage_wells)} wells")

            if len(gage_wells) == 0:
                continue

            # Create the map
            create_single_styled_map_vertical(gage_id, gage_wells, subbasin_gdf, stream_gdf,
                                   lake_gdf, gage_df, terminal_relationships, save_dir)

        except Exception as e:
            print(f"❌ Error for gage {gage_id}: {e}")
            import traceback
            traceback.print_exc()

    print(f"✅ Maps saved to: {save_dir}")

def get_gage_watersheds_styled(gage_id, subbasin_gdf, gage_df, terminal_relationships=None):
    """Get terminal and upstream basins for a gage - styled version."""
    # Get gage info
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        return gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs), gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']

    # Create gage point
    gage_point = Point(gage_lon, gage_lat)
    gage_gdf = gpd.GeoDataFrame([1], geometry=[gage_point], crs="EPSG:4326")
    gage_gdf = gage_gdf.to_crs(subbasin_gdf.crs)
    gage_point = gage_gdf.geometry.iloc[0]

    # Find terminal basin (containing the gage)
    containing = subbasin_gdf[subbasin_gdf.geometry.contains(gage_point)]
    if len(containing) > 0:
        terminal_basin = containing.iloc[[0]]
    else:
        # Fallback: find nearest basin
        distances = subbasin_gdf.geometry.distance(gage_point)
        nearest_idx = distances.idxmin()
        terminal_basin = subbasin_gdf.iloc[[nearest_idx]]

    # Get upstream basins if relationships are available
    upstream_basins = gpd.GeoDataFrame(geometry=[], crs=subbasin_gdf.crs)

    if terminal_relationships is not None:
        # Standardize column names
        if 'Gage_ID' in terminal_relationships.columns:
            gage_col = 'Gage_ID'
            upstream_col = 'Upstream_Catchment_ID'
        else:
            gage_col = 'gage_id'
            upstream_col = 'upstream_catchment_id'

        # Find linkno column in subbasin_gdf
        linkno_col = None
        for col in ['linkno', 'LINKNO', 'LinkNo', 'LINK_NO']:
            if col in subbasin_gdf.columns:
                linkno_col = col
                break

        if linkno_col and gage_col in terminal_relationships.columns:
            # Get upstream catchments for this gage
            upstream_catchments = terminal_relationships.loc[
                terminal_relationships[gage_col] == gage_id, upstream_col
            ].dropna().astype(int).tolist()

            if upstream_catchments:
                upstream_basins = subbasin_gdf[
                    subbasin_gdf[linkno_col].astype(int).isin(upstream_catchments)
                ].copy()

                # Remove terminal basin if it's included in upstream list
                if not terminal_basin.empty:
                    terminal_linkno = int(terminal_basin.iloc[0][linkno_col])
                    upstream_basins = upstream_basins[
                        upstream_basins[linkno_col].astype(int) != terminal_linkno
                    ]

    return terminal_basin, upstream_basins

def plot_overview_inset_styled(ax, subbasin_gdf, stream_gdf, lake_gdf,
                              terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id):
    """Plot the basin-wide overview inset - styled version."""
    # All catchments
    subbasin_gdf.plot(ax=ax, color='#FAFAFA', edgecolor='#B0B0B0',
                      linewidth=0.6, alpha=0.9)

    # Lakes for context
    lake_gdf.plot(ax=ax, color='#D6EAF8', alpha=0.9,
                  edgecolor='#3498DB', linewidth=0.6)

    # Highlight watersheds
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax, color='#CD853F', alpha=0.8,
                            edgecolor='#8B7355', linewidth=1.5)
    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax, color='#F5E6D3', alpha=0.7,
                             edgecolor='#D2B48C', linewidth=1.2)

    # Gage star (small)
    ax.scatter(gage_lon, gage_lat, color='#FFD700', marker='*', s=150,
               edgecolor='black', linewidth=1.5, zorder=5)

    # Full basin extent
    basin_bounds = subbasin_gdf.total_bounds
    ax.set_xlim(basin_bounds[0], basin_bounds[2])
    ax.set_ylim(basin_bounds[1], basin_bounds[3])

    # Style
    ax.set_title('Basin Overview', fontsize=10, fontweight='bold', pad=8)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    ax.set_aspect('equal')

def create_single_styled_map_vertical(gage_id, gage_wells, subbasin_gdf, stream_gdf, lake_gdf,
                           gage_df, terminal_relationships, save_dir):
    """Create a single map using the correlation map style, colored by vertical distance."""

    # Get gage info
    gage_info = gage_df[gage_df['id'] == gage_id]
    if gage_info.empty:
        print(f"No gage info for {gage_id}")
        return False

    gage_name = gage_info.iloc[0].get('name', f'Gage {gage_id}')
    gage_lat = gage_info.iloc[0]['latitude']
    gage_lon = gage_info.iloc[0]['longitude']

    # Get watersheds: terminal + upstream
    terminal_basin, upstream_basins = get_gage_watersheds_styled(
        gage_id, subbasin_gdf, gage_df, terminal_relationships
    )

    if terminal_basin.empty:
        print(f"No terminal basin found for gage {gage_id}")
        return False

    print(f"Found terminal basin and {len(upstream_basins)} upstream basins for gage {gage_id}")

    # Sort wells by vertical distance (highest above reach first)
    gage_wells = gage_wells.sort_values('vertical_distance_m', ascending=False)

    # Figure layout - SAME AS CORRELATION MAPS
    fig = plt.figure(figsize=(15, 10))
    fig.patch.set_facecolor('white')
    ax_main = plt.subplot2grid((10, 10), (0, 0), colspan=7, rowspan=10)
    ax_main.set_facecolor('white')

    # Plot watershed polygons - SAME COLORS AS CORRELATION MAPS
    if not terminal_basin.empty:
        terminal_basin.plot(ax=ax_main, color='#B8860B', alpha=0.8,
                            edgecolor='#8B7355', linewidth=2, zorder=1)
        print(f"✅ Plotted terminal basin for gage {gage_id}")

    if not upstream_basins.empty:
        upstream_basins.plot(ax=ax_main, color='#F5E6D3', alpha=0.6,
                             edgecolor='#D2B48C', linewidth=1.5, zorder=1)
        print(f"✅ Plotted {len(upstream_basins)} upstream basins for gage {gage_id}")

    # Clip streams/lakes by watershed extent
    all_basins = pd.concat([terminal_basin, upstream_basins], ignore_index=True)
    if not all_basins.empty:
        watershed_union = all_basins.unary_union
        local_streams = stream_gdf[stream_gdf.geometry.intersects(watershed_union)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(watershed_union)]
    else:
        # Fallback: extent around the gage and wells
        all_lons = [gage_lon] + gage_wells['well_lon'].tolist()
        all_lats = [gage_lat] + gage_wells['well_lat'].tolist()
        min_lon, max_lon = min(all_lons), max(all_lons)
        min_lat, max_lat = min(all_lats), max(all_lats)
        buffer = max(max_lon - min_lon, max_lat - min_lat) * 0.3 if (max_lon > min_lon and max_lat > min_lat) else 0.1

        from shapely.geometry import box
        extent_box = box(min_lon - buffer, min_lat - buffer, max_lon + buffer, max_lat + buffer)
        local_streams = stream_gdf[stream_gdf.geometry.intersects(extent_box)]
        local_lakes = lake_gdf[lake_gdf.geometry.intersects(extent_box)]

    # Base layers - SAME AS CORRELATION MAPS
    if not local_streams.empty:
        local_streams.plot(ax=ax_main, color='#4A90E2', linewidth=1.5, alpha=0.8, zorder=2)
    if not local_lakes.empty:
        local_lakes.plot(ax=ax_main, color='#E6F3FF', alpha=0.8, edgecolor='#4A90E2', linewidth=0.8, zorder=2)

    # Plot wells colored by vertical distance
    if len(gage_wells) > 0:
        norm = plt.Normalize(
            vmin=gage_wells['vertical_distance_m'].min(),
            vmax=gage_wells['vertical_distance_m'].max()
        )

        scatter = ax_main.scatter(
            gage_wells['well_lon'],
            gage_wells['well_lat'],
            c=gage_wells['vertical_distance_m'],
            cmap='RdYlBu_r',  # Red-Yellow-Blue reversed (red=high, blue=low)
            s=90,
            alpha=0.9,
            edgecolor='black',
            linewidth=0.8,
            zorder=4,
            norm=norm
        )

    # Plot gage as large bright yellow star - SAME AS CORRELATION MAPS
    ax_main.scatter(gage_lon, gage_lat,
              color='#FFD700',  # Bright yellow/gold
              marker='*',       # Star
              s=500,           # Much larger size
              edgecolor='black',  # Black edge
              linewidth=2,
              zorder=10)

    # Set main map extent - SAME AS CORRELATION MAPS
    if not all_basins.empty:
        bounds = all_basins.total_bounds
        buffer = max(bounds[2] - bounds[0], bounds[3] - bounds[1]) * 0.1
        ax_main.set_xlim(bounds[0] - buffer, bounds[2] + buffer)
        ax_main.set_ylim(bounds[1] - buffer, bounds[3] + buffer)
    else:
        # Fallback to well and gage extent
        all_lons = list(gage_wells['well_lon']) + [gage_lon]
        all_lats = list(gage_wells['well_lat']) + [gage_lat]
        lon_buffer = (max(all_lons) - min(all_lons)) * 0.2
        lat_buffer = (max(all_lats) - min(all_lats)) * 0.2
        ax_main.set_xlim(min(all_lons) - lon_buffer, max(all_lons) + lon_buffer)
        ax_main.set_ylim(min(all_lats) - lat_buffer, max(all_lats) + lat_buffer)

    # Style main map - SAME AS CORRELATION MAPS
    ax_main.set_title(f'Wells by Vertical Distance:\n {gage_name}\nGage ID: {gage_id}',
                      fontsize=14, fontweight='bold', pad=15)
    ax_main.set_xlabel('Longitude', fontsize=11)
    ax_main.set_ylabel('Latitude', fontsize=11)
    for spine in ax_main.spines.values():
        spine.set_visible(False)
    ax_main.grid(True, alpha=0.3, linestyle='-', linewidth=0.5, color='lightgray')
    ax_main.tick_params(labelsize=9)

    # Inset - SAME AS CORRELATION MAPS
    ax_inset = plt.subplot2grid((10, 10), (0, 7), colspan=3, rowspan=3)
    ax_inset.set_facecolor('white')
    plot_overview_inset_styled(ax_inset, subbasin_gdf, stream_gdf, lake_gdf,
                               terminal_basin, upstream_basins, gage_lon, gage_lat, gage_id)

    # Colorbar - SAME LAYOUT AS CORRELATION MAPS
    if len(gage_wells) > 0:
        cbar_ax = plt.subplot2grid((10, 10), (4, 7), colspan=3, rowspan=1)
        cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Vertical Distance (Well - Reach) [meters]', fontsize=10, fontweight='bold')
        cbar.ax.tick_params(labelsize=8)

    # Add statistics and legend - SAME STYLE AS CORRELATION MAPS
    add_vertical_distance_statistics_and_legend_styled(fig, ax_main, gage_wells, gage_id,
                                            terminal_basin, upstream_basins)

    # Save
    safe_name = gage_name.replace('/', '_').replace('\\', '_')[:40]
    filename = f"gage_{int(gage_id)}_vertical_distance_{safe_name}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none', pad_inches=0.1)
    plt.close()

    print(f"✅ Saved map for gage {gage_id}: {filename}")
    return True

def add_vertical_distance_statistics_and_legend_styled(fig, ax_main, gage_wells, gage_id,
                                            terminal_basin, upstream_basins):
    """Add statistics box and legend - SAME STYLE AS CORRELATION MAPS."""
    # Stats text
    watershed_info = ""
    if not terminal_basin.empty:
        watershed_info += f"Terminal Basin: 1\n"
    if not upstream_basins.empty:
        watershed_info += f"Upstream Basins: {len(upstream_basins)}\n"

    if not gage_wells.empty:
        min_vert_dist = gage_wells['vertical_distance_m'].min()
        max_vert_dist = gage_wells['vertical_distance_m'].max()
        mean_vert_dist = gage_wells['vertical_distance_m'].mean()
        wells_n = len(gage_wells)
    else:
        min_vert_dist = float('nan')
        max_vert_dist = float('nan')
        mean_vert_dist = float('nan')
        wells_n = 0

    stats_text = (f"{watershed_info}"
                  f"Wells: {wells_n}\n"
                  f"Min Vert Dist: {min_vert_dist:.1f}m\n"
                  f"Max Vert Dist: {max_vert_dist:.1f}m" if wells_n > 0 else f"{watershed_info}Wells: 0")

    if wells_n > 0:
        stats_text += f"\nMean Vert Dist: {mean_vert_dist:.1f}m"

    ax_main.text(0.02, 0.98, stats_text, transform=ax_main.transAxes,
                 bbox=dict(boxstyle="round,pad=0.4", facecolor='white',
                           alpha=0.95, edgecolor='black', linewidth=1),
                 verticalalignment='top', fontsize=10,
                 fontweight='bold', zorder=15)

    # Legend - SAME AS CORRELATION MAPS
    legend_elements = [
        plt.Line2D([0], [0], marker='*', color='w',
                   markerfacecolor='#FFD700', markersize=18,
                   markeredgecolor='black', markeredgewidth=1.5,
                   label=f'Gage {int(gage_id)}', linestyle='None'),
    ]

    if wells_n > 0:
        legend_elements.extend([
            plt.Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='red', markersize=8,
                       markeredgecolor='black', markeredgewidth=0.5,
                       label='Wells (higher above reach)', linestyle='None'),
            plt.Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='blue', markersize=8,
                       markeredgecolor='black', markeredgewidth=0.5,
                       label='Wells (lower above reach)', linestyle='None')
        ])

    if not terminal_basin.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#B8860B', edgecolor='#8B7355',
                          alpha=0.8, label='Terminal Basin (Gage Location)')
        )
    if not upstream_basins.empty:
        legend_elements.append(
            plt.Rectangle((0, 0), 1, 1, facecolor='#F5E6D3', edgecolor='#D2B48C',
                          alpha=0.6, label='Upstream Basins')
        )

    ax_main.legend(handles=legend_elements, loc='lower left',
                   bbox_to_anchor=(0.02, 0.02), fontsize=9,
                   frameon=True, fancybox=False, shadow=False,
                   edgecolor='black', facecolor='white', framealpha=0.95)

# Run the vertical distance version
create_watershed_vertical_distance_maps_styled()

=== Creating Watershed Vertical Distance Maps (Styled Version) ===
Loading datasets...
Using well_gage data for well-gage relationships: 2880 unique pairs
✅ Loaded terminal relationships
✅ All datasets loaded
Merging data using numeric approach...
After adding coordinates: 8752 wells
Final merged data: 2880 wells with gage relationships
Calculated vertical distances - Range: 703.8m to 1898.8m
Creating maps for 10 gages: [10152000 10141000 10153100 10163000 10168000 10143500 10126000 10142000
 10172952 10172860]


Creating styled maps:   0%|          | 0/10 [00:00<?, ?it/s]


Processing Gage 10152000 with 92 wells
Found terminal basin and 109 upstream basins for gage 10152000
✅ Plotted terminal basin for gage 10152000
✅ Plotted 109 upstream basins for gage 10152000


Creating styled maps:  10%|█         | 1/10 [00:22<03:21, 22.44s/it]

✅ Saved map for gage 10152000: gage_10152000_vertical_distance_SPANISH FORK NEAR LAKE SHORE - UTAH.png

Processing Gage 10141000 with 732 wells
Found terminal basin and 358 upstream basins for gage 10141000
✅ Plotted terminal basin for gage 10141000
✅ Plotted 358 upstream basins for gage 10141000


Creating styled maps:  20%|██        | 2/10 [00:45<03:04, 23.06s/it]

✅ Saved map for gage 10141000: gage_10141000_vertical_distance_WEBER RIVER NEAR PLAIN CITY - UT.png

Processing Gage 10153100 with 16 wells
Found terminal basin and 16 upstream basins for gage 10153100
✅ Plotted terminal basin for gage 10153100
✅ Plotted 16 upstream basins for gage 10153100


Creating styled maps:  30%|███       | 3/10 [01:05<02:30, 21.52s/it]

✅ Saved map for gage 10153100: gage_10153100_vertical_distance_HOBBLE CREEK AT 1650 WEST AT SPRINGVILLE.png

Processing Gage 10163000 with 471 wells
Found terminal basin and 123 upstream basins for gage 10163000
✅ Plotted terminal basin for gage 10163000
✅ Plotted 123 upstream basins for gage 10163000


Creating styled maps:  40%|████      | 4/10 [01:28<02:11, 22.00s/it]

✅ Saved map for gage 10163000: gage_10163000_vertical_distance_PROVO RIVER AT PROVO - UT.png

Processing Gage 10168000 with 73 wells
Found terminal basin and 5 upstream basins for gage 10168000
✅ Plotted terminal basin for gage 10168000
✅ Plotted 5 upstream basins for gage 10168000


Creating styled maps:  50%|█████     | 5/10 [01:49<01:48, 21.61s/it]

✅ Saved map for gage 10168000: gage_10168000_vertical_distance_LITTLE COTTONWOOD CREEK @ JORDAN RIVER N.png

Processing Gage 10143500 with 20 wells
Found terminal basin and 0 upstream basins for gage 10143500
✅ Plotted terminal basin for gage 10143500


Creating styled maps:  60%|██████    | 6/10 [02:06<01:20, 20.00s/it]

✅ Saved map for gage 10143500: gage_10143500_vertical_distance_CENTERVILLE CREEK ABV. DIV NEAR CENTERVI.png

Processing Gage 10126000 with 1461 wells
Found terminal basin and 1260 upstream basins for gage 10126000
✅ Plotted terminal basin for gage 10126000
✅ Plotted 1260 upstream basins for gage 10126000


Creating styled maps:  70%|███████   | 7/10 [02:40<01:13, 24.65s/it]

✅ Saved map for gage 10126000: gage_10126000_vertical_distance_BEAR RIVER NEAR CORINNE - UT.png

Processing Gage 10142000 with 13 wells
Found terminal basin and 0 upstream basins for gage 10142000
✅ Plotted terminal basin for gage 10142000


Creating styled maps:  80%|████████  | 8/10 [02:57<00:44, 22.19s/it]

✅ Saved map for gage 10142000: gage_10142000_vertical_distance_FARMINGTON CR ABV DIV NR FARMINGTON - UT.png

Processing Gage 10172952 with 1 wells
Found terminal basin and 0 upstream basins for gage 10172952
✅ Plotted terminal basin for gage 10172952


Creating styled maps:  90%|█████████ | 9/10 [03:14<00:20, 20.68s/it]

✅ Saved map for gage 10172952: gage_10172952_vertical_distance_DUNN CREEK NEAR PARK VALLEY - UT.png

Processing Gage 10172860 with 1 wells
Found terminal basin and 11 upstream basins for gage 10172860
✅ Plotted terminal basin for gage 10172860
✅ Plotted 11 upstream basins for gage 10172860


Creating styled maps: 100%|██████████| 10/10 [03:34<00:00, 21.45s/it]

✅ Saved map for gage 10172860: gage_10172860_vertical_distance_WARM CREEK NEAR GANDY - UT.png
✅ Maps saved to: ../reports/figures/watershed_vertical_distance_maps_styled





# plot mutual information