# Introduction



# Steps
1- Realease the particles (2_PT_seseflux)

2- Read the particles for each month 

3- Reduce to first particles intersection to coastline, delete the ones that does not interact with coastline 

4- Make a dataframe that only include the particles first intersect with the shoreline for each month 

-(another thing we need to add is to add a group_number to the particles in addition that group_id like 11, 12, 13, code 3)

# Import required Libraries

In [None]:
# Standard libraries
import os
import glob
import warnings
import gc

# Data manipulation and analysis libraries
import numpy as np
import pandas as pd
import dask.dataframe as dd
import dask.array as da
import xarray as xr
import geopandas as gpd
from shapely.geometry import Point
from netCDF4 import Dataset
# geopandas 
from shapely.geometry import Point
import geopandas as gpd

# Dask diagnostics and progress bar
from dask.diagnostics import ProgressBar

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap, LogNorm
import cartopy.crs as ccrs
from pylag.processing.plot import FVCOMPlotter, create_figure, colourmap

# Helper Functions

In [None]:
# Define the function to sort the files based on the time
def sort_key(file):
    filename = os.path.basename(file)
    try:
        # Extract the number after the double underscores and before the `.nc` extension
        number = int(filename.split('_')[-1].split('.')[0])
        return number
    except (IndexError, ValueError):
        # Handle filenames that do not match the pattern by returning a high number to place them last
        return float('inf')


In [None]:
def sort_key(file):
    filename = os.path.basename(file)

    # ✅ Debugging: Print filename to check if it's correct
    print(f"🔍 Sorting file: {filename}")

    try:
        # ✅ Extract the numeric part after the last underscore `_`
        parts = filename.split("_")
        number_part = parts[-1].split('.')[0]  # Get the last part before .nc

        number = int(number_part)  # Convert to integer
        print(f"✅ Extracted number: {number} from {filename}")
        return number

    except (IndexError, ValueError) as e:
        print(f"❌ ERROR processing filename '{filename}': {e}")
        return float('inf')  # Move problematic files to the end of sorting


# Inputs

In [None]:
# Define paths for coastal wetland shapefiles
GIS_LAYERS = '/home/abolmaal/Data/GIS_layer'  # Add leading slash to make it an absolute path
CW_path = os.path.join(GIS_LAYERS, 'Coastalwetland/hitshoreline')

#CW_path = '/mnt/d/Users/abolmaal/Arcgis/NASAOceanProject/GIS_layer/Coastalwetland/hitshoreline/'

# Paths to specific coastal wetland shapefiles with different inundation levels that have 50 meter buffer from the shoreline
CW_avg_path = os.path.join(CW_path, 'Wetland_connected_avg_inundation_NAD1983_shorelineinteraction_50m_ExportFeatures.shp')
CW_low_path = os.path.join(CW_path, 'wetlands_connected_low_inundation_NAD1983_shorelineinteraction_50m_ExportFeatures.shp')
CW_high_path = os.path.join(CW_path, 'wetlands_connected_high_inundation_NAD1983_shorelineinteraction_50m_ExportFeatures.shp')
CW_surge_path = os.path.join(CW_path, 'wetlands_connected_surge_inundation_NAD1983_shorelineinteraction_50m_ExportFeatures.shp')

# Load coastal wetland shapefiles as GeoDataFrames
CW_avg = gpd.read_file(CW_avg_path)
CW_low = gpd.read_file(CW_low_path)
CW_high = gpd.read_file(CW_high_path)
CW_surge = gpd.read_file(CW_surge_path)

# Define the path to FVCOM model output files
data_dir = '/home/abolmaal/modelling/FVCOM/Huron/output'
files = glob.glob(os.path.join(data_dir, "updated_FVCOM_Huron_*.nc"))
files.sort(key=sort_key)

# Reading the Data

**If time range in the FVCOM outputs overlap we need to remove the times overlap/if not we will use this methid for reading datasets = xr.open_mfdataset(files, combine='by_coords', parallel=True)**

In [None]:
# Ensure the files list is sorted before iterating
files.sort(key=sort_key)

for file in files:
    with xr.open_dataset(file) as ds:
        if 'time' in ds.variables:
            print(f"File: {file} - Time range: {ds['time'].values[0]} to {ds['time'].values[-1]}")
        else:
            print(f"File: {file} - No 'time' variable found in the dataset.")

In [None]:
def remove_overlap(datasets):
    # Remove the first time step of each dataset except for the first one
    datasets[1:] = [ds.sel(time=slice(ds['time'][1], None)) for ds in datasets[1:]]
    return datasets

# Apply to datasets
datasets = [xr.open_dataset(file) for file in files]
datasets = remove_overlap(datasets)

# Now you can combine them
datasets = xr.concat(datasets, dim='time')

In [None]:
# Reading the FVCOM output files
# Open multiple NetCDF datasets with xarray, using chunks and parallel processing
#datasets = xr.open_mfdataset(files, combine='by_coords', parallel=True)

# Define the path to the FVCOM grid metrics file
grid_metrics_file_name = '/home/abolmaal/modelling/FVCOM/Huron/input/gridfile/grid_metrics_huron_senseflux_Seasonal.nc'


# Path to nutrient load data CSV and load the CSV
Direct_Nutrient_load = '/mnt/d/Users/abolmaal/Arcgis/NASAOceanProject/ZonalStats/StreamWatresheds_total_N_P.csv'
Direct_Nutrient = pd.read_csv(Direct_Nutrient_load)

# Check and rename columns in the CSV file to match NetCDF data requirements
required_columns = {'Group_id': 'group_id', 'WetLoad_TN_kgcellday': 'WetLoad_TN_kgcellday', 'WetLoad_TP_kgcellday': 'WetLoad_TP_kgcellday'}

# for col, new_col in required_columns.items():
# if col not in Direct_Nutrient.columns:
#         raise ValueError(f"CSV file is missing required column: '{col}'")
# Direct_Nutrient.rename(columns={'Group_id': 'group_id'}, inplace=True)


# Output

In [None]:
output_path = '/home/abolmaal/modelling/FVCOM/Huron/results'
output_figures = '/home/abolmaal/modelling/FVCOM/Huron/figures'

# Define the file names to the FVCOM model output
originalFVCOM = 'LakeHuronparticletracking_2023_original.shp'


# file name for Intersected FVCOM model output with coastal wetlands

Intersection_PTCW_Avg = 'Intersections_Avg_PTCW.csv'
Intersection_PTCW_high = 'Intersections_high_PTCW.csv'
Intersection_PTCW_low = 'Intersections_low_PTCW.csv'
Intersection_PTCW_surge = 'Intersections_surge_PTCW.csv'
###################################################################
GroupIdcount = 'group_id_counts.csv'
Non_Intersection_avg = 'Non_Intersection_avg.csv'
Non_Intersection_high = 'Non_Intersection_high.csv'
Non_Intersection_low = 'Non_Intersection_low.csv'
Non_Intersection_surge = 'Non_Intersection_surge.csv'

# Parameters

### Figure Parameters

In [None]:
# Figure parameters
# Custom colormap setup (pink and blue shades)
pink_shades = ['#fff5f7', '#ffebf0', '#ffd6e1', '#ffbfd4', '#ff99c1', '#ff6ea9', '#ff4c92', '#ff2171', '#b50d4e']
blue_shades = ['#f7fbff', '#deebf7', '#c6dbef', '#9ecae1', '#6baed6', '#4292c6', '#2171b5', '#084594']
blue_shades_reversed = blue_shades[::-1]

pink_cmap = LinearSegmentedColormap.from_list('custom_pink', pink_shades)
blue_cmap_reversed = LinearSegmentedColormap.from_list('custom_blue', blue_shades)

