# CMIP6 Import Data

**Following steps are included in this script:**

1. Open dkrz cataloge, save CMIP6 catalog and browse it
2. Load important hydroecological data for each model separately as kernel dies otherwise
3. Merge data into one file for each model
4. Load already saved netcdf file and update with new variables
4. Save data to netcdf files

In [1]:
# ========== Packages ==========
import xarray as xr
import intake
import dask
import os
import pandas as pd
import numpy as np

In [2]:
# ========== Run all functions at the bottom ========

## 1. Open dkrz catalog and save CMIP6 catalog

In [2]:
# ----Open dkrz catalog----
cat = intake.open_catalog(["https://dkrz.de/s/intake"])

# ----Save CMIP6 catalog----
cat_cmip6 = cat.dkrz_cmip6_disk

## 2. Load important hydroecological data


#### Attributes
| 'source_id' | 'member_id' | 'variable_id' |
|:-----------:|:-----------:|:-----------:|
| 'BCC-CSM2-MR', 'CESM2', 'CNRM-CM6-1-HR', 'NorESM2-MM', 'SAM0-UNICON', 'TaiESM1'  | 'r1i1p1f1', 'r1i1p1f2' (for CNRM)   | 'pr', 'mrro', 'mrros', 'evspsbl', 'evspsblsoi', 'evspsblveg', 'tran', 'mrso', 'mrsos', 'mrsol', 'huss', 'hurs', 'lai', 'gpp', 'npp' | 

In [13]:
# ----Define attributes----
attrs = dict(
    
    experiment_id="historical", #ssp126 historical
  #  member_id=['r1i1p1f1' ,'r1i1p1f2', 'r1i1p2f1', 'r1i1p1f3'], #'r1i1p1f2'],,'r1i1p1f2', 'r1i1p2f1', 'r1i1p1f3'
   source_id=['CNRM-ESM2-1'], #'TaiESM1', 'AWI-ESM-1-1-LR', 'BCC-CSM2-MR', 'BCC-ESM1', 'CanESM5', 'CNRM-CM6-1', 'CNRM-CM6-1-HR', 'CNRM-ESM2-1','UKESM1-0-LL', 'CESM2', 'CESM2-FV2', 'CESM2-WACCM', 'NorESM2-MM'], 
    # table_id =['Amon', 'Lmon', 'Emon'], #'Amon', 'Lmon', 'Emon'
    variable_id=[
              #  'tas',
              #  'ps', #surface pressure
          #       'pr', # CESM2 has problems loading pr with other Amon data
          #      'mrro', 
               # 'mrros', 
          #      'evspsbl', 
               # 'evspsblsoi', 
               # 'evspsblveg', 
              #  'tran', 
               # 'mrso', 
               # 'mrsos', 
              #  'mrsol', 
              #  'huss', 
               #  'hurs',  # TaiESM1 has hurs only in daily resolution
              #  'lai', 
              # 'gpp', 
              #  'npp'
                'tsl'
    #     'sftlf' # land area fraction
    ]
  #  ,version = ['v20200623', 'v20200624'] #TaiESM1 has two versions for gpp. I select the newer version.
)

In [4]:
# ----Save data selection----
selection = cat_cmip6.search(**attrs)
#selection = cat_cmip6.search(require_all_on=["source_id"], **attrs) #require_all_on defines that source ID must include all important variables

In [5]:
# ----Set properties of pandas tables ----
pd.set_option('display.max_colwidth', None) #pd.reset_option('display.max_colwidth')
pd.set_option('display.max_rows', None) #pd.reset_option('display.max_rows', None)

# ----Print table with different attributes of selected data----
selection.df.groupby(
    [
       # "grid_label",
        "institution_id",
        "source_id",
     #   "version",
    #    'member_id',
       # "time_range",
        'experiment_id',
       
      #   'table_id'
    #    'variable_id'
    ]
)['variable_id'].unique().apply(list).to_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,variable_id
institution_id,source_id,experiment_id,Unnamed: 3_level_1
CNRM-CERFACS,CNRM-ESM2-1,historical,[tsl]


In [None]:
# ========= Load selection in dictionary ========== (I always have to run the 'define attrs' cell again...)

