In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.patches as patches
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# Load your dataframes and shapefiles
# Assume `fert_consumption` dataframe is already loaded with necessary data
gdf = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\Census_2011 - Copy\new_2011_Dist_final_1_1_2.shp")
boundary_gdf = gpd.read_file(r"E:\nsurplus_paper\code\ns_paper\shape_file\India-State-and-Country-Shapefile-Updated-Jan-2020-master (1)\India-State-and-Country-Shapefile-Updated-Jan-2020-master\India_State_Boundary.shp")

# Rename columns in gdf to match those in fert_consumption for merging
gdf = gdf.rename(columns={'State': 'State Name', 'Parent_dis': 'Dist Name'})

# Selected years to generate plots
selected_years = [1970, 1980, 1990, 2000, 2010, 2017]

# Create a figure with subplots arranged in 2 rows and 3 columns
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
fig.subplots_adjust(hspace=0.01, wspace=0.01)  # Adjust space between plots

# Flatten the array of axes to make them easier to iterate over
axes = axes.flatten()

# Set common color scale for all plots
vmin, vmax = 0, 200
cmap = plt.get_cmap('YlGnBu').copy()  # Use a green-blue color map and make a copy to adjust brightness
cmap.set_over('darkblue')
cmap.set_under('lightyellow')

# Generate plots for each selected year
for i, year in enumerate(selected_years):
    fert_consumption_year = fert_consumption[fert_consumption['Year'] == year]
    merged_gdf = gdf.merge(fert_consumption_year, on=['State Name', 'Dist Name'], how='left')
    
    # Plot on the designated subplot
    merged_gdf.plot(column='mean_N_budjet_kg_ha', ax=axes[i], legend=False,
                    cmap=cmap, vmin=vmin, vmax=vmax, alpha=1,
                    missing_kwds={"color": "white"})  # Set color for NaN values
    
    boundary_gdf.plot(ax=axes[i], edgecolor='black', facecolor='none', alpha=0.5, zorder=10)
    
    # Add basemap features
    axes[i].add_feature(cfeature.LAND, edgecolor='grey', facecolor='grey')
    axes[i].add_feature(cfeature.COASTLINE, edgecolor='grey')
    #axes[i].add_feature(cfeature.BORDERS, linestyle=':')
    axes[i].add_feature(cfeature.OCEAN, color='lightblue', alpha=1)
    
    # Add a box around each subplot
    box = patches.FancyBboxPatch((0, 0), 1, 1, transform=axes[i].transAxes, 
                                 boxstyle='round,pad=0.3', edgecolor='black', facecolor='none', 
                                 linewidth=2)
    axes[i].add_patch(box)
    
    # Add title manually within the box at the top right
    axes[i].annotate(f'{year}', xy=(0.95, 0.95), xycoords='axes fraction', fontsize=16, 
                     fontweight='bold', ha='right', va='top', bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=0))
    axes[i].set_extent([68, 98, 6, 37], crs=ccrs.PlateCarree())
    
    # Add transparent gridlines with labels only on the left of the first and fourth plots
    # and bottom labels on the fourth, fifth, and sixth plots
    gl = axes[i].gridlines(draw_labels=True, color='none')  # Make gridlines transparent
    gl.top_labels = False
    gl.right_labels = False
    if i % 3 != 0:  # Not the first column
        gl.left_labels = False
    if i < 3:  # Not the first row
        gl.bottom_labels = False
    if i == 0 or i == 3:  # First and fourth plots
        gl.left_labels = True
    if i >= 3:  # Fourth, fifth, and sixth plots
        gl.bottom_labels = True

# Place a common colorbar at the bottom of the plots
cbar_ax = fig.add_axes([0.15, 0.06, 0.7, 0.02])  # Position for horizontal colorbar
norm = colors.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm._A = []
cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
cbar.set_label('N surplus (kg ha$^{-1}$ yr$^{-1}$)', fontsize=18, fontweight='bold')
cbar.ax.tick_params(labelsize=14)  # Adjust colorbar tick size


# Optionally, display the plot in the notebook if you are using Jupyter
plt.show()
