# Apply MDM and QDM to CMIP6 daily temperatures
- Daily temperature data is loaded using the OSDF protocol and discovered using an intake-ESM catalog
- We apply three different bias-correction techniques to pairs of CMIP6 models, treating one of the models in the pair as pseudo-observations
- The bias-correction techniques are Moment Delta Mapping, Quantile Delta Mapping+ sort and a shift in DMT

In [1]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
import xarray_regrid 
#import seaborn as sns
import s3fs
import cftime
import pandas as pd

  from tqdm.autonotebook import tqdm


In [2]:
import dask 
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report

In [3]:
# Decide whether to re-calculate everything
RECALC = True
#
pi_year  = 1865
eoc_year = 2085
#
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
#
lustre_scratch   = "/lustre/desc1/scratch/harshah"
gdex_url     =  'https://data.gdex.ucar.edu/'
catalog_url = gdex_url +  'd850001/catalogs/osdf/cmip6-aws/cmip6-osdf-zarr.json'
# catalog_url = 'https://cmip6-pds.s3.amazonaws.com/pangeo-cmip6.json'
gdex_data    = '/gdex/data/special_projects/harshah/osdf_data/'
#
tmean_path  = gdex_data + 'tmean/'
tmax_path   = gdex_data + 'tmax/'
tmin_path   = gdex_data + 'tmin/'

In [4]:
# Create a PBS cluster object
cluster = PBSCluster(
    job_name = 'dask-wk25-mdm',
    cores = 1,
    memory = '8GiB',
    processes = 1,
    local_directory = lustre_scratch+'/dask/spill',
    log_directory = lustre_scratch + '/dask/logs/',
    resource_spec = 'select=1:ncpus=1:mem=8GB',
    queue = 'casper',
    walltime = '5:00:00',
    interface = 'ext'
)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44233 instead


In [5]:
# Create the client to load the Dashboard
client = Client(cluster)

In [6]:
n_workers =8
cluster.scale(n_workers)
client.wait_for_workers(n_workers = n_workers)
cluster

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/44233/status,Workers: 8
Total threads: 8,Total memory: 64.00 GiB

0,1
Comm: tcp://128.117.208.96:46363,Workers: 8
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/44233/status,Total threads: 8
Started: Just now,Total memory: 64.00 GiB

0,1
Comm: tcp://128.117.208.177:45503,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/34123/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.177:44513,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-8hgrjkd3,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-8hgrjkd3
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 2.0%,Last seen: Just now
Memory usage: 126.80 MiB,Spilled bytes: 0 B
Read bytes: 8.88 kiB,Write bytes: 16.75 kiB

0,1
Comm: tcp://128.117.208.176:35901,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/45381/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.176:36873,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-gr8dznrz,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-gr8dznrz
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 6.0%,Last seen: Just now
Memory usage: 124.84 MiB,Spilled bytes: 0 B
Read bytes: 182.35 kiB,Write bytes: 33.98 kiB

0,1
Comm: tcp://128.117.208.176:36827,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/44089/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.176:38703,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-ytz13nm3,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-ytz13nm3
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 4.0%,Last seen: Just now
Memory usage: 126.93 MiB,Spilled bytes: 0 B
Read bytes: 24.42 kiB,Write bytes: 29.61 kiB

0,1
Comm: tcp://128.117.208.177:39887,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/37797/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.177:40663,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-9zialadk,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-9zialadk
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 4.0%,Last seen: Just now
Memory usage: 122.89 MiB,Spilled bytes: 0 B
Read bytes: 182.55 kiB,Write bytes: 0.92 MiB

0,1
Comm: tcp://128.117.208.178:34743,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/46503/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.178:33491,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-ionxasxh,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-ionxasxh
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 51.62 MiB,Spilled bytes: 0 B
Read bytes: 5.55 MiB,Write bytes: 2.91 MiB