with dask.config.set(**{"use_cftime": True, "decode_times": True, 'consolidated': True}):
    ds_dict = selection.to_dataset_dict(preprocess=pre_preprocessing) 


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.source_id.experiment_id.table_id.grid_label'


## 3. Merge datasets into one file for each model

In [None]:
# =========== Drop redundant coordinates and variables ================

# Define redundant coordinates and variables
drop_list = ['member_id','nbnd', 'bnds', 'height', 'depth', 'lat_bnds', 'lon_bnds', 'time_bnds', 'time_bounds', 'depth_bnds', 'sdepth_bounds', 'depth_bounds', 'hist_interval', 'axis_nbounds'] #depth is not dropped for datasets with variable mrsol

# Drop redundant coordinates and variables
ds_dict = drop_redundant(ds_dict, drop_list)

In [None]:
ds_dict.keys()

In [None]:
ds_dict[list(ds_dict.keys())[0]]

In [None]:
# =========== Merge data if loading all data of one model is not possible ==============
# Name of second dictionary ds_dict_

# Only use this command if loading at once is not possible

#ds_dict[f"{list(ds_dict_.keys())[0]}_"]=ds_dict_[list(ds_dict_.keys())[0]]

In [None]:
# =========== Merge datasets with different table_id and same source_id ================

ds_dict = merge_source_id_data(ds_dict)

In [None]:
# =========== Compute montly mean for daily datasets (optional: merge with rest of the datasets) =============

#ds_dict_monthly = daily_to_monthly(ds_dict) # optionally include ,ds_dict_merged

In [None]:
# =========== Check dictionary for consistency =============
ds_dict[list(ds_dict.keys())[0]]

## 4. Load already saved netcdf file and update with new variables

In [None]:
# ========= Load model which needs to be updated ==============
experiment_id = 'historical'
source_id = ['BCC-CSM2-MR']
savepath = f'../../data/CMIP6/{experiment_id}/preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# ========= Create a helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

# ========= Create dictionary using a dictionary comprehension and Dask =======
ds_dict_merge, = dask.compute({model: open_dataset(os.path.join(savepath, f'CMIP.{model}.{experiment_id}.nc'))
                        for model in source_id})

In [None]:
# =========== Check dictionary for consistency =============

ds_dict_merge[list(ds_dict_merge.keys())[0]]

In [None]:
# ========= Create a dictionary with the computed monthly mean and the loaded model data ==========

ds_dict_all = {}
ds_dict_all['dataset_one'] = ds_dict_merge[list(ds_dict_merge.keys())[0]]
ds_dict_all['dataset_two'] = ds_dict_merged[list(ds_dict_merged.keys())[3]]

In [None]:
# ========= Merge data ======================

ds_dict_all = merge_source_id_data(ds_dict_all)

In [None]:
# =========== Check dictionary for consistency =============

ds_dict_all[list(ds_dict_all.keys())[0]]

## 5. Store netcdf files

In [None]:
folder='preprocessed'

for key in ds_dict_merged.keys():
    ds_in = ds_dict_merged[key]
    filename = f'CMIP.{ds_in.source_id}.{ds_in.experiment_id}_tsl.nc'
    savepath = f'../data/CMIP6/{ds_in.experiment_id}/{folder}'
    nc_out = os.path.join(savepath, filename)
    os.makedirs(savepath, exist_ok=True) 
    if os.path.exists(nc_out):
        os.remove(nc_out)
        print(f"File  with path: {nc_out} removed")
    # Save to netcdf file
    with dask.config.set(scheduler='threads'):
        ds_in.to_netcdf(nc_out)

In [None]:
# =========== Store file and remove any former one ==========

nc_out = save_file(ds_dict, folder='raw')

In [None]:
# =========== Check stored file ==============

xr.open_dataset(nc_out)

# Functions

In [6]:
from xmip.preprocessing import correct_lon, correct_units, parse_lon_lat_bounds, maybe_convert_bounds_to_vertex, maybe_convert_vertex_to_bounds

