# Plot Regional Precipitation Differences on Map

This notebook creates a map showing the gridded percent differences between observed vs. predicted precipitation changes everywhere, with boxes highlighting the 5 subregions.

**Note:** This uses data from `/global/cfs/cdirs/m4334/sferrett/monsoon-pod/data/processed/`. If working from a different directory, these full paths should allow access to the data.

## Import Packages

In [None]:
import warnings
import numpy as np
import xarray as xr
from numba import jit
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd

warnings.filterwarnings('ignore')

## Configuration

In [None]:
# Data file path (full path for working from other directories)
FILEDIR = '/global/cfs/cdirs/m4334/sferrett/monsoon-pod/data/processed'

# Define the 5 subregions
REGIONS = {
    'Eastern Arabian Sea': {
        'latmin': 9., 'latmax': 19.5, 
        'lonmin': 64., 'lonmax': 72.
    }, 
    'Central India': {
        'latmin': 18., 'latmax': 24., 
        'lonmin': 76., 'lonmax': 83.
    },
    'Central Bay of Bengal': {
        'latmin': 9., 'latmax': 14.5, 
        'lonmin': 86.5, 'lonmax': 90.
    },
    'Equatorial Indian Ocean': {
        'latmin': 5., 'latmax': 10., 
        'lonmin': 62., 'lonmax': 67.5
    },
    'Konkan Coast': {
        'latmin': 15., 'latmax': 19.5, 
        'lonmin': 69., 'lonmax': 72.5
    }
}

# Binning parameters for B_L
BINPARAMS = {
    'bl': {'min': -0.6, 'max': 0.1, 'width': 0.0025}
}

# Month pairs: first 3 regions use (6,7), last 2 regions use (7,8)
MONTHPAIRS = [(6, 7), (7, 8)]

# Thresholds
PRTHRESH = 0.25  # precipitation threshold in mm/day
SAMPLETHRESH = 50  # minimum sample threshold for bin inclusion

## Load Data

In [None]:
def load(filename, filedir=FILEDIR, varlist=None):
    """Load dataset from full filepath"""
    filepath = f'{filedir}/{filename}'
    ds = xr.open_dataset(filepath)
    if varlist:
        ds = ds[varlist]
    return ds.load()

# Load the LR-ERA5/IMERG data (you can change to HR or GPCP as needed)
print("Loading data...")
data = load('LR_ERA5_IMERG_pr_bl_terms.nc', varlist=['pr', 'bl'])
print(f"Data loaded: {list(data.data_vars)}")
print(f"Dimensions: {dict(data.dims)}")

## Helper Functions

In [None]:
def get_region(data, key, regions=REGIONS):
    """Subset data by region"""
    region = regions[key]
    return data.sel(
        lat=slice(region['latmin'], region['latmax']),
        lon=slice(region['lonmin'], region['lonmax'])
    )

def get_month(data, months):
    """Subset data by month(s)"""
    if not isinstance(months, (list, tuple)):
        months = [months]
    monthmask = data.time.dt.month.isin(months)
    return data.sel(time=monthmask)

def get_bin_edges(key, binparams=BINPARAMS):
    """Get bin edges for a variable"""
    varname = binparams[key]
    return np.arange(varname['min'], varname['max'] + varname['width'], varname['width'])

In [None]:
@jit(nopython=True)
def fast_1D_binned_stats(blidxs, prdata, nblbins, prthresh=0.25):
    """Fast 1D binning using numba JIT compilation"""
    Q0 = np.zeros(nblbins)
    Q1 = np.zeros(nblbins)
    for i in range(prdata.size):
        blidx = blidxs.flat[i]
        prval = prdata.flat[i]
        if 0 <= blidx < nblbins and np.isfinite(prval):
            Q0[blidx] += 1
            Q1[blidx] += prval
    return Q0, Q1

def calc_binned_stats(data, binparams=BINPARAMS, prthresh=PRTHRESH):
    """Calculate binned statistics for B_L"""
    blbins = get_bin_edges('bl', binparams)
    blidxs = ((data.bl.values - binparams['bl']['min']) / binparams['bl']['width'] + 0.5).astype(np.int32)
    Q0, Q1 = fast_1D_binned_stats(blidxs, data.pr.values, blbins.size, prthresh)
    
    ds = xr.Dataset()
    ds['Q0'] = ('bl', Q0)
    ds['Q1'] = ('bl', Q1)
    ds['bl'] = blbins
    return ds

def get_bin_mean_pr(stats, samplethresh=SAMPLETHRESH):
    """Get mean precipitation in each bin"""
    blbins = stats.bl.values
    Q0 = stats.Q0.values.copy()
    Q1 = stats.Q1.values
    Q0[Q0 == 0.0] = np.nan
    binmeanpr = Q1 / Q0
    binmeanpr[Q0 < samplethresh] = np.nan
    return xr.DataArray(binmeanpr, coords={'bl': blbins})

