In [1]:
import intake
import requests
import aiohttp
import s3fs
import geopandas as gpd
from shapely.geometry import mapping
import xarray as xr
import numpy as np
import dask.array as da
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
from cartopy.feature import NaturalEarthFeature
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader

# Map changes in length of temperature/precipitation extremes in LOCA downscaled projections
## Jan 2025

In [2]:
# set parameters
data_preload = 0  # saves data to local zarr store to make access faster

data_savepath = '/glade/derecho/scratch/samantha/CAFifthAssessment/zarr_data/'

calc_prctile = 0  # do we need to calculate the percentiles associated with heat waves? 0 = no, 1 = yes

In [3]:
# Open catalog of climate data compiled for the Fifth Assessment Report
cat = intake.open_esm_datastore(
    'https://cadcat.s3.amazonaws.com/cae-collection.json'
)

# List set of models
models = cat.df.source_id.unique()

# note: HadGEM3-GC31-LL only has precipitation data for SSP245 and SSP585
# Filter the list to exclude this model
models = [s for s in models if (s != 'HadGEM3-GC31-LL' and s != 'ensmean' and s != 'ensmean' and s != 'ERA5')]
print(models)

['ACCESS-CM2', 'CESM2-LENS', 'CNRM-ESM2-1', 'EC-Earth3', 'EC-Earth3-Veg', 'FGOALS-g3', 'GFDL-ESM4', 'INM-CM5-0', 'IPSL-CM6A-LR', 'KACE-1-0-G', 'MIROC6', 'MPI-ESM1-2-HR', 'MRI-ESM2-0', 'TaiESM1', 'CESM2']


In [4]:
# Pull out all records associated with daily data
cat_access_hist = cat.search(
    activity_id = "LOCA2",
    experiment_id = "historical",
    table_id = "day"
)

# See which variables have daily information available
variable_ids = cat_access_hist.df['variable_id'].unique()
print(variable_ids)

['hursmax' 'hursmin' 'huss' 'pr' 'rsds' 'tasmax' 'tasmin' 'uas' 'vas'
 'wspeed']


In [5]:
# Loop over models, calculate exceedance lengths during the time periods specified in author guidance
# "historical" = 1981-2010
# "mid-century" = 2041-2070
# "late-century" = 2071-2100

data_hist = []
data_mid = []
data_late = []

for m in range(len(models)-1):
    thismod = models[m]
    print(thismod)

    # Output arrays
    histdata_thismod = []
    middata_thismod = []
    enddata_thismod = []

    # Extract data paths for this model: historical
    cat_access_hist = cat.search(
        activity_id = "LOCA2",
        source_id = thismod,
        experiment_id = "historical",
        variable_id = "tasmax",
        table_id = "day"
    )
    # Extract data paths for this model: SSP3-7.0
    cat_access_ssp = cat.search(
        activity_id = "LOCA2",
        source_id = thismod,
        experiment_id = "ssp370",
        variable_id = "tasmax",
        table_id = "day"
    )

    # Loop over data files, extract historical information    
    for f in range(len(cat_access_hist.df['path'])):
        print(cat_access_hist.df['path'][f])

        # Preload data if specified
        savepath_hist = data_savepath + cat_access_hist.df['source_id'][f] + '_' + cat_access_hist.df['member_id'][f] + '_' + cat_access_hist.df['table_id'][f] + '_' + cat_access_hist.df['variable_id'][f] + '_hist.zarr'  

        if data_preload == 1:
            ds = xr.open_zarr(cat_access_hist.df['path'][f], storage_options={'anon': True})

            # Historical period
            hist_data = ds.sel(time=slice('1981-01-01', '2010-12-31'))
            # Store in zarr for later use
            print("Saving file: " + savepath_hist)
            hist_data.to_zarr(savepath_hist, mode='w', safe_chunks=False)

        hist_data = xr.open_zarr(savepath_hist)

        # Append to list
        histdata_thismod.append(hist_data)

    # Loop over data files, extract SSP370 information
    for f in range(len(cat_access_ssp.df['path'])):
        print(cat_access_ssp.df['path'][f])

        # Preload data if specified
        savepath_mid = data_savepath + cat_access_ssp.df['source_id'][f] + '_' + cat_access_ssp.df['member_id'][f] + '_' + cat_access_ssp.df['table_id'][f] + '_' + cat_access_ssp.df['variable_id'][f] + '_mid.zarr'  
        savepath_end = data_savepath + cat_access_ssp.df['source_id'][f] + '_' + cat_access_ssp.df['member_id'][f] + '_' + cat_access_ssp.df['table_id'][f] + '_' + cat_access_ssp.df['variable_id'][f] + '_end.zarr'  

        if data_preload == 1:
            ds = xr.open_zarr(cat_access_ssp.df['path'][f], storage_options={'anon': True})   

            # Mid-century
            mid_data = ds.sel(time=slice('2041-01-01', '2070-12-31'))
            # Store in zarr for later use
            print("Saving file: " + savepath_mid)
            mid_data.to_zarr(savepath_mid, mode='w', safe_chunks=False)

            # End-century
            end_data = ds.sel(time=slice('2071-01-01', '2100-12-31'))
            # Store in zarr for later use
            print("Saving file: " + savepath_end)
            end_data.to_zarr(savepath_end, mode='w', safe_chunks=False)

        mid_data = xr.open_zarr(savepath_mid)
        end_data = xr.open_zarr(savepath_end)

        # Append to list
        middata_thismod.append(mid_data)
        enddata_thismod.append(end_data)

    # Make an xarray object with data from each ensemble member
    histdata_thismod = xr.concat(histdata_thismod, dim="member")
    middata_thismod = xr.concat(middata_thismod, dim="member")
    enddata_thismod = xr.concat(enddata_thismod, dim="member")

    # Append that xarray object to the total list
    data_hist.append(histdata_thismod)
    data_mid.append(middata_thismod)
    data_late.append(enddata_thismod)