0,1
Comm: tcp://128.117.208.177:35667,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/46679/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.177:38027,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-gtqltgk1,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-gtqltgk1
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 2.0%,Last seen: Just now
Memory usage: 124.93 MiB,Spilled bytes: 0 B
Read bytes: 9.18 kiB,Write bytes: 16.31 kiB

0,1
Comm: tcp://128.117.208.176:37345,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/42863/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.176:44025,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-kb88d1h0,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-kb88d1h0
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 4.0%,Last seen: Just now
Memory usage: 128.91 MiB,Spilled bytes: 0 B
Read bytes: 175.36 kiB,Write bytes: 28.19 kiB

0,1
Comm: tcp://128.117.208.177:38735,Total threads: 1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/41273/status,Memory: 8.00 GiB
Nanny: tcp://128.117.208.177:35781,
Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-48zunia2,Local directory: /lustre/desc1/scratch/harshah/dask/spill/dask-scratch-space/worker-48zunia2
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 4.0%,Last seen: Just now
Memory usage: 126.88 MiB,Spilled bytes: 0 B
Read bytes: 181.90 kiB,Write bytes: 0.99 MiB


In [7]:
# calculate global means
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'quantile'}
    return (ds * weight).mean(other_dims)

def detrend_data(ds, central_year):
    # Assumes that the ds has coordinates day, year and member.
    
    #Fit a linear fuction and extract slope
    pcoeffs = ds.polyfit(dim='year',deg=1)
    slope   = pcoeffs.polyfit_coefficients.sel(degree=1)
    
    #Calculate trend
    ds_trend   = slope*(ds['year']- central_year)
    
    #Detrend by subtracting the trend from the data
    ds_detrended = ds  - ds_trend
    
    return ds_detrended

## Section 2: Load Data

In [8]:
col = intake.open_esm_datastore(catalog_url)
col

Unnamed: 0,unique
activity_id,18
institution_id,36
source_id,88
experiment_id,170
member_id,657
table_id,37
variable_id,709
grid_label,10
zstore,522217
dcpp_init_year,60


In [9]:
var_name    = 'tas'
folder_path = tmean_path
variable    = ['tas'] #Other variables of interest: 'tasmax', 'tasmin'

In [10]:
# 2. Search for daily temperature 
expts = ['ssp370','historical']

query = dict(
    experiment_id=expts,
    table_id='day',
    variable_id= variable,
    member_id = 'r1i1p1f1',
    #activity_id = 'CMIP',
    
)

col_subset = col.search(require_all_on=["source_id"], **query)

col_subset.df.groupby("source_id")[["experiment_id", "variable_id", "table_id","member_id"]].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id,member_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ACCESS-CM2,2,1,1,1
AWI-CM-1-1-MR,2,1,1,1
BCC-CSM2-MR,2,1,1,1
BCC-ESM1,2,1,1,1
CESM2-WACCM,2,1,1,1
CMCC-CM2-SR5,2,1,1,1
CMCC-ESM2,2,1,1,1
CanESM5,2,1,1,1
EC-Earth3,2,1,1,1
EC-Earth3-AerChem,2,1,1,1


In [11]:
df = col_subset.df
# model_counts = df.groupby('source_id').size()
# print(model_counts)
df.head()

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year,version
0,CMIP,CSIRO-ARCCSS,ACCESS-CM2,historical,r1i1p1f1,day,tas,gn,osdf:///aws-opendata/us-west-2/cmip6-pds/CMIP6...,,20191108
1,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp370,r1i1p1f1,day,tas,gn,osdf:///aws-opendata/us-west-2/cmip6-pds/CMIP6...,,20191108
2,CMIP,AWI,AWI-CM-1-1-MR,historical,r1i1p1f1,day,tas,gn,osdf:///aws-opendata/us-west-2/cmip6-pds/CMIP6...,,20181218
3,ScenarioMIP,AWI,AWI-CM-1-1-MR,ssp370,r1i1p1f1,day,tas,gn,osdf:///aws-opendata/us-west-2/cmip6-pds/CMIP6...,,20190529
4,CMIP,BCC,BCC-CSM2-MR,historical,r1i1p1f1,day,tas,gn,osdf:///aws-opendata/us-west-2/cmip6-pds/CMIP6...,,20181126


