# Figure Training data maps

In [None]:
# Libraries
import os
import numpy as np
import geopandas as gpd
import dask.dataframe as dd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from shapely.geometry import box
import cartopy.crs as ccrs

In [None]:
# Directories
dir02 = '../paper_deficit/output/02_dbase/'
dir06 = '../paper_deficit/output/06_eval/'
dir_nearth = '../data/naturalearth/'

---

In [None]:
# Get coastline data
coastline110 = gpd.read_file(
    os.path.join(dir_nearth, 'ne_110m_coastline', 'ne_110m_coastline.shp'))

land110 = gpd.read_file(
    os.path.join(dir_nearth, 'ne_110m_land', 'ne_110m_land.shp'))

In [None]:
# Get df_dbase file
df = dd.read_parquet(os.path.join(dir02, 'df_dbase.parquet'))

df = df[['lat', 'lon', 'train_prim', 'train_secd']].compute()

In [None]:
def plot_train_grid(ax, scen, ymin=-60, ymax=90, xmin=-180, xmax=180, add_points=False, dpi=600):
    """
    Plots a training grid on a map, showing the density of training points within 1°x1° grid cells.
    
    Parameters:
        ax (matplotlib.axes.Axes): The matplotlib Axes instance with a Cartopy projection.
        scen (str): The scenario identifier to filter the training data (column name suffix).
        ymin, ymax, xmin, xmax (float): The latitudinal and longitudinal bounds of the map.
        add_points (bool): Whether to overlay individual training points on the map.
    """
    
    # Filter the training data for the specified scenario
    df_scen = df[df[f'train_{scen}'] == True][['lat', 'lon', f'train_{scen}']]
    
    # Create a GeoDataFrame with point geometries from the training data
    gdf_geometry = gpd.points_from_xy(df_scen.lon, df_scen.lat)
    gdf = gpd.GeoDataFrame(df_scen, geometry=gdf_geometry, crs='EPSG:4326')

    # Get the geographical bounds of the data
    minx, miny, maxx, maxy = gdf.total_bounds
        
    # Generate a 1°x1° grid covering the extent of the data
    grid_cells = []
    for x in np.arange(np.floor(minx), np.ceil(maxx), 1):
        for y in np.arange(np.floor(miny), np.ceil(maxy), 1):
            grid_cells.append(box(x, y, x + 1, y + 1))
    
    # Create a GeoDataFrame for the grid
    grid = gpd.GeoDataFrame(grid_cells, columns=["geometry"], crs="EPSG:4326")
    
    # Perform a spatial join to assign each point to a grid cell
    joined = gpd.sjoin(gdf, grid, how="left", predicate="within")
    
    # Count the number of points in each grid cell
    grid["count"] = joined.groupby("index_right").size()

    # Plot land with color of colorbar equals 0 training grid cells
    land110.plot(ax=ax, transform=ccrs.PlateCarree(), color='#fff7ec', linewidth=0.5)
    # Plot training grid cell count
    grid.plot(ax=ax, column="count", cmap="OrRd", legend=False, transform=ccrs.PlateCarree())
    # Plot coastline
    coastline110.plot(ax=ax, transform=ccrs.PlateCarree(), color='#000000', linewidth=0.5)

    # Optionally overlay individual training points
    if add_points:
        gdf.plot(ax=ax, transform=ccrs.PlateCarree(), markersize=0.1, marker='+')
    
    # Set the map extent and remove axes
    ax.set_extent((xmin, xmax, ymin, ymax), ccrs.PlateCarree())
    ax.axis('off')

    # Add a colorbar to indicate the density of points in grid cells
    norm = colors.Normalize(vmin=grid['count'].min(), vmax=grid['count'].max())
    cbar = plt.cm.ScalarMappable(norm=norm, cmap='OrRd')

    # Ensure the colorbar is tied to the figure containing the axes
    fig = ax.get_figure()
    ax_cbar = fig.colorbar(cbar, ax=ax, shrink=0.65)
    ax_cbar.set_label('Training points in 1°x1° grid cells')

In [None]:
# Plot the grid with point counts
fig = plt.figure(figsize=(9, 9), dpi=600)
fig.set_facecolor('#ffffff')

ax0 = fig.add_subplot(2,1,1, projection=ccrs.Robinson(central_longitude=10), aspect='auto')
ax1 = fig.add_subplot(2,1,2, projection=ccrs.Robinson(central_longitude=10), aspect='auto')

plot_train_grid(ax0, 'prim')
plot_train_grid(ax1, 'secd')

# Add title and legend
ax0.set_title('Pristine land assumption', fontsize=14)
ax1.set_title('Low human influence assumption', fontsize=14)

plt.savefig(os.path.join(dir06, f'pdf/figs10_training_data_maps.pdf'), bbox_inches='tight', dpi=600)
plt.savefig(os.path.join(dir06, f'png/figs10_training_data_maps.png'), bbox_inches='tight', dpi=600);