ACCESS-CM2
s3://cadcat/loca2/ucsd/access-cm2/historical/r1i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/access-cm2/historical/r2i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/access-cm2/historical/r3i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/access-cm2/ssp370/r1i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/access-cm2/ssp370/r2i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/access-cm2/ssp370/r3i1p1f1/day/tasmax/d03/
CESM2-LENS
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r10i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r1i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r2i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r3i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r4i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r5i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r6i1p1f1/day/tasmax/d03/
s3://cadcat/loca2/ucsd/cesm2-lens/historical/r7i1p1f1/day/tasmax/d03/
s3://cadc

In [6]:
# Concatenate into xarray object
# data_hist = xr.concat(data_hist, dim="model")
# data_mid = xr.concat(data_mid, dim="model")
# data_late = xr.concat(data_late, dim="model")

In [8]:
# Calculate the temperature threshold needed to define heat waves

# Threshold definition used in Fourth Assessment: 98th percentile of observed,
# historical (1961-1990) daily maximum temperatures between April 1 and Oct 31.
# To make things simpler, I'll use the historical period defined in the author
# guidance: 1981-2010
# Also not sure if they actually used observations to derive the thresholds in
# the previous report, will have to check with other authors

if calc_prctile == 1:
    hist_thr = []
    for m in range(len(models)-1):
        print(models[m])
        histdata_thismod = data_hist[m]

        # Select only April-October
        histdata_aproct = histdata_thismod.sel(time=histdata_thismod.time.dt.month.isin([4, 5, 6, 7, 8, 9, 10]))

        # Concatenate all data arrays along the time dimension
        print("Concatenating data along the 'time' dimension for this model")
        all_data_hist = histdata_aproct.stack(time_member=("time", "member"))

        # Calculate the 98th percentile along the 'time_member' dimension (which is a combination of 'time' and 'member')
        print("Calculating the 98th percentile along the 'time_member' dimension")
        combined_dataset = all_data_hist.chunk({'time_member': -1})
        percentile_98th = combined_dataset.quantile(0.98, dim='time_member')
        hist_thr.append(percentile_98th)

    # Make xarray object with heat wave threshold temperatures
    hist_thr = xr.concat(hist_thr, dim="model")

    # Save xarray object for later use
    savepath_hw = data_savepath + 'allmodels_98thprctile_tasmax.zarr'  
    print("Saving file: " + savepath_hw)
    hist_thr.to_zarr(savepath_hw, mode='w', safe_chunks=False)

else:
    savepath_hw = data_savepath + 'allmodels_98thprctile_tasmax.zarr'  
    hist_thr = xr.open_zarr(savepath_hw)

In [None]:
# Sanity check: plot 98th percentile thresholds for a single model
#m=0
#print(models[m])
#histthr_thismod = hist_thr.isel(model=m)
#
#proj = ccrs.PlateCarree(central_longitude=180)
#
#fig = plt.figure(figsize=(15, 5))
#ax = plt.axes(projection=proj)
#ax.set_extent([-124, -116, 32, 37.5], crs=ccrs.PlateCarree())
#ax.gridlines()
#im=plt.pcolor(histthr_thismod.lon,histthr_thismod.lat,histthr_thismod.tasmax,
#              transform=ccrs.PlateCarree(),cmap='RdBu_r')#
#plt.title('98th percentile of temperature')
#coast_10m = cfeature.NaturalEarthFeature('physical', 'land', '10m',edgecolor='k', facecolor='none')
#ax.add_feature(coast_10m)
#ax.add_feature(cfeature.STATES, linewidth=0.5)
#ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,linewidth=2, color='gray', alpha=0.5, linestyle='--')
#plt.colorbar(im)

