In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.patches as patches
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np

fert_consumption = result

# Load your dataframes and shapefiles
gdf = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\Census_2011 - Copy\new_2011_Dist_final_1_1_2.shp")
boundary_gdf = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\India-State-and-Country-Shapefile-Updated-Jan-2020-master (1)\India-State-and-Country-Shapefile-Updated-Jan-2020-master\India_State_Boundary.shp")
basin_gdf = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\basin_shaope_file\basin_indiat.shp")
basin_gdf_new = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\basin_shaope_file\basin_level_5_hy_shedt.shp")

# Rename columns in gdf to match those in fert_consumption for merging
gdf = gdf.rename(columns={'State': 'State Name', 'Parent_dis': 'Dist Name'})

# Ensure all GeoDataFrames are in the same CRS
gdf = gdf.to_crs(basin_gdf.crs)
boundary_gdf = boundary_gdf.to_crs(basin_gdf.crs)
basin_gdf_new = basin_gdf_new.to_crs(basin_gdf.crs)

# Function to calculate weighted mean nitrogen surplus for each basin
def calculate_weighted_mean(year, basin_gdf, object_id_col):
    fert_consumption_year = fert_consumption[fert_consumption['Year'] == year]
    merged_gdf = gdf.merge(fert_consumption_year, on=['State Name', 'Dist Name'], how='left')
    intersections = gpd.overlay(merged_gdf, basin_gdf, how='intersection')
    projected_intersections = intersections.to_crs(epsg=32644)  # UTM Zone 44N
    projected_intersections['intersection_area'] = projected_intersections.geometry.area
    projected_intersections['weighted_n_surplus'] = projected_intersections['mean_N_budjet_kg_ha'] * projected_intersections['intersection_area']
    basin_stats = projected_intersections.groupby(object_id_col).apply(lambda x: x['weighted_n_surplus'].sum() / x['intersection_area'].sum())
    basin_stats = basin_stats.reset_index()
    basin_stats.columns = [object_id_col, 'mean_N_budjet_kg_ha']
    basin_stats_gdf = basin_gdf.merge(basin_stats, on=object_id_col, how='left')
    return basin_stats_gdf

# Function to calculate state averages
def calculate_state_averages(year):
    fert_consumption_year = fert_consumption[fert_consumption['Year'] == year]
    merged_gdf = gdf.merge(fert_consumption_year, on=['State Name', 'Dist Name'], how='left')
    state_stats = merged_gdf.dissolve(by='State Name', aggfunc={'mean_N_budjet_kg_ha': 'mean'}).reset_index()
    return state_stats

# Selected years to generate plots
selected_years = [1966, 1990, 2017]

# Create a figure with subplots arranged in 3 rows and 4 columns
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(26, 20), subplot_kw={'projection': ccrs.PlateCarree()})
fig.subplots_adjust(hspace=0.01, wspace=0.01)

# Flatten the array of axes to make them easier to iterate over
axes = axes.flatten()

# Set common color scale for all plots
vmin, vmax = 0, 200
cmap = plt.get_cmap('YlGnBu').copy()
cmap.set_over('darkblue')
cmap.set_under('lightyellow')

for i, year in enumerate(selected_years):
    fert_consumption_year = fert_consumption[fert_consumption['Year'] == year]
    merged_gdf = gdf.merge(fert_consumption_year, on=['State Name', 'Dist Name'], how='left')

    print(f"Year {year} - NaN values in 'mean_N_budjet_kg_ha':", merged_gdf['mean_N_budjet_kg_ha'].isna().sum())

    merged_gdf.plot(column='mean_N_budjet_kg_ha', ax=axes[i*4], legend=False,
                    cmap=cmap, vmin=vmin, vmax=vmax, alpha=1,
                    missing_kwds={"color": "white"})
    merged_gdf.boundary.plot(ax=axes[i*4], edgecolor='black', linewidth=0.5, alpha=0.7, zorder=10)
    boundary_gdf.plot(ax=axes[i*4], edgecolor='lightgrey', facecolor='none', alpha=0.5, zorder=9)

    state_stats = calculate_state_averages(year)
    state_stats.plot(column='mean_N_budjet_kg_ha', ax=axes[i*4+1], legend=False,
                     cmap=cmap, vmin=vmin, vmax=vmax, alpha=1,
                     missing_kwds={"color": "white"})
    boundary_gdf.plot(ax=axes[i*4+1], edgecolor='black', facecolor='none', linewidth=0.5, alpha=0.7, zorder=10)

    basin_stats_gdf = calculate_weighted_mean(year, basin_gdf, 'OBJECTID_1')
    basin_stats_gdf.plot(column='mean_N_budjet_kg_ha', ax=axes[i*4+2], legend=False,
                         cmap=cmap, vmin=vmin, vmax=vmax, alpha=1,
                         missing_kwds={"color": "white"})
    basin_gdf.boundary.plot(ax=axes[i*4+2], edgecolor='black', linewidth=0.5, alpha=0.7, zorder=10)

    basin_stats_gdf_new = calculate_weighted_mean(year, basin_gdf_new, 'OBJECTID')
    basin_stats_gdf_new.plot(column='mean_N_budjet_kg_ha', ax=axes[i*4+3], legend=False,
                             cmap=cmap, vmin=vmin, vmax=vmax, alpha=1,
                             missing_kwds={"color": "white"})
    basin_gdf_new.boundary.plot(ax=axes[i*4+3], edgecolor='black', linewidth=0.5, alpha=0.7, zorder=10)

    for j in range(i*4, i*4+4):
        axes[j].add_feature(cfeature.LAND, edgecolor='none', facecolor='grey')
        axes[j].add_feature(cfeature.COASTLINE, edgecolor='black')
        axes[j].add_feature(cfeature.OCEAN, color='lightblue', alpha=0.3)
        axes[j].set_extent([68, 98, 6, 37], crs=ccrs.PlateCarree())
        gl = axes[j].gridlines(draw_labels=True, color='none')
        gl.top_labels = False
        gl.left_labels = False
        if j % 4 != 3:
            gl.right_labels = False
        if j < 8:
            gl.bottom_labels = False
        if j == 3 or j == 7 or j == 11:
            gl.right_labels = True
        if j >= 8:
            gl.bottom_labels = True

# Add column titles
column_titles = ['District', 'State', 'Basin', 'Sub-Basin']
for ax, col in zip(axes[:4], column_titles):
    ax.set_title(col, fontsize=20, fontweight='bold')

# Add row labels
for ax, row in zip(axes[::4], selected_years):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 15, 0),
                xycoords='axes fraction', textcoords='offset points',
                ha='right', va='center', fontsize=20, fontweight='bold', rotation=90)

# Place a common colorbar at the bottom of the plots
cbar_ax = fig.add_axes([0.15, 0.06, 0.7, 0.02])
norm = colors.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm._A = []
cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
cbar.set_label('N surplus (kg ha$^{-1}$ yr$^{-1}$)', fontsize=18, fontweight='bold')
cbar.ax.tick_params(labelsize=14)

plt.show()
