# Calculate Percentage of Different Severities of Drought using SPI; AGCD and CORDEX Simulations

Using drought thresholds as defined in McKee et al. (1993): https://www.droughtmanagement.info/literature/AMS_Relationship_Drought_Frequency_Duration_Time_Scales_1993.pdf

In [None]:
# Import modules
%matplotlib inline
%run /g/data/w40/ri9247/code/aus_precip_benchmarking/master_functions_bmf.ipynb
import fnmatch
import xarray as xr
import pandas as pd
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker
import matplotlib.gridspec as gridspec

## Define Spatiotemporal Boundaries and Masks

In [None]:
# Define region of interest (lat and lon boundaries for all of Australia)
lat_slice = slice(-44.5,-10)
lon_slice = slice(112,156.25)
time_slice = slice("1976-01-01", "2005-12-30")

# Define temporal scale for SPI variable (1 for 3-months, 2 for 6-months, 3 for 12-months); use "None" if not using the SPI variable
#iscale = 2

# Define a list for all iscale values to calculate SPI at each SPI averaging period
iscale_list = [1,2,3]

# Define season as a list of month numbers. If no seasonal breakdown, use None.
season = None 
season_name = 'Annual' 

# Path to combined quality mask
qc_mask_ds = xr.open_dataset('/g/data/w40/ri9247/CORDEX-Australasia/data/obs/AUS-44i_grid/no_indices/mask/AUS-44i_combined_quality_mask_no_oceans.nc')

# Extract mask over Australia
qc_mask = qc_mask_ds.station_mask.sel(lat=lat_slice, lon=lon_slice)

## Define Keywords for Data and Data Paths

Keywords needed are bracketed {}. <br><br>
**RCM** database is organized following: <br>
**Climpact Indices:** <br>
parent_directory/{time_period}/{index-keyword}/{variable}_{time_average} _dataset_file.nc 
<br><br>
**Observations** database is organized as follows: <br>
obs_parent_directory/grid_type}/climpact/{variable}/{variable} _{time_average}_agcd_historical_v1_1950-2020.nc <br><br>
**Keyword Options:** <br>
grid_type: 'AUS-44i_grid', 'native_grid', 'one_degree_grid' <br>
time_period: 'historical', 'rcp85' <br>
variable: See list of Climpact Indices at: https://climpact-sci.org/indices; must be all-lowercase <br>
time_average: 'ANN', 'MON'

### Using Climpact Indices

In [None]:
# Define keywords to be used with f-strings to build the path to datasets
time_period = 'historical'
variable = 'spi'
time_average = 'MON'
grid_type = 'AUS-44i_grid'

# Define paths to data
model_master_path = '/g/data/ks32/CLEX_Data/CORDEX_Australasia_Indices/v1-0/'
obs_master_path = '/g/data/w40/ri9247/CORDEX-Australasia/data/obs/'

model_data_path = model_master_path + f'{time_period}/{variable}/'
agcd_data_path = obs_master_path + f'{grid_type}/climpact/{variable}/{variable}_{time_average}_agcd_historical_v1_1950-2020.nc'

## Define list for subset of models

In [None]:
# List names of model subsets. This is dependent on model performance against the Minimum Standards Metrics and will need to be updated for other applications
subset_names = [
    "ACCESS1-0   CCAM-1704"
    , "ACCESS1-0   CCAM-2008"
    , "ACCESS1-0   WRF360J"
    , "ACCESS1-0   WRF360K"
    , "CanESM2   CCAM-2008"
    , "CanESM2   WRF360J"
    , "CNRM-CM5   CCAM-1704"
    , "GFDL-ESM2M   CCAM-1704"
    , "GFDL-ESM2M   CCAM-2008"
    , "HadGEM2-CC   CCAM-1704"
    , "HadGEM2-ES   CCLM5-0-15"
    , "HadGEM2-ES   REMO2015"
    , "MIROC5   CCAM-1704"
    , "MIROC5   CCAM-2008"
    , "MPI-ESM-LR   CCLM5-0-15"
    , "MPI-ESM-LR   REMO2015"
    , "MPI-ESM-MR   RegCM4-7"
    , "NorESM1-M   CCAM-1704"
    , "NorESM1-M   RegCM4-7"
    , "NorESM1-M   REMO2015"
]

## Get Data File Paths

