# Imports and setting up viz

NB : conda env1 on Mac, lam1env on spirit (Python3.12)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

#import personnal tools
import sys
sys.path.append('../../python_tools/')
from tools import *
from tools_mapping import *

In [None]:
rivers = cfeature.NaturalEarthFeature('physical', 'rivers_lake_centerlines', '10m',edgecolor=(0, 0, 0, 0.3), facecolor='none')

# Load files

LAM output analysis.

Simu irr et simu no-irr.

## Area and period selection

In [None]:
# lon_min=-13
# lon_max=6
# lat_min=32
# lat_max=49

date_min = '2010-01-01'
date_max = '2022-12-31'

## Sims

In [None]:
# noirr_dir='../../../JZ_simu_outputs/LAM/LAM_1000_40/noirr_2010_2022'
# irr_dir='../../../JZ_simu_outputs/LAM/LAM_1000_40/irr_2010_2022'
noirr_dir='../../../JZ_simu_outputs/LAM/LAM_1500_60/noirr'
irr_dir='../../../JZ_simu_outputs/LAM/LAM_1500_60/irr'

In [None]:
TS_flag=False

In [None]:
#open netcdf files
if TS_flag:
    filename = '{}/*/SRF/TS_MO/*.nc'.format(noirr_dir)
elif not TS_flag:
    filename = '{}/*/SRF/MO/*sechiba_history.nc'.format(noirr_dir)

ORCnoirr0 = xr.open_mfdataset(filename)
ORCnoirr = ORCnoirr0.rename({'time_counter':'time'})
ORCnoirr.attrs['name'] = 'no_irr'
ORCnoirr.attrs['plot_color'] = "red"
ORCnoirr = ORCnoirr.sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))

ORCnoirr['snowmelt'] = ORCnoirr['snowmelt'] * 24 * 3600
ORCnoirr['snowmelt'].attrs['units'] = 'mm/day'

ORCnoirr['snow_contrib'] = ORCnoirr['snowmelt'] / (ORCnoirr['snowmelt'] + ORCnoirr['rain']) * 100
ORCnoirr['snow_contrib'].attrs['units'] = '%'

ORCnoirr

In [None]:
#add spinup
filename = '{}/*/SRF/*history.nc'.format(noirr_dir)
spinup_noirr = xr.open_mfdataset(filename)

spinup_noirr = spinup_noirr.rename({'time_counter':'time'})
spinup_noirr.attrs['name'] = 'no_irr'
spinup_noirr.attrs['plot_color'] = "red"
spinup_noirr = spinup_noirr.sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))

spinup_noirr


In [None]:
if TS_flag:
    filename = '{}/*/SRF/TS_MO/*.nc'.format(noirr_dir)
elif not TS_flag:
    filename = '{}/*/SRF/MO/*history.nc'.format(irr_dir)

ORCirr0 = xr.open_mfdataset(filename)
ORCirr = ORCirr0.rename({'time_counter':'time'})
ORCirr.attrs['name'] = 'irr'
ORCirr.attrs['plot_color'] = "#0C5DA5"
ORCirr = ORCirr.sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))

ORCirr['snowmelt'] = ORCirr['snowmelt'] * 24 * 3600
ORCirr['snowmelt'].attrs['units'] = 'mm/day'

ORCirr

In [None]:
#add spinup
filename = '{}/*/SRF/*history.nc'.format(irr_dir)
spinup_irr = xr.open_mfdataset(filename)
spinup_irr = spinup_irr.rename({'time_counter':'time'})
spinup_irr.attrs['name'] = 'irr'
spinup_irr.attrs['plot_color'] = "blue"
spinup_irr = spinup_irr.sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))
spinup_irr