In [72]:
# Define threshold exceedance function
def exceed_length(thresh, time_series):
   # Create a boolean mask where tasmax > threshold
    mask = time_series > thresh
    # Convert to numerical values
    mask = mask.astype(int)

    # Identify the periods of consecutive True values
    diff_ts = mask.diff("time", 1).compute()

    exceedance_start = diff_ts.where(diff_ts == 1, drop=True).time
    exceedance_end = diff_ts.where(diff_ts == -1, drop=True).time

    # If the time series starts with a True (exceedance), we add it to the start
    if mask.values[0]:
        exceedance_start = xr.concat(mask.time[0], exceedance_start, dim="time")

    # If the time series ends with a True (exceedance), we add it to the end
    if mask.values[-1]:
        exceedance_end = xr.concat(exceedance_end, mask.time[-1], dim="time")

    meanlen = []
    # Now calculate the length of each exceedance period
    for start, end in zip(exceedance_start, exceedance_end):
        exceedance_length = end - start + 1
        # Store the exceedance length
        meanlen.append(exceedance_length)

    # Average over all events
    if len(meanlen) > 0:
        meanlen = xr.concat(meanlen, dim="event").mean(dim="event") / (86400*1e9)
    else:
        meanlen = mask[0]
        meanlen.values = np.nan

    # Convert from nanoseconds to days and return average value
    #return meanlen.values / (86400 * 1e9)
    return meanlen.astype('timedelta64[D]').item()

In [43]:
thresh=histthr_thismod_repeated[0,:,200,450]

time_series = histdata_thismod.isel(member=mem)[:,200,450]

test = exceed_length(thresh, time_series)

  return meanlen.astype('timedelta64[D]').item() / (86400*1e9)


In [None]:
# Now go through each model and each ensemble, determine the length of
# heat waves during each time period
import warnings
warnings.filterwarnings("ignore")

hwlen_arr = []

#for m in range(len(models)):
for m in range(0, 1):
    print(models[m])
    histdata_thismod = data_hist[m].tasmax
    histthr_thismod = hist_thr.isel(model=m).tasmax

    # Chunk the time dimension appropriately 
    histdata_thismod = histdata_thismod.chunk({'time': 100})  # Example chunk size

    hwlen_thismod = []
    for mem in range(0, 1):
        # Get the data for the current member
        member_data = histdata_thismod.isel(member=mem)

        # Expand histthr_thismod along the time dimension to match member_data
        histthr_thismod_expanded = histthr_thismod.expand_dims(time=member_data.time)
        print('data expanded')

        # Repeat histthr_thismod along the time dimension to match histdata_thismod
        histthr_thismod_repeated = histthr_thismod_expanded.broadcast_like(member_data)
        print('data repeated')

        # Chunk the time dimension appropriately 
        histthr_thismod_repeated = histthr_thismod_repeated.chunk({'time': 100})  # Same chunking for threshold

        # Apply the function pointwise over the grid
        heatwave_lens = []
        for lat_idx in range(histdata_thismod.shape[1]):
            for lon_idx in range(histdata_thismod.shape[2]):
                threshold = histthr_thismod_repeated.isel(lat=lat_idx, lon=lon_idx)
                data = member_data.isel(lat=lat_idx, lon=lon_idx)

                # Compute the heatwave length for this grid point
                heatwave_len = exceed_length(threshold, data)
                heatwave_lens.append(heatwave_len)

        # Store the results as an xarray.DataArray
        heatwave_lens_array = np.array(heatwave_lens).reshape(histdata_thismod.shape[1:])
        hwlen_thismod.append(heatwave_lens_array)

    ## Use vectorized function to determine thresholds
    # hwlen_thismod = []
    # for mem in range(0, 1):
    #    heatwave_lens = xr.apply_ufunc(exceed_length, histthr_thismod_repeated, histdata_thismod.isel(member=mem), vectorize=True, input_core_dims=[['time'], ['time']], dask="allowed")
    #    hwlen_thismod.append(heatwave_lens)

    # Concatenate the results across the member dimension and compute the mean
    heatwave_lens = xr.concat(hwlen_thismod, dim="member").mean(dim="member")
    hwlen_arr.append(heatwave_lens)

ACCESS-CM2
data expanded
data repeated


In [None]:
member_data

In [None]:
# Make xarray object, take the mean across models
hwlen_arr = xr.concat(hwlen_arr, dim="model").mean(dim="model")

In [None]:
# Plot results: historical 99th percentile precip
proj = ccrs.PlateCarree(central_longitude=180)

fig = plt.figure(figsize=(15, 5))
ax = plt.axes(projection=proj)
ax.set_extent([-124, -116, 32, 37.5], crs=ccrs.PlateCarree())    
ax.gridlines()
im=plt.pcolor(hwlen_arr.lon,hwlen_arr.lat,hwlen_arr.tasmax,transform=ccrs.PlateCarree(),cmap='BrBG')#,vmin=-100,vmax=100)#
plt.title('Heat wave length (days)')
coast_10m = cfeature.NaturalEarthFeature('physical', 'land', '10m',edgecolor='k', facecolor='none')
ax.add_feature(coast_10m)
ax.add_feature(cfeature.STATES, linewidth=0.5)
ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,linewidth=2, color='gray', alpha=0.5, linestyle='--')
plt.colorbar(im) 