#### Models

In [None]:
# Get all model paths in full ensemble (this will be stored in a Pandas DataFrame)
model_paths = get_model_files(model_data_path, time_average)

# Get file paths for subset of models
model_paths_subset = get_model_files_subset(model_paths, subset_names)
pd.set_option('display.max_colwidth', None)

# Option to print paths to confirm we get the correct files
model_paths_subset

## Get Weighted Spatial Average for each SPI Temporal Interval at Each Time Step

#### Obs - AGCD

In [None]:
# Initialize empty dictionary to store Obs time series of SPI at different time intervals
obs_ts_dict = {}

# Loop through iscale list and calculate time series for each iscale value for the observational dataset
for iscale in iscale_list:
    
    obs_ts = get_weighted_spatial_average_at_default_time_step(agcd_data_path, variable, time_slice, lat_slice, lon_slice, season, iscale, qc_mask)
    
    # Define dictionary keys based on the temporal interval of the SPI calculation (1 for 3-months, 2 for 6-months, 3 for 12-months)
    if iscale == 1:
        obs_ts_dict['obs_3mon'] = obs_ts
        
    elif iscale == 2:
        obs_ts_dict['obs_6mon'] = obs_ts
        
    elif iscale == 3:
        obs_ts_dict['obs_12mon'] = obs_ts

# Option to print dictionary keys
#print(obs_ts_dict.keys())

#### Models

In [None]:
# Intialize empty dictionaries to store model timeseries SPI at different time intervals
model_3mon_ts_dict = {}
model_6mon_ts_dict = {}
model_12mon_ts_dict = {}

# Loop through iscale list and calculate time series for each iscale value for each model dataset
for iscale in iscale_list:
    
    # Loop through RCM simulations; store time series in appropriate dictionaries
    for i, row in model_paths_subset.iterrows():
    
        # Get weighted spatial average for time series
        model_ts = get_weighted_spatial_average_at_default_time_step(model_paths_subset.iloc[i,1], variable, time_slice, lat_slice, lon_slice, season, iscale, qc_mask)
    
        # Add to appropriate dictionary based on the temporal interval of the SPI calculation (1 for 3-months, 2 for 6-months, 3 for 12-months)
        if iscale == 1:
            model_3mon_ts_dict[model_paths_subset.iloc[i,0]] = model_ts
        
        elif iscale == 2:
            model_6mon_ts_dict[model_paths_subset.iloc[i,0]] = model_ts
        
        elif iscale == 3:
            model_12mon_ts_dict[model_paths_subset.iloc[i,0]] = model_ts

# Option to print dictionary keys for one dictionary
print(model_12mon_ts_dict.keys())

## Calculate Percentage of Each Category of Drought across the time series (McKee et al. 1993)

### Define specific thresholds for drought categories

In [None]:
# Define list for categories of drought
#drought_cats = ['mild', 'moderate', 'severe', 'extreme'] # McKee et al 1993 categories
drought_cats = ['moderate', 'severe', 'extreme'] # WMO 2012 categories

# Define thresholds for categories of drought
# No min value for 'extreme' category: anything under -2.00

#mild_max = 0
#mild_min = -0.99

moderate_max = -1.00
moderate_min = -1.49

severe_max = -1.50
severe_min = -1.99

extreme_max = -2.00

### Obs

In [None]:
# Initialize Pandas DataFrame to store percentage of time series in each category of drought
obs_master_drought_df = pd.DataFrame(columns=['dataset_name', 'moderate', 'severe', 'extreme'])

# Loop through each observational time series
for obs_ts in obs_ts_dict.keys():
    
    # Loop through each drought category and calculate percentage of time series in that category of drought
    for drought_type in drought_cats:
        
        #if drought_type == 'mild':
        #    obs_perc_mild = ((obs_ts_dict[obs_ts] >= mild_min) & (obs_ts_dict[obs_ts] <= mild_max)).mean() * 100
            
        if drought_type == 'moderate':
            obs_perc_moderate = ((obs_ts_dict[obs_ts] >= moderate_min) & (obs_ts_dict[obs_ts] <= moderate_max)).mean() * 100
        
        elif drought_type == 'severe':
            obs_perc_severe = ((obs_ts_dict[obs_ts] >= severe_min) & (obs_ts_dict[obs_ts] <= severe_max)).mean() * 100
            
        elif drought_type == 'extreme':
            obs_perc_extreme = ((obs_ts_dict[obs_ts] <= extreme_max)).mean() * 100
            
    # Create Pandas DataFrame for each obs_ts
    obs_percentage_df = pd.DataFrame({'dataset_name': f'{obs_ts}', 'moderate': obs_perc_moderate.item(0), 'severe': obs_perc_severe.item(0), 'extreme': obs_perc_extreme.item(0)}, index=[0])
   
    # Add obs dataframe to master obs DataFrame
    obs_master_drought_df = pd.concat([obs_master_drought_df, obs_percentage_df], ignore_index=True)
    