In [None]:
if not TS_flag:
    # #manually define irrig_deficit as netirrig-irrigation in ORC file
    # ORCirr['irrig_deficit'] = ORCirr['netirrig'] - ORCirr['irrigation']
    # #make irrig_deficit units mm/day
    # ORCirr['irrig_deficit'].attrs['units'] = 'mm/day'
    # ORCirr['irrig_frac'] = ORCirr['irrigmap_dyn']/ORCirr['Areas']

    ORCnoirr['irrig_frac'] = ORCirr['irrigmap_dyn']/ORCirr['Areas'] * 100
    ORCnoirr['irrig_frac'].attrs['units'] = '%'
    ORCirr['irrig_frac'] = ORCirr['irrigmap_dyn']/ORCirr['Areas'] * 100
    ORCirr['irrig_frac'].attrs['units'] = '%'

## Obs

In [None]:
outfilename='../../../obs/CCI-SM/cci_sm_monthly_filtered.nc'
newcci = xr.open_mfdataset(outfilename)
newcci['flag'].attrs['units']='-'

newcci

## Interpolation and masks

In [None]:
#create a masks of data points where irrig_frac >5%
irr_mask = (ORCirr['irrigmap_dyn']/ORCirr['Areas'])>0.05
con_mask=ORCnoirr['Contfrac']>0.95
ip_mask=polygon_to_mask(ORCnoirr, iberian_peninsula)

In [None]:
#iberian peninsula ds
ip_ORCnoirr=ORCnoirr.where(con_mask).where(ip_mask)
ip_ORCirr=ORCirr.where(con_mask).where(ip_mask)

In [None]:
ip_spinup_noirr = spinup_noirr.where(con_mask).where(ip_mask)
ip_spinup_noirr['time'] = ip_spinup_noirr['time'] - pd.to_timedelta(5 * 365.25, unit='D')
ip_long_noirr = xr.concat([ip_spinup_noirr,ip_ORCnoirr], dim='time')
ip_long_noirr

ip_spinup_irr = spinup_irr.where(con_mask).where(ip_mask)
ip_spinup_irr['time'] = ip_spinup_irr['time'] - pd.to_timedelta(5 * 365.25, unit='D')
ip_long_irr = xr.concat([ip_spinup_irr,ip_ORCirr], dim='time')
ip_long_irr

In [None]:
#mask on irrigated areas only
# irr_ORCirr=ORCirr.where(irr_mask)
# irr_ORCnoirr=ORCnoirr.where(irr_mask)

# ip_irr_ORCirr=ip_ORCirr.where(irr_mask)
# ip_irr_ORCnoirr=ip_ORCnoirr.where(irr_mask)

In [None]:
cci_iORC = newcci.interp_like(ORCnoirr)
ip_cci_iORC = cci_iORC.where(ip_mask)

In [None]:
# ebro_mask = polygon_to_mask(ORCnoirr, ebro)
# ebro_ORCirr = ip_ORCirr.where(ebro_mask)
# ebro_ORCnoirr = ip_ORCnoirr.where(ebro_mask)
# ebro_irr_ORCirr = ebro_ORCirr.where(irr_mask)
# ebro_irr_ORCnoirr = ebro_ORCnoirr.where(irr_mask)

# Maps

In [None]:
ds=ORCirr
var='evap'

vmin = 0
vmax = 1
vmin, vmax = None, None
# vmax=None

title=None
# title='off'

# clabel="Share of surface withdrawals (%)"
clabel=None

cmap=wet

map_ave(ds, var, vmin=vmin, vmax=vmax, title=title, clabel=clabel, cmap=cmap, poly=None)

In [None]:
ds=ip_ORCnoirr
var='mrsos'

vmin = -1
vmax = 1
# vmin, vmax = None, None
# vmax=None

title=None
# title='off'

# clabel="Share of surface withdrawals (%)"
clabel=None

cmap=wet

map_ave(ds, var, vmin=vmin, vmax=vmax, title=title, clabel=clabel, cmap=cmap, poly=None)

In [None]:
#diff
ds1=ip_ORCirr
ds2=ip_ORCnoirr
var='mrsos'
vmax= 1.5
vmin=-1.5
# vmin, vmax=None, None

title=None
# title='off'

cmap=emb_neutral

clabel=None