# Define a list of green shades for the colormap
green_shades =  ['#e0ffe0', '#b3ffb3', '#80ff80', '#4dff4d', '#00e600', '#00cc00', '#009900', '#006600', '#003300']
# Create a custom green colormap
green_cmap = LinearSegmentedColormap.from_list('custom_green', green_shades)

# Replace pink_cmap with viridis and plasma
viridis_cmap = plt.colormaps['viridis']  # Updated to use new interface
plasma_cmap = plt.colormaps['plasma']  # Updated to use new interface

# Set up plotting parameters
font_size = 15
cmap = plt.colormaps['hsv_r']  # Fixed: using an existing colormap (hsv_r)

# Extent of the plot
extents = np.array([275, 277.69, 43, 46.3], dtype=float)

extents_ausable = np.array([276.5, 276.8, 45, 45.5], dtype=float)

# Some parameters for the Zonal Stats Fields
# Fields to calculate / Direct delivery to Watersheds
fieldDirectTN = 'WetLoad_TN_kgcellday'
fieldDirectTP = 'WetLoad_TP_kgcellday'


# Main Functions 

## 1-Make a geodataframe for particle tracking output files

### Select the variables 

In [None]:
import geopandas as gpd
import numpy as np
import pandas as pd
from dask.delayed import delayed
import xarray as xr
from shapely.geometry import Point
import dask  # Import Dask to access dask.compute
from dask import compute  # Import the specific compute function
import gc  # To manually collect garbage and free up memory
from dask.distributed import Client


# Limit the number of workers and threads
client = Client(n_workers=2, threads_per_worker=2)  # Adjust as needed

print(client)
# Function to remove overlap in time
def remove_overlap(datasets):
    # Remove the first time step of each dataset except for the first one
    datasets[1:] = [ds.sel(time=slice(ds['time'][1], None)) for ds in datasets[1:]]
    return datasets

# Function to process a single file and convert to GeoDataFrame
@delayed
def process_file(file):
    # Open the dataset lazily with xarray (no data loaded yet)
    with xr.open_dataset(file) as ds:
        selected_vars = ds[['time', 'group_id', 'group_number', 'longitude', 'latitude']]

        # Perform operations like longitude correction
        selected_vars['longitude'].values = np.where(selected_vars['longitude'].values > 180, 
                                                     selected_vars['longitude'].values - 360, 
                                                     selected_vars['longitude'].values)

        # Convert xarray to pandas dataframe
        PT_df = selected_vars.to_dataframe().reset_index()

        # Create GeoDataFrame from the DataFrame
        PT_gdf = gpd.GeoDataFrame(PT_df, geometry=gpd.GeoSeries.from_xy(PT_df['longitude'], PT_df['latitude']))

        # Set the CRS and reproject if necessary
        PT_gdf.set_crs('EPSG:4326', inplace=True, allow_override=True)
        PT_gdf = PT_gdf.to_crs('EPSG:3174')

        # Release memory after processing the dataset
        del ds, selected_vars, PT_df  # Delete variables no longer needed
        gc.collect()  # Force garbage collection to release memory

        return PT_gdf

# Use dask.delayed to process each file (process files in smaller batches to avoid memory overload)
batch_size = 2  # Process in smaller batches of 2 files at a time
batches = [files[i:i + batch_size] for i in range(0, len(files), batch_size)]

all_gdfs = []

# Process each batch separately
for batch in batches:
    all_delayed = [process_file(file) for file in batch]
    batch_gdfs = compute(*all_delayed)  # Use dask.compute here
    all_gdfs.extend(batch_gdfs)  # Append the GeoDataFrames from this batch

    # Manually trigger garbage collection after each batch to release memory
    gc.collect()

# Now combine the GeoDataFrames from all batches into one
final_gdf = pd.concat(all_gdfs, ignore_index=True)

# Set final CRS if necessary
#final_gdf.set_crs('EPSG:4326', inplace=True, allow_override=True)
#final_gdf = final_gdf.to_crs('EPSG:3174')

# Final GeoDataFrame with all datasets
print(final_gdf)


In [None]:
# print the crs of the final_gdf
print("Final GeoDataFrame CRS:", final_gdf.crs)

In [None]:
# Suppress FutureWarnings related to pandas unique
warnings.filterwarnings("ignore", category=FutureWarning, module="xarray")

# Select relevant variables: time, group_id, group_number, longitude, latitude
selected_vars = datasets[['time', 'group_id', 'group_number','longitude', 'latitude']]

# Stack across 'time' and 'particles', dropping unwanted dimensions
#stacked_data = selected_vars.stack(particle_time=('time', 'particles')).drop_dims('dim_0', errors='ignore')

# Stack across 'time' and 'particles', dropping unwanted dimensions
stacked_vars = selected_vars.stack(particle_time=('time', 'particles'))
# Convert longitudes greater than 180 to the range -180 to 180

# If longitude values are greater than 180, subtract 360 to convert them to the range -180 to 180
selected_vars['longitude'].values = np.where(selected_vars['longitude'].values > 180, 
                                             selected_vars['longitude'].values - 360, 
                                             selected_vars['longitude'].values)


### Convert selected variable to pandas dataframe

In [None]:
selected_vars

### Convert dataframe to geodataframe


In [None]:
import multiprocessing as mp
from multiprocessing import pool

In [None]:
# Convert your Pandas DataFrame to a Dask DataFrame
PT_ddf = dd.from_pandas(PT_df, npartitions=2)  # Partition based on the number of cores

# Function to create geometry using GeoSeries.from_xy (efficient and vectorized)
def create_geometry(df):
    return gpd.GeoSeries.from_xy(df['longitude'], df['latitude'])

# Apply the function to create the geometry in each partition
PT_ddf['geometry'] = PT_ddf.map_partitions(create_geometry)

# Convert Dask DataFrame to GeoDataFrame (compute the result)
PT_gdf = gpd.GeoDataFrame(PT_ddf.compute(), geometry='geometry')

# Set CRS to EPSG:4326
PT_gdf.set_crs('EPSG:4326', inplace=True)

# Convert CRS to EPSG:3174 for Great Lakes Albers
PT_gdf = PT_gdf.to_crs('EPSG:3174')

In [None]:
# Check that longitude and latitude are valid
PT_df = PT_df.dropna(subset=['longitude', 'latitude'])

# Create geometry from valid points
geometry = [Point(xy) for xy in zip(PT_df['longitude'], PT_df['latitude'])]

# Create GeoDataFrame
PT_gdf = gpd.GeoDataFrame(PT_df, geometry=geometry)

# Set CRS to EPSG:4326
PT_gdf.set_crs('EPSG:4326', inplace=True)

# Convert CRS to EPSG:3174 for Great Lakes Albers
PT_gdf.to_crs('EPSG:3174', inplace=True)

In [None]:
PT_gdf = final_gdf.copy()

In [None]:
# Ensure the 'time' column is in datetime format
PT_df['time'] = pd.to_datetime(PT_df['time'], errors='coerce')

# Extract the month from the 'time' column
PT_df['month'] = PT_df['time'].dt.month

# Convert 'group_id' to integer (ensure the 'group_id' column exists and is valid)
PT_df['group_id'] = PT_df['group_id'].astype(int)

# Count the unique 'group_id' in each month
group_id_counts = PT_df.groupby('month')['group_id'].nunique().reset_index(name='unique_group_count')

# Create a DataFrame with all months (1 through 12)
all_months = pd.DataFrame({'month': range(1, 13)})

# Merge the result with all months to ensure all months are shown
group_id_counts_full = pd.merge(all_months, group_id_counts, on='month', how='left').fillna({'unique_group_count': 0})

# Print the result
print(group_id_counts_full)