# print obs master dataframe
obs_master_drought_df    

### Models
Calculate percentage of time series in each drought category separately based on SPI averaging period

### 3-Month SPI

In [None]:
# 3-Month SPI 

# Initialize Pandas DataFrame to store percentage of time series in each category of drought
model_3mon_master_drought_df = pd.DataFrame(columns=['dataset_name', 'moderate', 'severe', 'extreme'])

# Loop through each model time series
for model_ts in model_3mon_ts_dict.keys():
    
    # Loop through each drought category and calculate percentage of time series in that category of drought
    for drought_type in drought_cats:
        
        #if drought_type == 'mild':
        #    model_perc_mild = ((model_3mon_ts_dict[model_ts] >= mild_min) & (model_3mon_ts_dict[model_ts] <= mild_max)).mean() * 100
            
        if drought_type == 'moderate':
            model_perc_moderate = ((model_3mon_ts_dict[model_ts] >= moderate_min) & (model_3mon_ts_dict[model_ts] <= moderate_max)).mean() * 100
        
        elif drought_type == 'severe':
            model_perc_severe = ((model_3mon_ts_dict[model_ts] >= severe_min) & (model_3mon_ts_dict[model_ts] <= severe_max)).mean() * 100
            
        elif drought_type == 'extreme':
            model_perc_extreme = ((model_3mon_ts_dict[model_ts] <= extreme_max)).mean() * 100
            
    # Create Pandas DataFrame for each model_ts
    #model_percentage_df = pd.DataFrame({'dataset_name': f'{model_ts}', 'mild': model_perc_mild.item(0), 'moderate': model_perc_moderate.item(0), 'severe': model_perc_severe.item(0), 'extreme': model_perc_extreme.item(0)}, index=[0])
    model_percentage_df = pd.DataFrame({'dataset_name': f'{model_ts}', 'moderate': model_perc_moderate.item(0), 'severe': model_perc_severe.item(0), 'extreme': model_perc_extreme.item(0)}, index=[0])
   
    # Add model dataframe to master model DataFrame
    model_3mon_master_drought_df = pd.concat([model_3mon_master_drought_df, model_percentage_df], ignore_index=True)
    
# print master dataframe
model_3mon_master_drought_df  

### 6-Month SPI

In [None]:
# 6-Month SPI

# Initialize Pandas DataFrame to store percentage of time series in each category of drought
model_6mon_master_drought_df = pd.DataFrame(columns=['dataset_name', 'moderate', 'severe', 'extreme'])

# Loop through each model time series
for model_ts in model_6mon_ts_dict.keys():
    
    # Loop through each drought category and calculate percentage of time series in that category of drought
    for drought_type in drought_cats:
        
        #if drought_type == 'mild':
        #    model_perc_mild = ((model_6mon_ts_dict[model_ts] >= mild_min) & (model_6mon_ts_dict[model_ts] <= mild_max)).mean() * 100
            
        if drought_type == 'moderate':
            model_perc_moderate = ((model_6mon_ts_dict[model_ts] >= moderate_min) & (model_6mon_ts_dict[model_ts] <= moderate_max)).mean() * 100
        
        elif drought_type == 'severe':
            model_perc_severe = ((model_6mon_ts_dict[model_ts] >= severe_min) & (model_6mon_ts_dict[model_ts] <= severe_max)).mean() * 100
            
        elif drought_type == 'extreme':
            model_perc_extreme = ((model_6mon_ts_dict[model_ts] <= extreme_max)).mean() * 100
            
    # Create Pandas DataFrame for each model_ts
    #model_percentage_df = pd.DataFrame({'dataset_name': f'{model_ts}', 'mild': model_perc_mild.item(0), 'moderate': model_perc_moderate.item(0), 'severe': model_perc_severe.item(0), 'extreme': model_perc_extreme.item(0)}, index=[0])
    model_percentage_df = pd.DataFrame({'dataset_name': f'{model_ts}', 'moderate': model_perc_moderate.item(0), 'severe': model_perc_severe.item(0), 'extreme': model_perc_extreme.item(0)}, index=[0])
   
    # Add model dataframe to master model DataFrame
    model_6mon_master_drought_df = pd.concat([model_6mon_master_drought_df, model_percentage_df], ignore_index=True)
    