map_diff_ave(ds1, ds2, var,cmap=cmap, title=title, clabel=clabel, vmin=vmin, vmax=vmax)

In [None]:
#relative difference
ds1=ip_ORCirr
ds2=ip_ORCnoirr
var='mrsos'
vmax= 10
vmin=-10
# vmin, vmax=None,None

title='off'
title=None

# clabel='Slow reservoir difference (%)'
clabel=None

map_rel_diff_ave(ds1, ds2, var,cmap=emb_neutral, title=title, clabel=clabel, vmin=vmin, vmax=vmax)

In [None]:
#map for 4 seasons
var='norm_sm'
ds1=normalize_sm(ip_ORCnoirr, 'mrsos', '2010-01-01', '2019-12-31')
ds2=normalize_sm(ip_cci_iORC, 'sm', '2010-01-01', '2019-12-31')
max_value= 1.5
min_value=-1.5
# min_value,max_value=None, None
cmap=emb_neutral

diff=ds1[var]-ds2[var]
title='{} diff, {} vs {} ({})'.format( var, ds1.attrs['name'], ds2.attrs['name'], ds1[var].attrs['units'])
plotvar=diff

map_seasons(plotvar, cmap=cmap, vmin=min_value, vmax=max_value, title=title, hex=False)

### Discrete colormap

In [None]:
# Your list of discrete values
data_values = np.array([0, 1, 2, 4, 8, 16, 32, 64, 128])
N = len(data_values) # Number of unique values (and colours)

# ----------------------------------------------------
# 1. Define the Colormap
# ----------------------------------------------------

# Choose a built-in colormap to sample from, e.g., 'viridis' or 'Spectral'.
# The number of colours sampled must equal the number of discrete values.
cmap_base = plt.cm.get_cmap('Spectral', N)

# Create a ListedColormap from the sampled colours
cmap = mcolors.ListedColormap(cmap_base(np.arange(N)))

# ----------------------------------------------------
# 2. Define the Boundaries (Norm)
# ----------------------------------------------------

# The boundaries must be defined to sit *between* your discrete values.
# The list of boundaries should have N+1 elements.
# A simple way is to find the mid-points between your data values.

# Calculate midpoints: (v_i + v_{i+1}) / 2
bounds = (data_values[:-1] + data_values[1:]) / 2

# Add boundaries for the start and end. We assume the bins are centred,
# so we extend the first and last boundaries outwards.
# For the first one: bounds[0] - (bounds[1] - bounds[0])
bounds = np.insert(bounds, 0, data_values[0] - (bounds[0] - data_values[0]))
bounds = np.append(bounds, data_values[-1] + (data_values[-1] - bounds[-1]))

# Create the BoundaryNorm object
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# ----------------------------------------------------
# 3. Create a Dummy Plot and Colorbar
# ----------------------------------------------------

fig, ax = plt.subplots(figsize=(8, 2))

# Create a ScalarMappable object for the colorbar, 
# as you may not have an image/plot yet.
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # The ScalarMappable needs a dummy array, but it can be empty

# Draw the colourbar
cbar = fig.colorbar(
    sm,
    ax=ax,
    boundaries=bounds,       # Pass the boundaries
    ticks=data_values,       # Set the ticks to be your original values
    spacing='uniform',       # Ensure uniform spacing of colors
    orientation='horizontal' # Choose horizontal or vertical
)

# Label the colourbar
cbar.set_label('Discrete Data Values (Unit)')
ax.set_title('Discrete Colormap for Specific Values', pad=20)

# Remove the dummy axes
ax.set_visible(False) 

plt.show()