In [12]:
df['activity_id'].unique()

array(['CMIP', 'ScenarioMIP', 'AerChemMIP'], dtype=object)

In [13]:
# Keep only rows with CMIP (historical) or ScenarioMIP (ssp370) for consistency. 
df_filtered = col_subset.df[col_subset.df['activity_id'].isin(['CMIP', 'ScenarioMIP'])]

print("Filtered DataFrame shape:", df_filtered.shape)
# print("Filtered activity_id values:", df_filtered['activity_id'])

Filtered DataFrame shape: (53, 11)


In [14]:
df_filtered.groupby("source_id")[["experiment_id", "variable_id", "table_id","activity_id"]].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id,activity_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ACCESS-CM2,2,1,1,2
AWI-CM-1-1-MR,2,1,1,2
BCC-CSM2-MR,2,1,1,2
BCC-ESM1,1,1,1,1
CESM2-WACCM,2,1,1,2
CMCC-CM2-SR5,2,1,1,2
CMCC-ESM2,2,1,1,2
CanESM5,2,1,1,2
EC-Earth3,2,1,1,2
EC-Earth3-AerChem,2,1,1,2


In [None]:
%%time
# dsets = col_subset.to_dataset_dict(storage_options={'anon': 'True'})
dsets = col_subset.to_dataset_dict()
print(f"\nDataset dictionary keys:\n {dsets.keys()}")


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'


In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop_vars(drop_vars)

def open_dset(df):
    assert len(df) == 1
    
    # Force anonymous access for public datasets
    storage_options = {'anon': True}
    # ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0], **storage_options),consolidated=True) #For s3fs protocol
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True) #Use for PelicanFS
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)

In [None]:
%%time
# Trigger computation
dsets_ = dask.compute(dict(dsets))[0]

In [None]:
#Define coarse grid to regrid on 1 *1 degree card
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.5)),
                    'lon': (['lon'], np.arange(0, 361, 1.5))})
ds_out

In [None]:
def drop_feb29(ds):
    # Check if the dataset's calendar is not '360_day'
    calendar = ds.time.encoding.get('calendar', None)
    print(ds.attrs['source_id'],calendar)
    if calendar != '360_day':
        ds = ds.convert_calendar('365_day')
    return ds


def to_daily(ds):
    # Check and deal with different datetime types
    if isinstance(ds['time'].values[0], np.datetime64):
        pass
    elif isinstance(ds['time'].values[0], cftime.datetime):
        pass
    else:
        # convert time coordinate to datetime64 objects
        ds['time'] = ds['time'].astype('datetime64[ns]')
    year      = ds.time.dt.year
    dayofyear = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "dayofyear")).unstack("time")  


def extract_data(ds):
    """
    Extract data from the dataset 'ds' for specific time and spatial range.

    Parameters:
    - ds (xarray.Dataset): Input dataset

    Returns:
    - xarray.Dataset: Dataset subsetted for required years and the specified space and time range.
    """    

    subset1 = ds.sel(year=slice(1850, 1879))
    subset2 = ds.sel(year=slice(2071, 2100))
    
    subset = xr.concat([subset1, subset2], dim='year')  

    return subset

def is_leap(year):
    """Check if a year is a leap year."""
    return (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0))


In [None]:
quants = np.linspace(0,1.0,30)

def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def regrid(ds, ds_out):
    experiment_id = ds.attrs['experiment_id']
    source_id     = ds.attrs['source_id']
    ds_new   = ds.regrid.nearest(ds_out)
    