# print master dataframe
model_6mon_master_drought_df

### 12-Month SPI

In [None]:
# 12-Month SPI

# Initialize Pandas DataFrame to store percentage of time series in each category of drought
model_12mon_master_drought_df = pd.DataFrame(columns=['dataset_name', 'moderate', 'severe', 'extreme'])

# Loop through each model time series
for model_ts in model_12mon_ts_dict.keys():
    
    # Loop through each drought category and calculate percentage of time series in that category of drought
    for drought_type in drought_cats:
        
        #if drought_type == 'mild':
        #    model_perc_mild = ((model_12mon_ts_dict[model_ts] >= mild_min) & (model_12mon_ts_dict[model_ts] <= mild_max)).mean() * 100
            
        if drought_type == 'moderate':
            model_perc_moderate = ((model_12mon_ts_dict[model_ts] >= moderate_min) & (model_12mon_ts_dict[model_ts] <= moderate_max)).mean() * 100
        
        elif drought_type == 'severe':
            model_perc_severe = ((model_12mon_ts_dict[model_ts] >= severe_min) & (model_12mon_ts_dict[model_ts] <= severe_max)).mean() * 100
            
        elif drought_type == 'extreme':
            model_perc_extreme = ((model_12mon_ts_dict[model_ts] <= extreme_max)).mean() * 100
            
    # Create Pandas DataFrame for each model_ts
    model_percentage_df = pd.DataFrame({'dataset_name': f'{model_ts}', 'moderate': model_perc_moderate.item(0), 'severe': model_perc_severe.item(0), 'extreme': model_perc_extreme.item(0)}, index=[0])
   
    # Add model dataframe to master model DataFrame
    model_12mon_master_drought_df = pd.concat([model_12mon_master_drought_df, model_percentage_df], ignore_index=True)
    
# print master dataframe
model_12mon_master_drought_df

## Plot Time series and fill areas that fall within threshold

In [None]:
# All dataset names for subset of models - alphabatized for plotting

# Alphabetized by forcing GCM (Obs first)
dataset_names = [
    "AGCD"
    ,"ACCESS1-0   CCAM-1704"
    , "ACCESS1-0   CCAM-2008"
    , "ACCESS1-0   WRF360J"
    , "ACCESS1-0   WRF360K"
    , "CanESM2   CCAM-2008"
    , "CanESM2   WRF360J"
    , "CNRM-CM5   CCAM-1704"
    , "GFDL-ESM2M   CCAM-1704"
    , "GFDL-ESM2M   CCAM-2008"
    , "HadGEM2-CC   CCAM-1704"
    , "HadGEM2-ES   CCLM5-0-15"
    , "HadGEM2-ES   REMO2015"
    , "MIROC5   CCAM-1704"
    , "MIROC5   CCAM-2008"
    , "MPI-ESM-LR   CCLM5-0-15"
    , "MPI-ESM-LR   REMO2015"
    , "MPI-ESM-MR   RegCM4-7"
    , "NorESM1-M   CCAM-1704"
    , "NorESM1-M   RegCM4-7"
    , "NorESM1-M   REMO2015"
]