In [None]:
# print the number of particles in each month
PT_df['time'] = PT_df.index.get_level_values('time')
PT_df['time'] = PT_df['time'].dt.month
PT_df['time'].value_counts()
print(PT_df['time'].value_counts())

In [None]:
#set the crs of cw_avg to 3174
CW_avg.to_crs('EPSG:3174', inplace=True)
CW_low.to_crs('EPSG:3174', inplace=True)
CW_high.to_crs('EPSG:3174', inplace=True)
CW_surge.to_crs('EPSG:3174', inplace=True)

In [None]:
print(CW_avg.crs)
print(CW_low.crs)
print(CW_high.crs)
print(CW_surge.crs)


## Find intersection of particle tracking with coastal wetlands

In [None]:
def calculate_first_intersections(gdf, CW_avg):
    """
    Calculate the first intersections of particles with the CW_avg for each month,
    and compute the percentage of particles intersecting for the first time.

    Parameters:
    - gdf: GeoDataFrame containing particle data with 'time', 'group_id', and 'group_number' columns.
    - CW_avg: GeoDataFrame representing the coastal wetlands with average lake level to find intersections with.

    Returns:
    - first_intersections: DataFrame with the first intersections for each particle group.
    - percentage_intersecting: float representing the percentage of particles intersecting for the first time.
    """

    # Step 1: Ensure 'time' and 'particles' are not both an index and a column
    if 'time' in gdf.index.names:
        gdf = gdf.reset_index(drop=False)

    # Step 2: Sort the GeoDataFrame by 'group_id' and 'group_number'
    gdf = gdf.sort_values(by=['group_id', 'group_number'])

    # Initialize a DataFrame to store the first intersections for each particle
    first_intersections = pd.DataFrame()

    # Initialize variables to store the count of total particles and intersecting particles
    total_particles_tracked = 0
    total_particles_intersecting = 0

    # Step 3: Loop through each month to find the first intersection with CW_avg
    unique_times = gdf['time'].dt.to_period('M').unique()

    for month in unique_times:
        # Filter the data for the current month
        monthly_gdf = gdf[gdf['time'].dt.to_period('M') == month]
        
        # Count total particles tracked in this month
        total_particles_tracked += monthly_gdf['group_id'].nunique()
        
        # Perform the spatial join to find intersections with CW_avg
        monthly_intersections = gpd.sjoin(monthly_gdf, CW_avg, how='inner', predicate='intersects')

        # Sort by 'group_id' and 'group_number' to ensure we find the first intersection
        monthly_intersections = monthly_intersections.sort_values(by=['group_id', 'group_number'])

        # Group by 'group_id' to find the first intersection for each particle group
        first_month_intersections = monthly_intersections.groupby('group_id').first().reset_index()

        # Append the first intersections for this month to the overall DataFrame
        first_intersections = pd.concat([first_intersections, first_month_intersections], ignore_index=True)

        # Step 4: Filter out particles that do not intersect at all
        particles_with_intersection = first_month_intersections[['group_id', 'group_number']]
        
        # Count how many particles intersect for the first time
        total_particles_intersecting += particles_with_intersection['group_id'].nunique()
        
        # Remove the particles from the original GeoDataFrame that don't intersect for this month
        gdf = gdf[~gdf.set_index(['group_id', 'group_number']).index.isin(
            particles_with_intersection.set_index(['group_id', 'group_number']).index)]

    # Step 5: Count the number of occurrences of each unique 'group_id'
    group_id_counts = first_intersections.groupby('group_id').size().reset_index(name='count')

    # Step 6: Calculate the percentage of particles that intersect for the first time
    if total_particles_tracked > 0:
        percentage_intersecting = (total_particles_intersecting / total_particles_tracked) * 100
        print(f"Percentage of particles intersecting for the first time: {percentage_intersecting:.2f}%")
    else:
        percentage_intersecting = 0.0

    return first_intersections, percentage_intersecting

## Calculate Monthly Intersections Seperately

In [None]:
import pandas as pd
import geopandas as gpd

def calculate_monthly_first_intersections(gdf, CW_avg, output_csv="monthly_first_intersections.csv"):
    """
    Loop through each month in the dataset, find only the first intersection per particle, 
    and save the results to a single CSV file. Also calculates the percentage of 
    intersecting particles per month.

    Parameters:
    - gdf: GeoDataFrame with 'time', 'group_id', 'group_number', 'geometry'
    - CW_avg: GeoDataFrame representing the wetland shapefile
    - output_csv: Optional output CSV filename

    Returns:
    - monthly_first_intersections_df: DataFrame of first-time intersections
    - summary_stats_df: DataFrame of intersection percentages per month
    """

    # ✅ Step 1: Reset index if needed
    if 'time' in gdf.index.names:
        print("🔄 Resetting 'time' from index to column...")
        gdf = gdf.reset_index()

    if 'time' not in gdf.columns:
        print("❌ ERROR: 'time' column is missing!")
        return None, None

    gdf['time'] = pd.to_datetime(gdf['time'], errors='coerce')

    # ✅ Ensure CRS matches
    if gdf.crs != CW_avg.crs:
        print("🔄 Reprojecting gdf to match CW_avg CRS...")
        gdf = gdf.to_crs(CW_avg.crs)

    # ✅ Extract month
    gdf['month'] = gdf['time'].dt.to_period("M")

    # ✅ Precompute wetland union geometry
    wetland_union = CW_avg.geometry.union_all()

    # ✅ Initialize storage
    monthly_results = []
    summary_stats = []

    # ✅ Loop over each unique month
    unique_months = sorted(gdf['month'].unique())
    for month in unique_months:
        print(f"📅 Processing {month}...")

        monthly_gdf = gdf[gdf['month'] == month]
        total_particles = monthly_gdf[['group_id', 'group_number']].drop_duplicates().shape[0]

        intersecting_particles = monthly_gdf[monthly_gdf['geometry'].intersects(wetland_union)]

        if not intersecting_particles.empty:
            first_intersections = (
                intersecting_particles
                .sort_values(by=['group_id', 'group_number', 'time'])
                .groupby(['group_id', 'group_number'])
                .first()
                .reset_index()
            )

            total_first = first_intersections.shape[0]
            percentage = (total_first / total_particles) * 100 if total_particles > 0 else 0

            print(f"✅ {month}: {total_first} particles first-time intersected ({percentage:.2f}%)")

            first_intersections['month'] = str(month)
            monthly_results.append(first_intersections)
            summary_stats.append({'month': str(month), 'percentage': percentage})
        else:
            print(f"⚠️ {month}: No intersections found.")

    # ✅ Combine all results
    if monthly_results:
        monthly_first_intersections_df = pd.concat(monthly_results, ignore_index=True)
    else:
        monthly_first_intersections_df = pd.DataFrame()

    # ✅ Convert summary list to DataFrame
    summary_stats_df = pd.DataFrame(summary_stats)

    # ✅ Save to CSV
    monthly_first_intersections_df.to_csv(output_csv, index=False)
    print(f"✅ Monthly first-time intersections saved to {output_csv}")

    # ✅ Return both
    return monthly_first_intersections_df, summary_stats_df


In [None]:
intersections_df, monthly_percentages_df = calculate_monthly_first_intersections(PT_gdf, CW_avg)


In [None]:
import matplotlib.pyplot as plt