In [None]:
def nice_map_discrete(
    plotvar, 
    ax, 
    cmap_discrete,     # <--- REQUIRED: Your custom ListedColormap
    norm_discrete,     # <--- REQUIRED: Your custom BoundaryNorm
    clabel=None, 
    cbar_on=True, 
    xloc=8, 
    yloc=9, 
    left_labels=True, 
    poly=None
):
    """
    Plots a geographical map with discrete, fixed-value colour categories.

    Uses a custom Matplotlib ListedColormap and BoundaryNorm for precise 
    discrete colour mapping, maintaining all existing nice_map features.
    """
    
    # --- 1. Cartopy/Axes Setup ---
    ax.coastlines()
    ax.add_feature(cfeature.RIVERS)
    gl = ax.gridlines(draw_labels=True, dms=False, x_inline=False, y_inline=False, alpha=0.8)
    gl.right_labels = False
    gl.left_labels = left_labels
    gl.top_labels = False
    gl.xlocator = plt.MaxNLocator(xloc)
    gl.ylocator = plt.MaxNLocator(yloc)

    # --- 2. Plotting Logic ---
    plot_obj = None

    # Determine plotting arguments
    plot_kwargs = {
        'ax': ax, 
        'transform': ccrs.PlateCarree(),
        'cmap': cmap_discrete, # Use the required discrete map
        'norm': norm_discrete   # Use the required boundary norm
        # vmin/vmax are NOT used here as they are defined by the norm
    }

    # 🚨 CRITICAL FIX: Explicitly call .plot.pcolormesh() 
    # to ensure the 'cmap' and 'norm' keywords are handled correctly.
    plot_obj = plotvar.plot.pcolormesh(**plot_kwargs) 

    # --- 3. Colourbar Customisation ---
    if cbar_on and plot_obj is not None:
        cbar = plot_obj.colorbar
        
        # 🚨 CRITICAL CHANGE FOR DISCRETE MAPS 🚨
        # Ticks are set to the midpoint of the discrete bins, which correspond 
        # to the original data values used to create the norm.
        
        # Calculate the centre of each bin from the BoundaryNorm's boundaries
        discrete_ticks = (norm_discrete.boundaries[:-1] + norm_discrete.boundaries[1:]) / 2
        
        cbar.set_label(clabel)
        cbar.set_ticks(discrete_ticks) # Use the calculated tick positions
        
        # You will likely want to format these ticks to match your specific
        # data values (0, 1, 2, 4, 8, 16, 32, 64, 128) for the final display.
        # This requires passing the original list of values to set_ticklabels().
        # Assuming your original values are the actual centres:
        cbar.set_ticklabels(discrete_ticks.astype(int)) 
        
        # Retaining your original formatting logic
        cbar.ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
        cbar.ax.yaxis.get_major_formatter().set_scientific(False) # Turn off scientific notation for discrete integers
        # cbar.ax.yaxis.get_major_formatter().set_powerlimits((-2, 4)) # Powerlimits unnecessary with False above

    elif cbar_on and plot_obj is None:
        print("Warning: Plot object not created, cannot draw colorbar.")
        
    elif not cbar_on and plot_obj is not None:
        # Attempt to remove colorbar if requested
        if hasattr(plot_obj, 'colorbar') and plot_obj.colorbar is not None:
            plot_obj.colorbar.remove()


    # --- 4. Final Touches ---
    if poly:
        # Assuming you have a function called plot_polygon
        plot_polygon(poly, ax)

    plt.tight_layout()

    return plot_obj

In [None]:
fig = plt.figure(figsize=(7.5, 4))
ax = plt.axes(projection=ccrs.PlateCarree())

plotvar=newcci['flag'].mean(dim='time')
plotvar=newcci['flag'].isel(time=2)

nice_map_discrete(plotvar, ax, cmap, norm)

# Time series

In [None]:
var='sm'

ds1=ip_cci_iORC
# ds2=ip_ORCnoirr
# ds3=ip_ORCirr

ds2=ip_long_noirr
ds3=ip_long_irr

# ds_list=[ds3]
ds_list=[ds2, ds3]

year_min = 2007
year_max = 2022

# title='off'
title=None

vmin=0
vmax=22
vmin, vmax= None, None

time_series_ave(    ds_list, var, ds_colors=True, year_min=year_min, year_max=year_max, title=title)
# seasonal_cycle_ave( ds_list, var, ds_colors=True, vmin=vmin, vmax=vmax,year_min=year_min, year_max=year_max, title=title)