# Alphabetized by RCM to match the organization of the tables in the paper
dataset_names_rcm_sorted = [
    "AGCD"
    , "ACCESS1-0   CCAM-1704"
    , "CNRM-CM5   CCAM-1704"
    , "GFDL-ESM2M   CCAM-1704"
    , "HadGEM2-CC   CCAM-1704"
    , "MIROC5   CCAM-1704"
    , "NorESM1-M   CCAM-1704"
    , "ACCESS1-0   CCAM-2008"
    , "CanESM2   CCAM-2008"
    , "GFDL-ESM2M   CCAM-2008"
    , "MIROC5   CCAM-2008"
    , "HadGEM2-ES   CCLM5-0-15"
    , "MPI-ESM-LR   CCLM5-0-15"
    , "MPI-ESM-MR   RegCM4-7"
    , "NorESM1-M   RegCM4-7"
    , "HadGEM2-ES   REMO2015"
    , "MPI-ESM-LR   REMO2015"
    , "NorESM1-M   REMO2015"
    , "ACCESS1-0   WRF360J"
    , "CanESM2   WRF360J"
    , "ACCESS1-0   WRF360K"
]

In [None]:
# Define plot numbers where we want to plot x- and/or y-axis labels (this will need to be updated based on the number of datasets used)
# 7 X 3 Plot
y_label = [0,3,6,9,12,15]
x_label = [19,20]
both_label = [18]

### Define colors to shade different categories of drought

In [None]:
# Define function to find areas that need to be shaded using a boolean mask as input
def fill_vertical_columns(boolean_fill_mask):
    
    # Find areas in the mask when values change (i.e. boolean switched from True to False and vice versa)
    boolean_switch = np.diff(boolean_fill_mask)
    
    # Find start and end of sections where the boolean fill mask is True (places I want to shade)
    region_to_shade, = boolean_switch.nonzero()
    
    # Handle edge cases where condition starts or ends with True
    if boolean_fill_mask[0]:
        region_to_shade = np.r_[0, region_to_shade]
   
    if boolean_fill_mask[-1]:
        region_to_shade = np.r_[region_to_shade, len(boolean_fill_mask)]
    
    # Reshape the result into pairs of start/end indices
    region_to_shade = region_to_shade.reshape((-1, 2))
    
    return region_to_shade

### 3-Month SPI

In [None]:
# Define figure size
fig = plt.figure(figsize=(17,18))
fig.suptitle(' 3-Month SPI (1976-2005)', fontsize=16, y=0.92)

# Setup axes for all subplots
gs = gridspec.GridSpec(5,4,width_ratios=[1,1,1,1], height_ratios=[1,1,1,1,1])

row_max = 4
col_max = 3
row = 0
col = 0

axs = {}

# Set up Axes labels (this loops through the sorted Pandas DF to assign axes positions based on largest slope)
# For Models
for dataset_name in dataset_names:
    axs[f'{dataset_name}'] = fig.add_subplot(gs[row,col])
    
    if col == col_max:
        row = row + 1
        col = 0 
    else:
        col = col + 1

# Add subplot titles
for name, ax in axs.items():
    ax.set_title(name, fontsize=13)
    ax.set_ylim(-2.5, 2.5)

time = obs_ts_dict['obs_3mon'].time.values
year_loc = mdates.YearLocator(5)  # Set every 5 years
year_fmt = mdates.DateFormatter('%Y')

# Add Data to Plots
# AGCD
axs['AGCD'].plot(time, obs_ts_dict['obs_3mon'].values.round(2), color='black', linestyle='-')

# Models
for model_name in model_3mon_ts_dict.keys():
    axs[model_name].plot(time, model_3mon_ts_dict[model_name].values.round(2), color='black', linestyle='-')
    
# Set axis ticks for only left and bottom figures    
for i, ax in enumerate(fig.axes):
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    if i in y_label:
        ax.set_yticklabels([-2.5, -2.0, -1.0, 0, 1.0, 2.0])
        ax.set_ylabel("SPI", fontsize=12)
    elif i in x_label:
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
    elif i in both_label:
        ax.set_yticklabels([-2.5, 2.0, -1.0, 0, 1.0, 2.0])
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
        ax.set_ylabel("SPI", fontsize=12)
    else:
        continue
        
# Add Data to Plots
# AGCD
#axs['AGCD'].plot(time, obs_ts_dict['obs_3mon'].values, color='black', linestyle='-')

# Models
#for model_name in model_3mon_ts_dict.keys():
#    axs[model_name].plot(time, model_3mon_ts_dict[model_name].values, color='black', linestyle='-')

### 6-Month SPI

In [None]:
# Define figure size
fig = plt.figure(figsize=(17,18))
fig.suptitle(' 6-Month SPI (1976-2005)', fontsize=16, y=0.92)