def get_pdf(stats):
    """Get B_L probability density function"""
    blbins = stats.bl.values
    Q0 = stats.Q0.values
    norm = np.nansum(Q0) * np.diff(blbins)[0]
    pdf = Q0 / norm
    return xr.DataArray(pdf, coords={'bl': blbins})

## Calculate Regional Mean Precipitation Changes

First, calculate the regional mean values for the 5 subregions to display in the boxes.

In [None]:
def calc_regional_pr_changes(regiondata, monthpair, binparams=BINPARAMS, 
                            samplethresh=SAMPLETHRESH, prthresh=PRTHRESH):
    """
    Calculate observed and POD-predicted precipitation changes for a region.
    
    Returns:
        obsprchange: Observed precipitation change (mm/day)
        predprchange: Predicted precipitation change (mm/day)
    """
    # Calculate binned statistics for the entire region
    regionstats = calc_binned_stats(regiondata, binparams, prthresh)
    binmeanpr = get_bin_mean_pr(regionstats, samplethresh)
    
    obsprlist = []
    predprlist = []
    
    for month in monthpair:
        # Get data for this month
        monthdata = get_month(regiondata, month)
        monthstats = calc_binned_stats(monthdata, binparams, prthresh)
        blpdf = get_pdf(monthstats)
        blbinwidth = np.diff(monthstats.bl.values)[0]
        
        # Predicted precipitation (POD framework)
        predpr = np.nansum(binmeanpr * blpdf * blbinwidth)
        
        # Observed precipitation
        obspr = monthdata.pr.mean(dim=['lat', 'lon', 'time']).values
        
        obsprlist.append(obspr)
        predprlist.append(predpr)
    
    # Calculate changes
    obsprchange = obsprlist[1] - obsprlist[0]
    predprchange = predprlist[1] - predprlist[0]
    
    return obsprchange, predprchange

# Calculate precipitation changes for all regions
print("Calculating regional mean precipitation changes...\n")

results = {}
for i, region_name in enumerate(REGIONS.keys()):
    # First 3 regions use June-July, last 2 use July-August
    monthpair = MONTHPAIRS[0] if i < 3 else MONTHPAIRS[1]
    monthpair_str = 'June-July' if i < 3 else 'July-August'
    
    # Get region data for the month pair
    regiondata = get_month(get_region(data, region_name), monthpair)
    
    # Calculate changes
    obs_change, pred_change = calc_regional_pr_changes(
        regiondata, monthpair, BINPARAMS, SAMPLETHRESH, PRTHRESH
    )
    
    # Calculate percent difference: (observed - predicted) / observed * 100
    if obs_change != 0:
        pct_diff = ((obs_change - pred_change) / abs(obs_change)) * 100
    else:
        pct_diff = 0.0
    
    results[region_name] = {
        'observed': obs_change,
        'predicted': pred_change,
        'pct_diff': pct_diff,
        'monthpair': monthpair_str
    }
    
    print(f"{region_name} ({monthpair_str}):")
    print(f"  Observed change:  {obs_change:6.3f} mm/day")
    print(f"  Predicted change: {pred_change:6.3f} mm/day")
    print(f"  Percent diff:     {pct_diff:6.2f}%\n")

## Calculate Gridded Precipitation Changes

Now calculate the percent differences at every grid point across the entire domain.

In [None]:
def calc_gridded_pr_changes(data, monthpair, binparams=BINPARAMS,
                           samplethresh=SAMPLETHRESH, prthresh=PRTHRESH):
    """
    Calculate observed and POD-predicted precipitation changes at each grid point.
    
    Returns:
        obs_change_map: 2D array of observed precipitation changes
        pred_change_map: 2D array of predicted precipitation changes
        pct_diff_map: 2D array of percent differences
    """
    nlat, nlon = len(data.lat), len(data.lon)
    obs_change_map = np.full((nlat, nlon), np.nan)
    pred_change_map = np.full((nlat, nlon), np.nan)
    
    print(f"Calculating gridded changes for months {monthpair}...")
    
    for i, lat in enumerate(data.lat.values):
        if i % 10 == 0:
            print(f"  Processing latitude {i+1}/{nlat}...")
        
        for j, lon in enumerate(data.lon.values):
            # Get data at this grid point
            point_data = data.sel(lat=lat, lon=lon)
            
            # Skip if no valid data
            if point_data.pr.isnull().all() or point_data.bl.isnull().all():
                continue
            
            # Get data for the month pair
            monthpair_data = get_month(point_data, monthpair)
            
            # Calculate binned statistics for this point
            try:
                point_stats = calc_binned_stats(monthpair_data, binparams, prthresh)
                binmeanpr = get_bin_mean_pr(point_stats, samplethresh)
                
                obsprlist = []
                predprlist = []
                
                for month in monthpair:
                    monthdata = get_month(point_data, month)
                    monthstats = calc_binned_stats(monthdata, binparams, prthresh)
                    blpdf = get_pdf(monthstats)
                    blbinwidth = np.diff(monthstats.bl.values)[0]
                    
                    # Predicted precipitation
                    predpr = np.nansum(binmeanpr * blpdf * blbinwidth)
                    
                    # Observed precipitation
                    obspr = monthdata.pr.mean(dim='time').values
                    
                    obsprlist.append(obspr)
                    predprlist.append(predpr)
                
                # Calculate changes
                obs_change_map[i, j] = obsprlist[1] - obsprlist[0]
                pred_change_map[i, j] = predprlist[1] - predprlist[0]
                
            except:
                continue
    
    # Calculate percent difference map
    pct_diff_map = np.full((nlat, nlon), np.nan)
    valid_mask = (obs_change_map != 0) & ~np.isnan(obs_change_map) & ~np.isnan(pred_change_map)
    pct_diff_map[valid_mask] = ((obs_change_map[valid_mask] - pred_change_map[valid_mask]) / 
                                 np.abs(obs_change_map[valid_mask])) * 100
    
    return obs_change_map, pred_change_map, pct_diff_map