#     #Assign back attributes as regirdder would have deleted attributes 
    ds_new.attrs['experiment_id'] = experiment_id
    ds_new.attrs['source_id'] = source_id
    
    #print(ds_new.attrs['experiment_id'],ds_new.attrs['source_id'])
    #print(ds_new)
    return ds_new

def process_data(ds, quantiles=quants):
    ds = ds.pipe(drop_feb29).pipe(to_daily).pipe(extract_data)
    
    if len(ds['year']) == 0:
        print("The dataset is empty. Skipping...")
        return None
    
    if len(ds['dayofyear'])<365:
        print('The dataset has less than 365 days. Skipping ..')
        return None
    
    # # Remove 'time' coordinate
    # ds = ds.set_index(time=("year", "dayofyear")).unstack("time")  
    return (ds.pipe(regrid, ds_out=ds_out))

## Section 3: Computations. 
- Evalulate these 

In [None]:
%%time
if RECALC:
    with progress.ProgressBar():
    
        expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                               coords={'experiment_id': expts})
    
        # Initialize an Empty Dictionary for Aligned Datasets:
        dsets_aligned = {}
    
        # Iterate Over dsets_ Dictionary:
    
        for k, v in tqdm(dsets_.items()):
            # Initialize a dictionary for this source_id
            dsets_aligned[k] = {}
            
            skip_source_id = False
    
            for expt in expts:
                ds = v[expt].pipe(process_data)
    
                # Check if the dataset is empty and skip this source_id if so
                if ds is None:
                    print(f"Skipping {expt} for {k} because the dataset is empty")
                    skip_source_id = True
                    break
                
                # Store the dataset in the dictionary
                # dsets_aligned[k][expt] = ds
                # Compute the dataset and store it in the dictionary
                dsets_aligned[k][expt] = ds.compute()
                print(dsets_aligned[k][expt])
    
            if skip_source_id:
                del dsets_aligned[k]
                continue

In [None]:
# dsets_aligned.keys()

In [None]:
%%time
if RECALC:
    source_ids = list(dsets_aligned.keys())
    source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})
    final_ds_pi = xr.concat([ds['historical'].reset_coords(drop=True)
                                     for ds in dsets_aligned.values()],
                                    dim=source_da)
    
    final_ds_eoc = xr.concat([ds['ssp370'].reset_coords(drop=True)
                                 for ds in dsets_aligned.values()],
                                dim=source_da)
    final_ds_eoc

In [None]:
%%time
final_ds_pi.to_zarr(folder_path  +'cmip6_pi_daily.zarr',mode='w')
final_ds_eoc.to_zarr(folder_path +'cmip6_eoc_daily.zarr',mode='w')

In [None]:
final_ds_pi  = xr.open_zarr(folder_path+'cmip6_pi_daily.zarr')
final_ds_eoc = xr.open_zarr(folder_path+'cmip6_eoc_daily.zarr')
final_ds_pi  = final_ds_pi[var_name]
final_ds_eoc = final_ds_eoc[var_name]
final_ds_eoc

### Detrend data and save

In [None]:
%%time
ds_pi_det  = detrend_data(final_ds_pi,pi_year)
ds_eoc_det = detrend_data(final_ds_eoc,eoc_year)
ds_eoc_det = ds_eoc_det.chunk({'year':30,'source_id':1})
ds_pi_det = ds_pi_det.chunk({'year':30,'source_id':1})
ds_eoc_det

In [None]:
# %%time
# ds_pi_det.rename(var_name).to_dataset().to_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr',mode='w')

In [None]:
# %%time
# ds_eoc_det.rename(var_name).to_dataset().to_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr',mode='w')

### Check if detrending worked

In [None]:
ds_pi_det  = xr.open_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr')
ds_eoc_det = xr.open_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr')
#
ds_pi_det  = ds_pi_det[var_name]
ds_eoc_det = ds_eoc_det[var_name]

In [None]:
%%time
ds_pi_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_pi.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()

In [None]:
%%time
ds_eoc_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_eoc.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()