# Setup axes for all subplots
gs = gridspec.GridSpec(5,4,width_ratios=[1,1,1,1], height_ratios=[1,1,1,1,1])

row_max = 4
col_max = 3
row = 0
col = 0

axs = {}

# Set up Axes labels (this loops through the sorted Pandas DF to assign axes positions based on largest slope)
# For Models
for dataset_name in dataset_names:
    axs[f'{dataset_name}'] = fig.add_subplot(gs[row,col])
    
    if col == col_max:
        row = row + 1
        col = 0 
    else:
        col = col + 1

# Add subplot titles
for name, ax in axs.items():
    ax.set_title(name, fontsize=13)
    ax.set_ylim(-2.5, 2.5)

time = obs_ts_dict['obs_6mon'].time.values
year_loc = mdates.YearLocator(5)  # Set every 5 years
year_fmt = mdates.DateFormatter('%Y')

# Add Data to Plots
# AGCD
axs['AGCD'].plot(time, obs_ts_dict['obs_6mon'].values.round(2), color='black', linestyle='-')

# Models
for model_name in model_6mon_ts_dict.keys():
    axs[model_name].plot(time, model_6mon_ts_dict[model_name].values.round(2), color='black', linestyle='-')
    
# Set axis ticks for only left and bottom figures    
for i, ax in enumerate(fig.axes):
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    if i in y_label:
        ax.set_yticklabels([-2.5, -2.0, -1.0, 0, 1.0, 2.0])
        ax.set_ylabel("SPI", fontsize=12)
    elif i in x_label:
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
    elif i in both_label:
        ax.set_yticklabels([-2.5, 2.0, -1.0, 0, 1.0, 2.0])
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
        ax.set_ylabel("SPI", fontsize=12)
    else:
        continue

### 12-Month SPI

In [None]:
# Define figure size
fig = plt.figure(figsize=(17,22))
fig.suptitle(' 12-Month SPI (1976-2005)', fontsize=16, y=0.91)

# Setup axes for all subplots
gs = gridspec.GridSpec(7,3,width_ratios=[1.5,1.5,1.5], height_ratios=[0.8,0.8,0.8,0.8,0.8,0.8,0.8])

row_max = 6
col_max = 2
row = 0
col = 0

axs = {}

# Set up Axes labels (this loops through the sorted Pandas DF to assign axes positions based on largest slope)
# For Models
for dataset_name in dataset_names_rcm_sorted:
    axs[f'{dataset_name}'] = fig.add_subplot(gs[row,col])
    
    if col == col_max:
        row = row + 1
        col = 0 
    else:
        col = col + 1

# Add subplot titles
for name, ax in axs.items():
    ax.set_title(name, fontsize=14)
    ax.set_ylim(-2.5, 2.5)

time = obs_ts_dict['obs_12mon'].time.values
year_loc = mdates.YearLocator(5)  # Set every 5 years
year_fmt = mdates.DateFormatter('%Y')

# Add Data to Plots
# AGCD
data = obs_ts_dict['obs_12mon'].values.round(2)

# Plot time series and brown horizontal 0-line
axs['AGCD'].plot(time, data, color='black', linestyle='-')
axs['AGCD'].axhline(y=0, color='sienna', linestyle='-')

# Add shading along time series for each category of drought
# Create boolean masks for each drought category
#no_drought_mask = data > mild_max
#mild_mask = (data <= mild_max) & (data > mild_min)
moderate_mask = (data <= moderate_max) & (data > moderate_min)
severe_mask = (data <= severe_max) & (data > severe_min)
extreme_mask = data <= extreme_max

# Create vertical shaded regions for each drought category
for ts_mask, color, label in zip([moderate_mask, severe_mask, extreme_mask], 
                              ['orange', 'crimson', 'darkred'],
                              ['Moderate', 'Severe', 'Extreme']):
    
    # find contiguous regions where the mask is True
    regions = fill_vertical_columns(ts_mask)
    for start, end in regions:
        # create shaded region for this contiguous region
        axs['AGCD'].axvspan(time[start], time[end-1], color=color, alpha=0.2, label=label)