def plot_monthly_intersection_barchart(summary_df, title="Monthly % of First-Time Intersections"):
    """
    Creates a bar chart showing the percentage of particles that intersected 
    the coastal wetland for the first time each month.

    Parameters:
    - summary_df: DataFrame with 'month' and 'percentage' columns
    - title: Title for the plot

    Returns:
    - None (displays the plot)
    """
    # Ensure the data is sorted by month
    summary_df = summary_df.sort_values(by="month")

    # Plot setup
    plt.figure(figsize=(12, 6))
    bars = plt.bar(summary_df["month"], summary_df["percentage"], color='skyblue', edgecolor='black')

    # Annotate each bar with the percentage
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, height + 0.5, f"{height:.1f}%", 
                 ha='center', va='bottom', fontsize=10)

    plt.xticks(rotation=45)
    plt.ylabel("Intersection Percentage (%) ", fontsize=12)
    plt.xlabel("Month", fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.savefig(output_figures + '/Monthly_First_Time_Intersections_maytoOct_2.png')
    plt.show()



In [None]:
plot_monthly_intersection_barchart(monthly_percentages_df)


### Average Inundation

In [None]:
first_intersections_avgInun, percentage_intersecting = calculate_first_intersections(PT_gdf, CW_avg)


In [None]:
monthly_first_intersections_avg = calculate_monthly_first_intersections(PT_gdf, CW_avg, "monthly_first_intersections.csv")


In [None]:
PT_gdf['time'] = PT_gdf.index.get_level_values('time')
PT_gdf['time'] = PT_gdf['time'].dt.month
PT_gdf['time'].value_counts()
print(PT_gdf['time'])

### Hight Inundation

In [None]:
first_intersections_HighInun, percentage_intersecting = calculate_first_intersections(PT_gdf, CW_high)


### Low Inundation

In [None]:
first_intersections_LowInun, percentage_intersecting = calculate_first_intersections(PT_gdf, CW_low)


### Surge Inundation

In [None]:
first_intersections_SurgeInun, percentage_intersecting = calculate_first_intersections(PT_gdf, CW_surge)


# Save Outputs

In [None]:
# Optional: Save the first_intersections and group_id counts to CSV files
first_intersections_avgInun.to_csv(os.path.join(output_path, Intersection_PTCW_Avg), index=False)

first_intersections_HighInun.to_csv(os.path.join(output_path, Intersection_PTCW_high), index=False)

first_intersections_LowInun.to_csv(os.path.join(output_path, Intersection_PTCW_low), index=False)

first_intersections_SurgeInun.to_csv(os.path.join(output_path, Intersection_PTCW_surge), index=False)

In [None]:
monthly_first_intersections_avg.to_csv(os.path.join(output_path,'monthly_first_intersections_avg.csv'), index=False)

In [None]:
# read the csv files from output path
first_intersections_avgInun = pd.read_csv(os.path.join(output_path, Intersection_PTCW_Avg))
first_intersections_highInun = pd.read_csv(os.path.join(output_path, Intersection_PTCW_high))
first_intersections_lowInun = pd.read_csv(os.path.join(output_path, Intersection_PTCW_low))
first_intersections_surgeInun = pd.read_csv(os.path.join(output_path, Intersection_PTCW_surge))


In [None]:
# group_id_counts.to_csv(os.path.join(output_path, GroupIdcount), index=False)

# # Optional: Save the filtered GeoDataFrame (after removing non-intersecting particles)
# gdf.to_csv(os.path.join(output_path, Non_Intersection), index=False)

# print("First intersections, group_id counts, filtered particles, and percentage calculation saved.")

# 3-Add NP load to the particles

#### You can look at DirectNutrientload Directory, ZonalStats code to see how we obtain Zonal Stats 

## Function to merge Intersection with NP Load

In [None]:
def merged_ZonalStats(first_intersections_avgInun, Direct_Nutrient):
    """
    Merges particle tracking data with nutrient load data and adjusts nutrient loads
    based on occurrences of each group_id.

    Parameters:
    - first_intersections_avgInun: DataFrame containing particle tracking data with a 'group_id' column.
    - Direct_Nutrient: DataFrame containing nutrient load data with 'group_id', 'WetLoad_TN_kg2', and 'WetLoad_TP_kg2' columns.

    Returns:
    - merged_data: DataFrame with merged data and adjusted 'WetLoad_TN_kg2' and 'WetLoad_TP_kg2' values.
    - returns the wetload values divided by the group_id counts the get the values for each group_id count
    """
    
    # Step 1: Count occurrences of each group_id in the NetCDF file and store as a dictionary
    group_id_counts = first_intersections_avgInun['group_id'].value_counts().to_dict()
    
    # Step 2: Merge particle data with nutrient load data
    merged_data = pd.merge(first_intersections_avgInun, Direct_Nutrient, on='group_id', how='left')
    
    # Step 3: Adjust WetLoad_TN_kg2 and WetLoad_TP_kg2 values by dividing by group_id count
    merged_data[fieldDirectTN] = merged_data.apply(
        lambda row: row[fieldDirectTN] / group_id_counts[row['group_id']] if pd.notnull(row[fieldDirectTN]) else None,
        axis=1
    )

    merged_data[fieldDirectTN] = merged_data.apply(
        lambda row: row[fieldDirectTN] / group_id_counts[row['group_id']] if pd.notnull(row[fieldDirectTN]) else None,
        axis=1
    )
    merged_data[fieldDirectTP] = merged_data.apply(
        lambda row: row[fieldDirectTP] / group_id_counts[row['group_id']] if pd.notnull(row[fieldDirectTP]) else None,
        axis=1
    )
    # drop unnecessary columns
    #merged_data.drop(columns=['index_right','ID','GRIDCODE','start_lat','start_lon'], inplace=True)
    return merged_data


## merging the Intersections for different lake levels

In [None]:
# Merge particle tracking data_average inundation with nutrient load data and adjust nutrient loads
merged_data_avg = merged_ZonalStats(first_intersections_avgInun, Direct_Nutrient)


In [None]:
# Merge particle tracking data_high inundation with nutrient load data and adjust nutrient loads
merged_data_high = merged_ZonalStats(first_intersections_highInun, Direct_Nutrient)

In [None]:
# Merge particle tracking data_low inundation with nutrient load data and adjust nutrient loads
merged_data_low = merged_ZonalStats(first_intersections_lowInun, Direct_Nutrient)

In [None]:
# Merge particle tracking data_surge inundation with nutrient load data and adjust nutrient loads
merged_data_surge = merged_ZonalStats(first_intersections_surgeInun, Direct_Nutrient)

# Plots

### Plot particles return to Coastal Wetlands

In [None]:
def plot_combined_wetload_distribution(
    merged_data, 
    grid_metrics_file_name, 
    extents, 
    cmap_tn, 
    cmap_tp, 
    font_size=15, 
    title="Combined Wetload Distribution"
):
    """
    Creates a combined hexbin plot showing WetLoad_TN_kg2 and WetLoad_TP_kg2 concentrations 
    in coastal wetlands with bathymetry background.

    Parameters:
    - merged_data: DataFrame containing 'longitude', 'latitude', 'WetLoad_TN_kg2', and 'WetLoad_TP_kg2' columns.
    - grid_metrics_file_name: Path to the NetCDF file containing bathymetry data.
    - extents: List defining the geographic extents for plotting [xmin, xmax, ymin, ymax].
    - cmap_tn: Colormap for WetLoad_TN_kg2.
    - cmap_tp: Colormap for WetLoad_TP_kg2.
    - font_size: Integer representing the font size in the plot.
    - title: Optional string to set a custom plot title.

    Returns:
    - None; displays a combined plot.
    """

    # Collect coordinates and wetload data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tn = merged_data[fieldDirectTN].values
    wetload_tp = merged_data[fieldDirectTP].values

    # Create the figure and axis using the FVCOM plotter
    fig, ax = create_figure(figure_size=(26., 26.), projection=ccrs.PlateCarree(), font_size=font_size, bg_color='gray')

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter and plot bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)
    ax, plot = plotter.plot_field(ax, bathy, extents=extents, add_colour_bar=True, cb_label='Depth (m)', vmin=-60., vmax=0., cmap='Blues', zorder=0)
    plotter.draw_grid(ax, linewidth=1.0)

    # Create hexbin plot for WetLoad_TN_kg2
    hb_tn = ax.hexbin(
        all_coords[:, 0], 
        all_coords[:, 1], 
        C=wetload_tn, 
        gridsize=800, 
        # use viridis colormap for TN
        cmap = 'viridis',
        #cmap=cmap_tn, 
        norm=LogNorm(), 
        reduce_C_function=np.sum, 
        zorder=40, 
        alpha=0.6, 
        label="WetLoad_TN",
        color = 'green'
    )

    # Create hexbin plot for WetLoad_TP_kg2
    hb_tp = ax.hexbin(
        all_coords[:, 0], 
        all_coords[:, 1], 
        C=wetload_tp, 
        gridsize=80, 
        cmap=cmap_tp, 
        norm=LogNorm(), 
        reduce_C_function=np.sum, 
        zorder=40, 
        alpha=0.6, 
        label="WetLoad_TP",
        color = 'pink'
    )

    # Add color bars for both loads
    cbar_tn = fig.colorbar(hb_tn, ax=ax, pad=0.05, fraction=0.05)
    #cbar_tn.set_label('WetLoad_TN', fontsize=font_size)

    cbar_tp = fig.colorbar(hb_tp, ax=ax, pad=0.14, fraction=0.05)
    #cbar_tp.set_label('WetLoad_TP', fontsize=font_size)

    # Set axis labels and title
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_title(title, fontsize=font_size)

    # Show legend
    ax.legend(loc='upper right')

    # Show the plot
    plt.show()