def pre_preprocessing(ds: xr.Dataset) -> xr.Dataset:
    """
    Preprocesses a CMIP6 dataset
    
    Parameters:
    ds (xr.Dataset): Input dataset
    
    Returns:
    xr.Dataset: Preprocessed dataset
    """
    
    def correct_coordinates(ds: xr.Dataset) -> xr.Dataset:
        """
        Corrects wrongly assigned data_vars to coordinates

        Parameters:
        ds (xr.Dataset): Input dataset

        Returns:
        xr.Dataset: Dataset with corrected coordinates
        """
        for co in ["lon", "lat"]:
            if co in ds.variables:
                ds = ds.set_coords(co)

        return ds.copy(deep=True)
 
    ds = correct_coordinates(ds)
    ds = correct_units(ds) 
    ds = parse_lon_lat_bounds(ds)
    ds = maybe_convert_bounds_to_vertex(ds)
    ds = maybe_convert_vertex_to_bounds(ds)
    return ds.copy(deep=True)

In [7]:
def drop_redundant(ds_dict, drop_list): 
    """
    Remove redundant coordinates and variables from datasets in a dictionary.

    Parameters:
    ds_dict (dict): Dictionary containing dataset names as keys and xarray.Dataset objects as values.
    drop_list (list): List of redundant coordinate or variable names to be removed from the datasets.

    Returns:
    dict: Dictionary with the same keys as the input ds_dict and modified xarray.Dataset objects with redundant elements removed.
    """
    for ds_name, ds_data in ds_dict.items():
        
        if 'sdepth' in ds_data.coords:
            if 'depth' in ds_data.dims:
                ds_data = ds_data.drop('depth')  
            ds_data = ds_data.rename({'sdepth': 'depth'})
            print(f'sdepth changed to depth for model {ds_data.source_id}')
   
        
        if 'mrsol' in ds_data and 'depth' in drop_list or 'tsl' in ds_data and 'depth' in drop_list:
            drop_list.remove('depth')
                      
        for coord in drop_list:
            if coord in ds_data.coords:
                ds_data = ds_data.drop(coord).squeeze()
                print(f'Dropped coordinate: {coord}')
            if coord in ds_data.variables:
                ds_data = ds_data.drop_vars(coord).squeeze()
                print(f'Dropped variable: {coord}')
            # Update the dictionary with the modified dataset
            ds_dict[ds_name] = ds_data
    
    return ds_dict

In [8]:
def replace_coordinates(new_coords, replace_coords):
    """
    Helper funtion to replace coordinates before merging.
    
    Args:
        new_coords (xr dataset): A dictionary of xarray datasets which gives the new coordinates.
        replace_coords (xr dataset): A dictionary of xarray datasets which coordinates will be replaced.

    Returns:
        replace_coords (xr dataset): The replace dictionary with the new coordinates copied from new_coords.
    """
    
    for coord in ['lon', 'lat', 'time']:
        if not new_coords[coord].equals(replace_coords[coord]):
            replace_coords[coord] = new_coords[coord]
        else:
            pass
    
    return replace_coords

In [9]:
def merge_source_id_data(ds_dict):
    """
    Merge datasets with the same source_id (name of the CMIP6 model) as CMIP6 data is stored in different table id's. This function is mainly used to merge two 
    different xarray datasets for 'table_id' Amon and Lmon into a single xarray dataset as this makes future investigations easier. Other table_id's
    can also be merged; however, be careful when the same variable exists in both datasets.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset 
                        and each value is the dataset itself.

    Returns:
        dict: A merged dictionary with a single dataset for each CMIP6 model/source_id.
    """
    
    merged_dict = {}
    for dataset_name, dataset in ds_dict.items():
        source_id = dataset.attrs['source_id']
        table_id = dataset.attrs['table_id']
        print(f"Merging dataset '{dataset_name}' with source_id '{source_id}' and table_id '{table_id}'...")
       
        if source_id in merged_dict:
            if source_id == merged_dict[source_id].attrs['source_id'] and table_id != merged_dict[source_id].attrs['table_id']:
                merg_model_name = merged_dict[source_id].attrs['intake_esm_dataset_key']
                merg_model_table_id = merged_dict[source_id].attrs['table_id']
                 
                # Replace coordinates lat, lon, time of dataset only when different to datasets in merged_dict
                dataset = replace_coordinates(merged_dict[source_id], dataset)

                # Merge data    
                with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                    merged_dict[source_id] = xr.merge([merged_dict[source_id], dataset])

                if len(list(merged_dict.keys())) == 1:
                    print(f"Datasets '{merg_model_name}' ('{merg_model_table_id}') and '{dataset_name}' ('{table_id}') are merged to 'merged_dict' with key '{source_id}'.")
                else:
                    print(f"Datasets '{dataset_name}' ('{table_id}') is merged with 'merged_dict'.")

        else:
            merged_dict[source_id] = dataset
            print(f"Dataset '{dataset_name}' ('{table_id}') is saved in 'merged_dict'.")

    return merged_dict