# Calculate gridded changes for June-July (used for most of domain)
print("\nCalculating gridded precipitation changes across the entire domain...")
obs_map_67, pred_map_67, pct_diff_map_67 = calc_gridded_pr_changes(data, MONTHPAIRS[0])

# For visualization, we'll use June-July as the primary map
pct_diff_map = pct_diff_map_67

print("\nDone calculating gridded changes!")

## Plot Map with Gridded Percent Differences and Subregion Boxes

In [None]:
# Create figure with cartopy map
fig = plt.figure(figsize=(16, 10))
ax = plt.axes(projection=ccrs.PlateCarree())

# Set extent to cover the monsoon region
ax.set_extent([60, 95, 0, 30], crs=ccrs.PlateCarree())

# Plot the gridded percent differences as a filled contour
# Create diverging colormap centered at 0
levels = np.linspace(-100, 100, 21)
cmap = plt.cm.RdBu_r  # Red for positive (obs>pred), Blue for negative (pred>obs)

# Plot the filled contours
cf = ax.contourf(data.lon, data.lat, pct_diff_map,
                 levels=levels, cmap=cmap, extend='both',
                 transform=ccrs.PlateCarree(), alpha=0.8)

# Add colorbar
cbar = plt.colorbar(cf, ax=ax, orientation='horizontal', pad=0.05, shrink=0.7)
cbar.set_label('Percent Difference: (Observed - Predicted) / |Observed| × 100%',
               fontsize=12, fontweight='bold')

# Add map features
ax.add_feature(cfeature.COASTLINE, linewidth=1, edgecolor='black')
ax.add_feature(cfeature.BORDERS, linewidth=0.5, linestyle=':', edgecolor='gray')
gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False,
                  linewidth=0.5, alpha=0.5, linestyle='--', color='gray')
gl.top_labels = False
gl.right_labels = False

# Plot each subregion box
for i, (region_name, coords) in enumerate(REGIONS.items()):
    pct_diff = results[region_name]['pct_diff']
    
    # Draw rectangle with thick black border
    rect = patches.Rectangle(
        (coords['lonmin'], coords['latmin']),
        coords['lonmax'] - coords['lonmin'],
        coords['latmax'] - coords['latmin'],
        linewidth=3,
        edgecolor='black',
        facecolor='none',
        transform=ccrs.PlateCarree()
    )
    ax.add_patch(rect)
    
    # Add label with percent difference in the box
    center_lon = (coords['lonmin'] + coords['lonmax']) / 2
    center_lat = (coords['latmin'] + coords['latmax']) / 2
    
    # Format the label - show the percent difference
    label_text = f"{pct_diff:+.1f}%"
    
    ax.text(center_lon, center_lat, label_text,
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=14,
            fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='white', 
                     alpha=0.9, edgecolor='black', linewidth=2),
            transform=ccrs.PlateCarree())

# Add title
ax.set_title('Gridded Percent Difference Between Observed and Predicted Precipitation Changes\n' +
             'LR-ERA5/IMERG (June-July)',
             fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.show()

# Optional: Save the figure
# plt.savefig('gridded_precip_differences_map.png', dpi=300, bbox_inches='tight')

## Summary Table

In [None]:
# Create a summary table
summary_data = []
for region_name, result in results.items():
    summary_data.append({
        'Region': region_name,
        'Month Pair': result['monthpair'],
        'Observed Δ (mm/day)': f"{result['observed']:.3f}",
        'Predicted Δ (mm/day)': f"{result['predicted']:.3f}",
        'Percent Difference (%)': f"{result['pct_diff']:+.2f}"
    })

df = pd.DataFrame(summary_data)
print("\n" + "="*100)
print("SUMMARY TABLE: Observed vs. Predicted Precipitation Changes")
print("="*100)
print(df.to_string(index=False))
print("="*100)