In [None]:
def plot_wetload_distribution_TN_with_three_insets(
    merged_data, 
    grid_metrics_file_name, 
    main_extents, 
    zoom_extents_1, 
    zoom_extents_2, 
    zoom_extents_3, 
    color_map, 
    font_size=15, 
    title=None,
    inset_position_1=[0.6, 0.6, 0.25, 0.25],  
    inset_position_2=[0.2, 0.6, 0.25, 0.25],  
    inset_position_3=[0.4, 0.2, 0.25, 0.25],  
    colorbar_outside=True  
):
    """
    Creates a hexbin plot showing the WetLoad_TN_kg2 concentration in coastal wetlands with bathymetry background,
    with three zoomed-in maps inside or near the main map.

    Fixes:
    ✅ Each zoom-in map title matches the color of its highlight rectangle.
    ✅ Axis ticks & labels removed for zoom-in maps.
    ✅ Smaller font sizes for zoom-in maps.

    """

    # Collect coordinates and WetLoad_TN_kg2 data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tn = merged_data['WetLoad_TN_kgcellday'].values

    # Create the figure and main axis
    fig, ax_main = plt.subplots(figsize=(26., 26.), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(title if title else 'Indirect Annual Nitrogen Load to Coastal Wetlands', fontsize=font_size + 5)

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter for bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)

    # Plot the main extent
    ax_main, plot_main = plotter.plot_field(
        ax_main, bathy, extents=main_extents, add_colour_bar=True, cb_label='Depth(m)', vmin=-60., vmax=0., cmap='Blues', zorder=0
    )
    plotter.draw_grid(ax_main, linewidth=0.5)
    hb_main = ax_main.hexbin(
        all_coords[:, 0], 
        all_coords[:, 1], 
        C=wetload_tn, 
        gridsize=50, 
        cmap=color_map, 
        norm=LogNorm(), 
        reduce_C_function=np.sum, 
        zorder=40
    )

    # Add color bar for main extent
    if colorbar_outside:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, pad=0.1)
    else:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, shrink=0.8, location="right", pad=0.02)
    cbar_main.set_label('Nitrogen Load (kg/cell/day)', fontsize=font_size)
    #ax_main.set_title( fontsize=font_size)
    ax_main.set_xlabel('Longitude', fontsize=font_size)
    ax_main.set_ylabel('Latitude', fontsize=font_size)

    # Define inset positions, extents, and titles
    inset_positions = [inset_position_1, inset_position_2, inset_position_3]
    zoom_extents = [zoom_extents_1, zoom_extents_2, zoom_extents_3]
    colors = ['red', 'blue', 'green']
    
    # Iterate over three insets
    for i, (inset_pos, zoom_extent, color) in enumerate(zip(inset_positions, zoom_extents, colors)):
        # Create inset axis
        ax_inset = plt.axes(inset_pos, projection=ccrs.PlateCarree())
        ax_inset.set_extent(zoom_extent, crs=ccrs.PlateCarree())
        
        # Plot inset bathymetry
        ax_inset, plot_zoom = plotter.plot_field(
            ax_inset, bathy, extents=zoom_extent, add_colour_bar=False, cb_label=None, vmin=-60., vmax=0., cmap='Blues', zorder=0
        )
        plotter.draw_grid(ax_inset, linewidth=0.5)

        # Add hexbin plot
        hb_zoom = ax_inset.hexbin(
            all_coords[:, 0], 
            all_coords[:, 1], 
            C=wetload_tn, 
            gridsize=50, 
            cmap=color_map, 
            norm=LogNorm(), 
            reduce_C_function=np.sum, 
            zorder=40
        )

        # Add color bar inside the inset with smaller font
        #cbar_zoom = fig.colorbar(hb_zoom, ax=ax_inset, shrink=0.6, pad=0.02)
        #cbar_zoom.set_label('N Load', fontsize=font_size - 6)  

        # Set inset title with matching color
        ax_inset.set_title(f"Enlarge view {i+1}", fontsize=font_size - 4, color=color, fontweight='bold')

        # Remove axis ticks & labels
        #ax_inset.set_xticks([])
        #ax_inset.set_yticks([])
        #ax_inset.set_xlabel('')
        #ax_inset.set_ylabel('')

        # Add rectangle to highlight the zoomed-in region with matching color
        rect = plt.Rectangle(
            (zoom_extent[0], zoom_extent[2]), 
            zoom_extent[1] - zoom_extent[0], 
            zoom_extent[3] - zoom_extent[2],
            linewidth=2, edgecolor=color, facecolor='none', transform=ccrs.PlateCarree(), zorder=50
        )
        ax_main.add_patch(rect)
        
    #ax.set_title(title if title else 'Indirect Annual Nitrogen Load to Coastal Wetlands with lowe Inundation (kg^2/area)', fontsize=font_size)
    plt.savefig(output_figures + '/WetLoadDistribution_AvgInun_Nitrogen_Zoombox.png', dpi=300, bbox_inches='tight')
    # Show the plot
    plt.tight_layout()
    plt.show()