# Models
for model_name in model_12mon_ts_dict.keys():
    
    # Plot fussy models separately
    if model_name in ["MIROC5   CCAM-2008", "ACCESS1-0   WRF360K"]:
        data = model_12mon_ts_dict[model_name].values.round(2)
        time = model_12mon_ts_dict[model_name]['time'].values.astype("datetime64[ns]")
        year_loc = mdates.YearLocator(5)  # Set every 5 years
        year_fmt = mdates.DateFormatter('%Y')
        
    
        # Plot time series and brown horizontal 0-line
        axs[model_name].plot(time, data, color='black', linestyle='-')
        axs[model_name].axhline(y=0, color='sienna', linestyle='-')
        
        # Add shading along the time series for each category of drought
        # Create boolean masks for each drought category
        no_drought_mask = data > moderate_max
        #mild_mask = (data < mild_max) & (data > mild_min)
        moderate_mask = (data <= moderate_max) & (data > moderate_min)
        severe_mask = (data <= severe_max) & (data > severe_min)
        extreme_mask = data <= extreme_max

   
        # create vertical shaded regions for each drought category
        for ts_mask, color, label in zip([moderate_mask, severe_mask, extreme_mask], 
                                      ['orange', 'crimson', 'darkred'],
                                      ['Moderate', 'Severe', 'Extreme']):
    
            # find contiguous regions where the mask is True
            regions_fussy = fill_vertical_columns(ts_mask)
            for start, end in regions_fussy:
                if start == 0 & end == 0:
                    continue
                else:
                    # create shaded region for this contiguous region
                    axs[model_name].axvspan(time[start], time[end-1], color=color, alpha=0.2, label=label)
                
    else:
        
        data = model_12mon_ts_dict[model_name].values.round(2)
        time = model_12mon_ts_dict[model_name]['time'].values.astype("datetime64[ns]")
        year_loc = mdates.YearLocator(5)  # Set every 5 years
        year_fmt = mdates.DateFormatter('%Y')
        
    
        # Plot time series and brown horizontal 0-line
        axs[model_name].plot(time, data, color='black', linestyle='-')
        axs[model_name].axhline(y=0, color='sienna', linestyle='-')
        
        # Add shading along time series for each category of drought
        # Create boolean masks for each drought category
        #mild_mask = (data < mild_max) & (data > mild_min)
        moderate_mask = (data <= moderate_max) & (data > moderate_min)
        severe_mask = (data <= severe_max) & (data > severe_min)
        extreme_mask = data <= extreme_max

                
        # create vertical shaded regions for each drought category
        for ts_mask, color, label in zip([moderate_mask, severe_mask, extreme_mask], 
                                      ['orange', 'crimson', 'darkred'],
                                      ['Moderate', 'Severe', 'Extreme']):
    
            # find contiguous regions where the mask is True
            regions = fill_vertical_columns(ts_mask)
            for start, end in regions:
                # create shaded region for this contiguous region
                axs[model_name].axvspan(time[start], time[end-1], color=color, alpha=0.2, label=label)
        
# Set axis ticks for only left and bottom figures    
for i, ax in enumerate(fig.axes):
    ax.set_xticklabels([])
    ax.set_yticklabels([-2.5, -2.0, -1.0, 0, 1.0, 2.0])
    if i in y_label:
        #ax.set_yticklabels([-2.5, -2.0, -1.0, 0, 1.0, 2.0])
        ax.set_ylabel("SPI", fontsize=12)
    elif i in x_label:
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
    elif i in both_label:
        #ax.set_yticklabels([-2.5, -2.0, -1.0, 0, 1.0, 2.0])
        ax.xaxis.set_major_locator(year_loc)
        ax.xaxis.set_major_formatter(year_fmt)
        ax.set_xlabel("year", fontsize=11)
        ax.set_ylabel("SPI", fontsize=12)
    else:
        continue
        
plt.gcf().text(0.36, 0.08, "Drought Categories: ", rotation='horizontal', fontsize=12, color='black', weight='bold')
#plt.gcf().text(0.47, 0.08, "Mild", rotation='horizontal', fontsize=10, color='gold', weight='bold')
plt.gcf().text(0.48, 0.08, "Moderate", rotation='horizontal', fontsize=12, color='darkorange', weight='bold')
plt.gcf().text(0.55, 0.08, "Severe", rotation='horizontal', fontsize=12, color='crimson', weight='bold')
plt.gcf().text(0.61, 0.08, "Extreme", rotation='horizontal', fontsize=12, color="darkred", weight='bold')