In [None]:
ds_list=[ds1, ds2, ds3]

year_min = 2010
year_max = 2022

# title='off'
title=None

# ylabel="Irrigation (mm/d)"
time_series_ave(ds_list, var, ds_colors=True, year_min=year_min, year_max=year_max, title=title)
seasonal_cycle_ave(ds_list, var, ds_colors=True, year_min=year_min, year_max=year_max, title=title)
plt.grid()

In [None]:
# multiple variables

varlist=['mrsos', 'humtot']
varlist=['streamr', 'slowr', 'fastr']

ds1=ip_ORCnoirr
ds1=ip_long_noirr

year_min = 2007
year_max = 2022

# title='off'
title=None

fig = plt.figure(figsize=(7.5, 4))
ax = plt.axes()

for var in varlist:
    plotvar=ds1[var].mean(dim=['lon', 'lat'])
    nice_time_plot(plotvar, ax)
plt.legend()

# Figures

In [None]:
#seasonal cycle of norm sm for irr, noirr, and obs
#maps of norm sm diff for both figs
fig = plt.figure(figsize=(18, 12))
gs = gridspec.GridSpec(3, 3, width_ratios=[1.1,1,1.2], height_ratios=[1, 1, 1])

vmin_p=-1
vmax_p= 1
vmin_et=-1
vmax_et= 1
vmin_sm = -1.5
vmax_sm=1.5


# Normalized SSM
date_min = '2010-01-01'
# date_max = '2022-12-31'
date_max = '2019-12-31' # to match with GPCC
ds_obs=ip_cci_iORC.sel(time=slice(date_min,date_max))
ds_irr=ip_ORCirr.sel(time=slice(date_min,date_max))
ds_noirr=ip_ORCnoirr.sel(time=slice(date_min,date_max))
var='norm_sm'

#compute normalized SSM over the given period
ds_noirr = normalize_sm(ds_noirr, 'mrsos', date_min, date_max)
ds_irr = normalize_sm(ds_irr, 'mrsos', date_min, date_max)
ds_obs = normalize_sm(ds_obs, 'sm', date_min, date_max)

# Seasonal cycle
ax7 = fig.add_subplot(gs[2,0])
vmin_sm_ts=-1.3
vmax_sm_ts=1.3
ylabel="Normalized SSM (-)"
plotvar1 = ds_obs[var].mean(dim=['lon', 'lat']).groupby('time.month').mean(dim='time')
color1 = 'black'
label1 = 'CCI'
nice_time_plot(plotvar1, ax7, vmin=vmin_sm_ts, label=label1, color=color1, ylabel=ylabel)
plotvar2 = ds_noirr[var].mean(dim=['lon', 'lat']).groupby('time.month').mean(dim='time')
color2 = 'red'
label2 = 'no_irr'
nice_time_plot(plotvar2, ax7, vmin=vmin_sm_ts, label=label2, color=color2, ylabel=ylabel)
plotvar3 = ds_irr[var].mean(dim=['lon', 'lat']).groupby('time.month').mean(dim='time')
color3 = 'blue'
label3 = 'irr'
nice_time_plot(plotvar3, ax7, vmin=vmin_sm_ts, vmax=vmax_sm_ts, label=label3, color=color3, ylabel=ylabel)
ax7.set_title('(g) Mean seasonnal cycle (2010-2019)')
ax7.set_xticks(np.arange(1, 13))
ax7.set_xticklabels(months_name_list)

#Diff
ax8 = fig.add_subplot(gs[2,1], projection=ccrs.PlateCarree())
plotvar=(ds_noirr[var]-ds_obs[var]).mean(dim='time')
cmap=emb_neutral
vmin=vmin_sm
vmax=vmax_sm
clabel="Normalized SSM bias (-)"
nice_map(plotvar, ax8, cmap, vmin, vmax, clabel=clabel, cbar_on=False)
ax8.set_title('(h) no_irr - CCI')