In [None]:
def plot_wetload_distribution_TN_with_three_insets(
    merged_data, 
    grid_metrics_file_name, 
    main_extents, 
    zoom_extents_1, 
    zoom_extents_2, 
    zoom_extents_3, 
    color_map, 
    font_size=15, 
    title=None,
    inset_position_1=[0.30, 0.50, 0.25, 0.25],  # [left, bottom, width, height]
    inset_position_2= [0.1, 0.40, 0.18, 0.18],  # [left, bottom, width, height]
    inset_position_3=[0.50, 0.25,0.18, 0.18],  # [left, bottom, width, height] 
    colorbar_outside=True  
    #save_fig=True
    
):
    """
    Creates a hexbin plot showing the WetLoad_TN_kg2 concentration in coastal wetlands with bathymetry background,
    with three zoomed-in maps inside or near the main map. 
    Each zoom-in region is enclosed with a colored box that matches the zoomed-in map's title color.
    """

    # Collect coordinates and WetLoad_TN_kg2 data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tn = merged_data['WetLoad_TN_kgcellday'].values

    # Create the figure and main axis
    fig, ax_main = plt.subplots(figsize=(26., 26.), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(title if title else 'Indirect Annual Nitrogen Load to Coastal Wetlands', fontsize=font_size + 5)

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter for bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)

    # Plot the main extent
    ax_main, plot_main = plotter.plot_field(
        ax_main, bathy, extents=main_extents, add_colour_bar=True, cb_label='Depth(m)', vmin=-60., vmax=0., cmap='Blues', zorder=0
    )
    plotter.draw_grid(ax_main, linewidth=0.5)
    hb_main = ax_main.hexbin(
        all_coords[:, 0], 
        all_coords[:, 1], 
        C=wetload_tn, 
        gridsize=50, 
        cmap=color_map, 
        norm=LogNorm(), 
        reduce_C_function=np.sum, 
        zorder=40
    )

    # Add color bar for main extent with increased font size
    if colorbar_outside:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, pad=0.1)
    else:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, shrink=0.8, location="right", pad=0.15)
    cbar_main.set_label('Nitrogen Load (kg/cell/day)', fontsize=font_size + 5)  # Increased font size for label
    cbar_main.ax.tick_params(labelsize=font_size + 2)  # Increase font size for tick labels

    ax_main.set_xlabel('Longitude', fontsize=font_size)
    ax_main.set_ylabel('Latitude', fontsize=font_size)

    # Define inset positions, extents, and titles
    inset_positions = [inset_position_1, inset_position_2, inset_position_3]
    zoom_extents = [zoom_extents_1, zoom_extents_2, zoom_extents_3]
    colors = ['red', 'blue', 'green']
    
    # Iterate over three insets
    for i, (inset_pos, zoom_extent, color) in enumerate(zip(inset_positions, zoom_extents, colors)):
        # Create inset axis
        ax_inset = plt.axes(inset_pos, projection=ccrs.PlateCarree())
        ax_inset.set_extent(zoom_extent, crs=ccrs.PlateCarree())
        
        # Plot inset bathymetry
        ax_inset, plot_zoom = plotter.plot_field(
            ax_inset, bathy, extents=zoom_extent, add_colour_bar=False, cb_label=None, vmin=-60., vmax=0., cmap='Blues', zorder=0
        )
        plotter.draw_grid(ax_inset, linewidth=0.5)

        # Add hexbin plot for the inset zoom area
        hb_zoom = ax_inset.hexbin(
            all_coords[:, 0], 
            all_coords[:, 1], 
            C=wetload_tn, 
            gridsize=50, 
            cmap=color_map, 
            norm=LogNorm(), 
            reduce_C_function=np.sum, 
            zorder=40
        )

        # Set inset title with matching color
        ax_inset.set_title(f"Enlarge view {i+1}", fontsize=font_size - 4, color=color, fontweight='bold')

        # Add rectangle to highlight the zoomed-in region with matching color
        rect = plt.Rectangle(
            (zoom_extent[0], zoom_extent[2]), 
            zoom_extent[1] - zoom_extent[0], 
            zoom_extent[3] - zoom_extent[2],
            linewidth=2, edgecolor=color, facecolor='none', transform=ccrs.PlateCarree(), zorder=50
        )
        ax_main.add_patch(rect)

        # **NEW CODE**: Draw a colored border around each zoom-in box
        inset_box = plt.Rectangle(
            (inset_pos[0], inset_pos[1]),  # Position of the inset box on the figure
            inset_pos[2],                  # Width of the inset box
            inset_pos[3],                  # Height of the inset box
            linewidth=3, edgecolor=color, facecolor='none', linestyle='-', zorder=60
        )
        ax_main.add_patch(inset_box)
        
    # Save the plot
    plt.savefig (output_figures + '/WetLoadDistribution_AvgInun_Nitrogen_Zoombox.png',dpi=300, bbox_inches='tight')
    # Show the plot
    plt.tight_layout()
    plt.show()


## Plot Annual Nitorgen load with enlarge maps

In [None]:
plot_wetload_distribution_TN_with_three_insets(
    merged_data=merged_data_avg,
    grid_metrics_file_name=grid_metrics_file_name,
    main_extents=[275,279,43,46.3],  # Main extent
    zoom_extents_1=[276.5,276.8,44.8,45.5],#zoomed-in extent
    zoom_extents_2=[276.5, 276, 43.58, 44],   # Zoomed-in extent
    zoom_extents_3=[277.5, 277, 43.5, 44],   # Zoomed-in extent
    color_map=plasma_cmap,  # Colormap for nitrogen
    font_size=24,
    inset_position_1=[0.30, 0.50, 0.25, 0.25],  # [left, bottom, width, height]
    inset_position_2= [0.1, 0.40, 0.18, 0.18],  # [left, bottom, width, height]
    inset_position_3=[0.50, 0.25,0.18, 0.18],  # [left, bottom, width, height]
    colorbar_outside=True,
    title="Indirect Nitrogen Load to Coastal Wetlands with Average Inundation(kg/cell/day)"

)
#lt.savefig(output_figures + '/WetLoadDistribution_AvgInun_Nitrogen_Nov-Apr.png', dpi=300, bbox_inches='tight')

In [None]:
output_figures

In [None]:
plot_wetload_distribution_TN_with_three_insets(
    merged_data=merged_data_high,
    grid_metrics_file_name=grid_metrics_file_name,
    main_extents=[275,279,43,46.3],  # Main extent
    zoom_extents_1=[276.5,276.8,44.8,45.5],#zoomed-in extent
    zoom_extents_2=[276.5, 276, 43.58, 44],   # Zoomed-in extent
    zoom_extents_3=[277.5, 277, 43.5, 44],   # Zoomed-in extent
    color_map=plasma_cmap,  # Colormap for nitrogen
    font_size=24,
    inset_position_1=[0.30, 0.50, 0.25, 0.25],  # [left, bottom, width, height]
    inset_position_2= [0.1, 0.40, 0.18, 0.18],  # [left, bottom, width, height]
    inset_position_3=[0.50, 0.25,0.18, 0.18],  # [left, bottom, width, height]
    colorbar_outside=True,
    title="Indirect Annual Nitrogen Load to Coastal Wetlands with High Inundation( kg/cell/day)")


In [None]:
plot_wetload_distribution_TN_with_three_insets(
    merged_data=merged_data_low,
    grid_metrics_file_name=grid_metrics_file_name,
    main_extents=[275,279,43,46.3],  # Main extent
    zoom_extents_1=[276.5,276.8,44.8,45.5],#zoomed-in extent
    zoom_extents_2=[276.5, 276, 43.58, 44],   # Zoomed-in extent
    zoom_extents_3=[277.5, 277, 43.5, 44],   # Zoomed-in extent
    color_map=plasma_cmap,  # Colormap for nitrogen
    font_size=24,
    inset_position_1=[0.30, 0.50, 0.25, 0.25],  # [left, bottom, width, height]
    inset_position_2= [0.1, 0.40, 0.18, 0.18],  # [left, bottom, width, height]
    inset_position_3=[0.50, 0.25,0.18, 0.18],  # [left, bottom, width, height]
    colorbar_outside=True,
    title="Indirect Annual Nitrogen Load to Coastal Wetlands with Low Inundation( kg/cell/day)")

## Plot P load with enlarge map 