In [10]:
def daily_to_monthly(ds_dict_daily, ds_dict_merged=1):
    """
    Compute mothly data from daily data and merge the variable with the rest of the model. Reason for this function is that the TaiESM1 model has no
    monthly data for the variable hurs. Can be applied to any model that has a similar issue.

    Args:
        ds_dict_daily (dict): Dictionary of xarray dataset with missing variable in daily resolution.
        ds_dict_merged (dict, optional): Dictionary of xarray datasets with all variables in monthly resolution except of the missing variable. 
                                            If no dict is passed to the function, simply the monthly mean is calculated. 

        
    Returns:
        ds_dict_monthly: A dictionary with dataarrays in monthly resolution.
            
            or
        
        ds_dict_all: A merged dictionary with dataarrays of each variable of the respective model.
    """
    ds_dict_monthly = {}
    
    if ds_dict_merged==1:
        for ds_name, ds_data in ds_dict_daily.items():
            source_id = ds_data.attrs['source_id']
            ds_data = ds_data.sortby('time')
            # Compute monthly values
            ds_data_mon = ds_data.resample(time='1MS').mean()
            # Put in dict
            ds_dict_monthly[source_id] = ds_data_mon
        
        return ds_dict_monthly
    
    else:
        for ds_name, ds_data in ds_dict_daily.items():
            source_id = ds_data.attrs['source_id']
            ds_data = ds_data.sortby('time')
            # Compute monthly values
            ds_data_mon = ds_data.resample(time='1MS').mean()
            # Replace coordinates of dataset when different to datasets in ds_dict_merged
            ds_data_mon = replace_coordinates(ds_dict_merged[source_id], ds_data_mon)
            # Put in dict
            ds_dict_monthly[source_id] = ds_data_mon

        # Merge computed monthly average with rest of model dict 
        ds_dict_all = {}

        for dataset_name, dataset in ds_dict_monthly.items():
            with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                ds_dict_all[dataset_name] = xr.merge([dataset, ds_dict_merged[dataset_name]])
        
        return ds_dict_all

In [11]:
def save_file(save_file, folder):
    """
    Save files as netCDF.

    Args:
        savefile (dict or dataset): Dictionary of xarray datasets or dataset.
        folder (string): Name of folder data is saved in.
        

    Returns:
        nc_out: Path were data is saved in.
    """
    
    if type(save_file) == dict:
        for key in save_file.keys():
            ds_in = save_file[key]
            filename = f'CMIP.{ds_in.source_id}.{ds_in.experiment_id}.nc'
            savepath = f'../data/CMIP6/{ds_in.experiment_id}/{folder}'
            nc_out = os.path.join(savepath, filename)
            os.makedirs(savepath, exist_ok=True) 
            if os.path.exists(nc_out):
                os.remove(nc_out)
                print(f"File  with path: {nc_out} removed")
            # Save to netcdf file
            with dask.config.set(scheduler='threads'):
                ds_in.to_netcdf(nc_out)
            
    elif type(save_file) == xr.core.dataset.Dataset:
            filename = f'CMIP.{save_file.source_id}.{save_file.experiment_id}.nc'
            savepath = f'../data/CMIP6/{save_file.experiment_id}/{folder}'
            nc_out = os.path.join(savepath, filename)
            os.makedirs(savepath, exist_ok=True) 
            if os.path.exists(nc_out):
                os.remove(nc_out)
                print(f"File  with path: {nc_out} removed")
            # Save to netcdf file
            with dask.config.set(scheduler='threads'):
                ds_in.to_netcdf(nc_out)
    else:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")
        
    return nc_out