ax9 = fig.add_subplot(gs[2,2], projection=ccrs.PlateCarree())
plotvar=(ds_irr[var]-ds_obs[var]).mean(dim='time')
cmap=emb_neutral
vmin=vmin_sm
vmax=vmax_sm
clabel="Normalized SSM bias (-)"
nice_map(plotvar, ax9, cmap, vmin, vmax, clabel=clabel, left_labels=False)
ax9.set_title('(i) irr - CCI')

plt.tight_layout()

# Screening and masking CCI data

In [None]:
stop

In [None]:
# filename = '../../../obs/CCI-SM_old/C3S*.nc'
filename = '../../../obs/CCI-SM/esacci/soil_moisture/data/daily_files/COMBINED/v09.1/*/ESACCI*.nc'

cci = xr.open_mfdataset(filename)
cci.attrs['name'] = 'CCI'
cci.attrs['plot_color']='black'

cci = cci.sel(lon=slice(lon_min,lon_max),lat=slice(lat_max,lat_min))


cci['humtot'] = cci['sm'] * 2 * 1000
cci['humtot'].attrs['units'] = 'mm'
cci['mrsos'] = cci['sm'] * 0.1 * 1000
cci['mrsos'].attrs['units'] = 'mm'

cci

In [None]:
cci_monthly = cci.resample(time='MS').mean()
cci_monthly['time'] = cci_monthly['time'] + pd.Timedelta(days=14)
cci_monthly

In [None]:
# Export to NetCDF
outfilename='../../../obs/CCI-SM/cci_sm_monthly.nc'
cci_monthly.to_netcdf(outfilename)

In [None]:
def count_var_value(ds: xr.Dataset, var: str, threshold: float) -> int:
    """
    Function to count how many data records are over a specific threshold.

    Accounts for dimensions time, lon, lat. The count is performed along the 
    'time' dimension, resulting in a 2D map (lat, lon) of event counts.

    Args:
        ds: The xarray.Dataset containing the variable.
        var: The string name of the variable to check (e.g., 'pr' for precipitation).
        threshold: The numeric value that data records must exceed.

    Returns:
        An xarray.DataArray (lat, lon) showing the total count of times 
        the threshold was exceeded at each grid cell.
    """
    
    # 1. Select the DataArray and apply the threshold condition
    # This creates a boolean DataArray where True indicates the threshold is met.
    # Dimensions: (time, lat, lon)
    exceedance_mask = ds[var] > threshold
    
    # 2. Convert the boolean mask to integers (True -> 1, False -> 0)
    # This prepares the data for summation.
    # Dimensions: (time, lat, lon)
    count_data = exceedance_mask.astype(int)
    
    # 3. Sum the counts along the 'time' dimension
    # The sum of 1s and 0s gives the total number of exceedances.
    # Dimensions: (lat, lon)
    exceedance_count = count_data.sum(dim=['time','lon','lat'])

    return exceedance_count

In [None]:
for threshold in [-1, 0, 1]:#, 2, 4, 8, 16, 32, 64, 128]:
    test = count_var_value(newcci, 'flag', threshold)
    print(f'threshold: {threshold} : {test.values}')

In [None]:
for threshold in [0, 0.01, 0.06]:
    test = count_var_value(newcci, 'sm_uncertainty', threshold)
    print(f'threshold: {threshold} : {test.values}')

In [None]:
for threshold in [-1e10, -1, 0]:
    test = count_var_value(newcci, 'mrsos', threshold)
    print(f'threshold: {threshold} : {test.values}')

In [None]:
quality_mask = (cci['flag']<= 0)
uncertainty_mask = (cci['sm_uncertainty'] < 0.06)

screened_cci = cci.where(quality_mask).where(uncertainty_mask)
screened_cci


In [None]:
monthly_screened = screened_cci.resample(time='MS').mean()
monthly_screened

In [None]:
# Export to NetCDF
outfilename='../../../obs/CCI-SM/cci_sm_monthly_filtered.nc'
monthly_screened.to_netcdf(outfilename)