In [None]:
def plot_wetload_distribution_TP_with_three_insets(
    merged_data, 
    grid_metrics_file_name, 
    main_extents, 
    zoom_extents_1, 
    zoom_extents_2, 
    zoom_extents_3, 
    color_map, 
    font_size=15, 
    title=None,
    inset_position_1=[0.6, 0.6, 0.25, 0.25],  
    inset_position_2=[0.2, 0.6, 0.25, 0.25],  
    inset_position_3=[0.4, 0.2, 0.25, 0.25],  
    colorbar_outside=True  
):
    """
    Creates a hexbin plot showing the WetLoad_TP_kg2 concentration in coastal wetlands with bathymetry background,
    with three zoomed-in maps inside or near the main map.

    Fixes:
    ✅ Each zoom-in map title matches the color of its highlight rectangle.
    ✅ Axis ticks & labels removed for zoom-in maps.
    ✅ Smaller font sizes for zoom-in maps.
    """

    # Collect coordinates and WetLoad_TP_kg2 data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tp = merged_data['WetLoad_TP_kgcellday'].values  # Change from TN to TP

    # Create the figure and main axis
    fig, ax_main = plt.subplots(figsize=(26., 26.), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(title if title else 'Indirect Annual Phosphorus Load to Coastal Wetlands', fontsize=font_size + 5)

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter for bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)

    # Plot the main extent
    ax_main, plot_main = plotter.plot_field(
        ax_main, bathy, extents=main_extents, add_colour_bar=True, cb_label='Depth (m)', vmin=-60., vmax=0., cmap='Blues', zorder=0
    )
    plotter.draw_grid(ax_main, linewidth=0.5)
    hb_main = ax_main.hexbin(
        all_coords[:, 0], 
        all_coords[:, 1], 
        C=wetload_tp,  # Change from TN to TP
        gridsize=50, 
        cmap=color_map, 
        norm=LogNorm(), 
        reduce_C_function=np.sum, 
        zorder=40
    )

    # Add color bar for main extent with increased font size
    if colorbar_outside:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, pad=0.1)
    else:
        cbar_main = fig.colorbar(hb_main, ax=ax_main, shrink=0.8, location="right", pad=0.02)
    cbar_main.set_label('Phosphorus Load (kg/cell/day)', fontsize=font_size + 5)  # Increased font size for label
    cbar_main.ax.tick_params(labelsize=font_size + 2)  # Increase font size for tick labels

    ax_main.set_xlabel('Longitude', fontsize=font_size)
    ax_main.set_ylabel('Latitude', fontsize=font_size)

    # Define inset positions, extents, and titles
    inset_positions = [inset_position_1, inset_position_2, inset_position_3]
    zoom_extents = [zoom_extents_1, zoom_extents_2, zoom_extents_3]
    colors = ['red', 'blue', 'green']
    
    # Iterate over three insets
    for i, (inset_pos, zoom_extent, color) in enumerate(zip(inset_positions, zoom_extents, colors)):
        # Create inset axis
        ax_inset = plt.axes(inset_pos, projection=ccrs.PlateCarree())
        ax_inset.set_extent(zoom_extent, crs=ccrs.PlateCarree())
        
        # Plot inset bathymetry
        ax_inset, plot_zoom = plotter.plot_field(
            ax_inset, bathy, extents=zoom_extent, add_colour_bar=False, cb_label=None, vmin=-60., vmax=0., cmap='Blues', zorder=0
        )
        plotter.draw_grid(ax_inset, linewidth=0.5)

        # Add hexbin plot
        hb_zoom = ax_inset.hexbin(
            all_coords[:, 0], 
            all_coords[:, 1], 
            C=wetload_tp,  # Change from TN to TP
            gridsize=50, 
            cmap=color_map, 
            norm=LogNorm(), 
            reduce_C_function=np.sum, 
            zorder=40
        )

        # Set inset title with matching color
        ax_inset.set_title(f"Enlarged View {i+1}", fontsize=font_size - 4, color=color, fontweight='bold')

        # Remove axis ticks & labels
        ax_inset.set_xticks([])
        ax_inset.set_yticks([])
        ax_inset.set_xlabel('')
        ax_inset.set_ylabel('')

        # Add rectangle to highlight the zoomed-in region with matching color
        rect = plt.Rectangle(
            (zoom_extent[0], zoom_extent[2]), 
            zoom_extent[1] - zoom_extent[0], 
            zoom_extent[3] - zoom_extent[2],
            linewidth=2, edgecolor=color, facecolor='none', transform=ccrs.PlateCarree(), zorder=50
        )
        ax_main.add_patch(rect)
        
    # Save the plot
    plt.savefig(output_figures + '/WetLoadDistribution_AvgInun_Phosphorus_Zoombox.png', dpi=300, bbox_inches='tight')

    # Show the plot
    plt.tight_layout()
    plt.show()


In [None]:
plot_wetload_distribution_TP_with_three_insets(
    merged_data=merged_data_avg,
    grid_metrics_file_name=grid_metrics_file_name,
    main_extents=[275,279,43,46.3],  # Main extent
    zoom_extents_1=[276.5,276.8,44.8,45.5],#zoomed-in extent
    zoom_extents_2=[276.5, 276, 43.58, 44],   # Zoomed-in extent
    zoom_extents_3=[277.5, 277, 43.5, 44],   # Zoomed-in extent
    color_map=plasma_cmap,  # Colormap for nitrogen
    font_size=24,
    inset_position_1=[0.30, 0.50, 0.25, 0.25],  # [left, bottom, width, height]
    inset_position_2= [0.1, 0.40, 0.18, 0.18],  # [left, bottom, width, height]
    inset_position_3=[0.50, 0.25,0.18, 0.18],  # [left, bottom, width, height]
    colorbar_outside=True,
    title="Indirect Annual Phosphorus Load to Coastal Wetlands with Average Inundation (kg^2/area)"
)

## Plot Nitrogen return to Coastal Wetlands

In [None]:
def plot_wetload_distribution_TN(merged_data, grid_metrics_file_name, extents, color_map, font_size=15, title=None):
    """
    Creates a hexbin plot showing the WetLoad_TN_kg2 concentration in coastal wetlands with bathymetry background.

    Parameters:
    - merged_data: DataFrame containing 'longitude', 'latitude', and 'WetLoad_TN_kg2' columns.
    - grid_metrics_file_name: Path to the NetCDF file containing bathymetry data.
    - extents: List defining the geographic extents for plotting [xmin, xmax, ymin, ymax].
    - color_map: Colormap for hexbin particle data.
    - font_size: Integer representing the font size in the plot.
    - title: Optional string to set a custom plot title.

    Returns:
    - None; displays a plot.
    """

    # Collect coordinates and WetLoad_TN_kg2 data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tn = merged_data[fieldDirectTN].values

    # Create the figure and axis using the FVCOM plotter
    fig, ax = create_figure(figure_size=(26., 26.), projection=ccrs.PlateCarree(), font_size=font_size, bg_color='gray')

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter and plot bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)
    ax, plot = plotter.plot_field(ax, bathy, extents=extents, add_colour_bar=True, cb_label='Depth (m)', vmin=-60., vmax=0., cmap='Blues', zorder=0)
    plotter.draw_grid(ax, linewidth=1.0)

    # Create a hexbin plot where bins reflect WetLoad_TN_kg2 concentration
    hb = ax.hexbin(all_coords[:, 0], all_coords[:, 1], C=wetload_tn, gridsize=50, cmap=color_map, norm=LogNorm(), reduce_C_function=np.sum, zorder=40)

    # Add color bar for WetLoad_TN_kg2 concentration
    cbar = fig.colorbar(hb, ax=ax, pad=0.1)
    cbar.set_label('Annual Indirect Nitrogen Load (kg/cell/day)', fontsize=font_size)

    # Set axis labels and title
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_title(title if title else 'Indirect Annual Nitrogen Load to Coastal Wetlands with lowe Inundation (kg^2/area)', fontsize=font_size)
    plt.savefig(output_figures + '/WetLoadDistribution_SurgeInun_Nitrogen.png', dpi=300, bbox_inches='tight')
    # Show the plot
    plt.show()


In [None]:
plot_wetload_distribution_TN(
    merged_data=merged_data_avg,
    grid_metrics_file_name= grid_metrics_file_name,
    extents=extents,
    color_map=plasma_cmap,
    font_size=12,
    title='Annual Indirect Nitrogen Load to Coastal Wetlands with Average Inundation in 2023 (kg²/area)'
)

#plt.savefig(output_path + '/WetLoadDistribution_AvgInun.png', dpi=300, bbox_inches='tight')

In [None]:
plot_wetload_distribution_TN(
    merged_data=merged_data_high,
    grid_metrics_file_name= grid_metrics_file_name,
    extents=extents,
    color_map=plasma_cmap,
    font_size=12,
    title='Annual Indirect Nitrogen Load to Coastal Wetlands with High Inundation in 2023 (kg²/area)'
)

In [None]:
plot_wetload_distribution_TN(merged_data=merged_data_low,
    grid_metrics_file_name= grid_metrics_file_name,
    extents=extents,
    color_map=plasma_cmap,
    font_size=12,
    title='Annual Indirect Nitrogen Load to Coastal Wetlands with Low Inundation in 2023 (kg²/area)'
)

In [None]:
plot_wetload_distribution_TN(merged_data=merged_data_surge, 
    grid_metrics_file_name= grid_metrics_file_name,
    extents=extents,
    color_map=plasma_cmap,
    font_size=12,
    title='Annual Indirect Nitrogen Load to Coastal Wetlands with Surge Inundation in 2023 (kg²/area)'
)

## Plot Phosphorus return to coastal Wetlands 

In [None]:
col

In [None]:
def plot_wetload_distribution_TP(merged_data, grid_metrics_file_name, extents, colourmap, font_size=15, title=None):
    """
    Creates a hexbin plot showing the WetLoad_TP_kg2 concentration in coastal wetlands with bathymetry background.

    Parameters:
    - merged_data: DataFrame containing 'longitude', 'latitude', and 'WetLoad_TP_kg2' columns.
    - grid_metrics_file_name: Path to the NetCDF file containing bathymetry data.
    - extents: List defining the geographic extents for plotting [xmin, xmax, ymin, ymax].
    - green_cmap: Colormap for hexbin phosphorus load data.
    - font_size: Integer representing the font size in the plot.
    - title: Optional string to set a custom plot title.

    Returns:
    - None; displays a plot.
    """

    # Collect coordinates and WetLoad_TP_kg2 data for plotting
    all_coords = np.array(list(zip(merged_data['longitude'], merged_data['latitude'])))
    wetload_tp = merged_data[fieldDirectTP].values

    # Create the figure and axis using the FVCOM plotter
    fig, ax = create_figure(figure_size=(26., 26.), projection=ccrs.PlateCarree(), font_size=font_size, bg_color='gray')

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure plotter and plot bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)
    ax, plot = plotter.plot_field(ax, bathy, extents=extents, add_colour_bar=True, cb_label='Depth (m)', vmin=-60., vmax=0., cmap='Blues', zorder=0)
    plotter.draw_grid(ax, linewidth=1.0)

    # Create a hexbin plot where bins reflect WetLoad_TP_kg2 concentration
    hb = ax.hexbin(all_coords[:, 0], all_coords[:, 1], C=wetload_tp, gridsize=50, cmap=viridis_cmap, norm=LogNorm(), reduce_C_function=np.sum, zorder=40)

    # Add color bar for WetLoad_TP_kg2 concentration
    cbar = fig.colorbar(hb, ax=ax, pad=0.1)
    cbar.set_label('Annual Indirect Phosphorus Load (kg/cell/day)', fontsize=font_size)

    # Set axis labels and title
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_title(title if title else 'Annual Indirect Phosphorus Load to Coastal Wetlands with Average Inundation in 2023 (kg²/area)', fontsize=font_size, pad=20)
    plt.savefig(output_figures + '/WetLoadDistribution_LowInun_PH.png', dpi=300, bbox_inches='tight')
    # Show the plot
    plt.show()


In [None]:
plot_wetload_distribution_TP(
    merged_data=merged_data_avg,
    grid_metrics_file_name = grid_metrics_file_name,
    extents=extents,
    colourmap=viridis_cmap,
    font_size=12,
    title='Annual Indirect Phosphorus Load to CW with Average Inundation in 2023(kg²/area)'
)


In [None]:
plot_wetload_distribution_TP(
    merged_data=merged_data_high,
    grid_metrics_file_name = grid_metrics_file_name,
    extents=extents,
    colourmap=viridis_cmap,
    font_size=12,
    title='Annual Indirect Phosphorus Load to CW with High Inundation in 2023(kg²/area)'
)   

In [None]:
plot_wetload_distribution_TP(
    merged_data=merged_data_low,
    grid_metrics_file_name = grid_metrics_file_name,
    extents=extents,
    colourmap=viridis_cmap,
    font_size=12,
    title='Annual Indirect Phosphorus Load to CW with Low Inundation in 2023(kg²/area)'
)   

In [None]:
plot_wetload_distribution_TP(
    merged_data=merged_data_surge,
    grid_metrics_file_name = grid_metrics_file_name,
    extents=extents,
    colourmap=viridis_cmap,
    font_size=12,
    title='Annual Indirect Phosphorus Load to CW with Surge Inundation in 2023(kg²/area)'
)   

In [None]:
def plot_CW_avg_with_bathy(CW_avg, grid_metrics_file_name, extents, font_size=15, title=None):
    """
    Plots the Coastal Wetlands (CW_avg) using the start_lat and start_lon coordinates over a bathymetry background.

    Parameters:
    - CW_avg: GeoDataFrame containing the 'start_lat' and 'start_lon' coordinates of the coastal wetlands.
    - grid_metrics_file_name: Path to the NetCDF file containing bathymetry data.
    - extents: List defining the geographic extents for plotting [xmin, xmax, ymin, ymax].
    - font_size: Integer representing the font size in the plot.
    - title: Optional string to set a custom plot title.

    Returns:
    - None; displays a plot.
    """

    # Extract coordinates for plotting
    coords = np.array(list(zip(CW_avg['start_lon'], CW_avg['start_lat'])))

    # Create the figure and axis
    fig, ax = create_figure(figure_size=(26., 26.), projection=ccrs.PlateCarree(), font_size=font_size, bg_color='white')

    # Load bathymetry data from NetCDF
    with Dataset(grid_metrics_file_name, 'r') as ds:
        bathy = -ds.variables['h'][:]

    # Configure the plot for bathymetry
    plotter = FVCOMPlotter(grid_metrics_file_name, geographic_coords=True, font_size=font_size)
    ax, plot = plotter.plot_field(ax, bathy, extents=extents, add_colour_bar=True, cb_label='Depth (m)', vmin=-60., vmax=0., cmap='Blues', zorder=0)

    # Scatter plot of CW_avg points
    ax.scatter(coords[:, 0], coords[:, 1], c='green', marker='o', s=10, label='Average Lake Level Extent of Coastal Wetlands', zorder=5)

    # Set axis labels and title
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_title(title if title else 'Coastal Wetlands (CW_avg) with Bathymetry', fontsize=font_size, pad=20)

    # Add a legend
    ax.legend(fontsize=font_size)
    plt.savefig(output_figures + '/CW_avg_with_bathy.png', dpi=300, bbox_inches='tight')
    # Show the plot
    plt.show()



In [None]:
plot_CW_avg_with_bathy(
    CW_avg=CW_avg,
    grid_metrics_file_name=grid_metrics_file_name,
    extents=[275, 277.69, 43, 46.3],
    font_size=15,
    title='Coastal Wetlands extent with average lake levels'
)
