# CMIP6 Intake

**Following steps are included in this script:**

1. Open DKRZ Cataloge and Save CMIP6 Catalog
2. Load Data with all Variables via Intake.
3. Save Data as netCDF

Saving each variable separately is prefered, however, saving more variables in a single file is also possible.

In [None]:
# ========== Packages ==========
import xarray as xr
import intake
import dask
import os
import pandas as pd
import numpy as np

### Functions

In [None]:
from xmip.preprocessing import correct_lon, correct_units, parse_lon_lat_bounds, maybe_convert_bounds_to_vertex, maybe_convert_vertex_to_bounds

def pre_preprocessing(ds: xr.Dataset) -> xr.Dataset:
    """
    Preprocesses a CMIP6 dataset
    
    Parameters:
    ds (xr.Dataset): Input dataset
    
    Returns:
    xr.Dataset: Preprocessed dataset
    """
    
    def correct_coordinates(ds: xr.Dataset) -> xr.Dataset:
        """
        Corrects wrongly assigned data_vars to coordinates

        Parameters:
        ds (xr.Dataset): Input dataset

        Returns:
        xr.Dataset: Dataset with corrected coordinates
        """
        for co in ["lon", "lat"]:
            if co in ds.variables:
                ds = ds.set_coords(co)

        return ds.copy(deep=True)
 
    ds = correct_coordinates(ds)
    ds = correct_units(ds) 
    ds = parse_lon_lat_bounds(ds)
    ds = maybe_convert_bounds_to_vertex(ds)
    ds = maybe_convert_vertex_to_bounds(ds)
    return ds.copy(deep=True)

In [None]:
def replace_coordinates(new_coords, replace_coords):
    """
    Helper funtion to replace coordinates before merging.
    
    Args:
        new_coords (xr dataset): A dictionary of xarray datasets which gives the new coordinates.
        replace_coords (xr dataset): A dictionary of xarray datasets which coordinates will be replaced.

    Returns:
        replace_coords (xr dataset): The replace dictionary with the new coordinates copied from new_coords.
    """
    
    for coord in ['lon', 'lat', 'time']:
        if not new_coords[coord].equals(replace_coords[coord]):
            replace_coords[coord] = new_coords[coord]
        else:
            pass
    
    return replace_coords

In [None]:
def merge_source_id_data(ds_dict):
    """
    Merge datasets with the same source_id (name of the CMIP6 model) as CMIP6 data is stored in different table id's. This function is mainly used to merge two 
    different xarray datasets for 'table_id' Amon and Lmon into a single xarray dataset as this makes future investigations easier. Other table_id's
    can also be merged; however, be careful when the same variable exists in both datasets.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset 
                        and each value is the dataset itself.

    Returns:
        dict: A merged dictionary with a single dataset for each CMIP6 model/source_id.
    """
    
    merged_dict = {}
    for dataset_name, dataset in ds_dict.items():
        source_id = dataset.attrs['source_id']
        table_id = dataset.attrs['table_id']
        print(f"Merging dataset '{dataset_name}' with source_id '{source_id}' and table_id '{table_id}'...")
       
        if source_id in merged_dict:
            if source_id == merged_dict[source_id].attrs['source_id'] and table_id != merged_dict[source_id].attrs['table_id']:
                merg_model_name = merged_dict[source_id].attrs['intake_esm_dataset_key']
                merg_model_table_id = merged_dict[source_id].attrs['table_id']
                 
                # Replace coordinates lat, lon, time of dataset only when different to datasets in merged_dict
                dataset = replace_coordinates(merged_dict[source_id], dataset)

                # Merge data    
                with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                    merged_dict[source_id] = xr.merge([merged_dict[source_id], dataset])

                if len(list(merged_dict.keys())) == 1:
                    print(f"Datasets '{merg_model_name}' ('{merg_model_table_id}') and '{dataset_name}' ('{table_id}') are merged to 'ds_dict' with key '{source_id}'.")
                else:
                    print(f"Datasets '{dataset_name}' ('{table_id}') is merged with 'ds_dict'.")

        else:
            merged_dict[source_id] = dataset
            print(f"Dataset '{dataset_name}' ('{table_id}') is saved in 'ds_dict'.")

    return merged_dict

In [None]:
def drop_redundant(ds_dict, drop_list): 
    """
    Remove redundant coordinates and variables from datasets in a dictionary.

    Parameters:
    ds_dict (dict): Dictionary containing dataset names as keys and xarray.Dataset objects as values.
    drop_list (list): List of redundant coordinate or variable names to be removed from the datasets.

    Returns:
    dict: Dictionary with the same keys as the input ds_dict and modified xarray.Dataset objects with redundant elements removed.
    """
    for ds_name, ds_data in ds_dict.items():
        
        if 'sdepth' in ds_data.coords:
            if 'depth' in ds_data.coords:
                ds_data = ds_data.drop('depth')
            if 'depth' in ds_data.dims:
                ds_data = ds_data.drop_dims('depth')
            ds_data = ds_data.rename({'sdepth': 'depth'})
            print(f'sdepth changed to depth for model {ds_data.source_id}')
            # Add comment about changes to data 
            if 'log' in ds_data.attrs:
                log_old = ds_data.attrs['log']
                ds_data.attrs['log'] = f'Coordinate name changed from sdepth to depth. // {log_old}'
            else:
                ds_data.attrs['log'] = 'Coordinate name changed from sdepth to depth.'
            
        if 'solth' in ds_data.coords:
            if 'depth' in ds_data.coords:
                ds_data = ds_data.drop('depth')
            if 'depth' in ds_data.dims:
                ds_data = ds_data.drop_dims('depth')
            ds_data = ds_data.rename({'solth': 'depth'})
            print(f'solth changed to depth for model {ds_data.source_id}')
            # Add comment about changes to data 
            if 'log' in ds_data.attrs:
                log_old = ds_data.attrs['log']
                ds_data.attrs['log'] = f'Coordinate name changed from solth to depth. // {log_old}'
            else:
                ds_data.attrs['log'] = 'Coordinate name changed from solth to depth.'
   
        
        if 'mrsol' in ds_data and 'depth' in drop_list or 'tsl' in ds_data and 'depth' in drop_list:
            drop_list.remove('depth')
                      
        for coord in drop_list:
            if coord in ds_data.coords:
                ds_data = ds_data.drop(coord).squeeze()
                print(f'Dropped coordinate: {coord}')
                # Add comment about changes to data 
                if 'log' in ds_data.attrs:
                    log_old = ds_data.attrs['log']
                    ds_data.attrs['log'] = f'Dropped: {coord}. // {log_old}'
                else:
                    ds_data.attrs['log'] = f'Dropped: {coord}.'
            if coord in ds_data.variables:
                ds_data = ds_data.drop_vars(coord).squeeze()
                print(f'Dropped variable: {coord}')
                # Add comment about changes to data 
                if 'log' in ds_data.attrs:
                    log_old = ds_data.attrs['log']
                    ds_data.attrs['log'] = f'Dropped: {coord}. // {log_old}'
                else:
                    ds_data.attrs['log'] = f'Dropped: {coord}.'
            
        # Check if the coords were dropped successfully and use squeeze if their length is 1
        for coord in drop_list:
            if coord in ds_data.dims:
                print(f"Coordinate {coord} was not dropped.")
                if ds_data.dims[coord] == 1:
                    ds_data = ds_data.squeeze(coord, drop=True)
                    print(f"Squeezed coordinate: {coord}")
                    # Add comment about changes to data 
                    if 'log' in ds_data.attrs:
                        log_old = ds_data.attrs['log']
                        ds_data.attrs['log'] = f'Dropped: {coord}. // {log_old}'
                    else:
                        ds_data.attrs['log'] = f'Dropped: {coord}.'
            
        # Update the dictionary with the modified dataset
        ds_dict[ds_name] = ds_data
    
    return ds_dict

In [None]:
def save_file(save_file, folder, save_var=True):
    """
    Save files as netCDF.

    Args:
        savefile (dict or dataset): Dictionary of xarray datasets or dataset.
        folder (string): Name of folder data is saved in.
        save_var (boolean): If True, data is saved separately for each variable. If false, one file is saved with all variables.
        

    Returns:
        nc_out: Path were data is saved in.
    """
    
    if save_var:
        for key, ds in ds_dict.items():
            for var in ds:
                # Variable to keep
                variable_to_keep = var
                dimensions_to_keep = {'time', 'lat', 'lon'}
                coordinates_to_keep = {'time', 'lat', 'lon'}

                if any('depth' in ds[var].dims for var in ds.variables):
                    dimensions_to_keep.add('depth')
                    coordinates_to_keep.add('depth')

                # Create a new dataset with only the desired variable
                ds_var = ds[[variable_to_keep]]

                # Keep only the desired dimensions
                ds_var = ds_var.isel({dim: slice(None) for dim in dimensions_to_keep.intersection(ds_var.dims)})

                # Set the desired coordinates
                coords_to_set = set(ds_var.variables).intersection(coordinates_to_keep)
                ds_var = ds_var.set_coords(list(coords_to_set))
                
                if var == 'sftlf':
                    savepath = f'../../data/CMIP6/landmask/raw/{var}/'
                    filename = f'CMIP.{ds_var.source_id}.{var}.nc'
                    nc_out = os.path.join(savepath, filename)
                    os.makedirs(savepath, exist_ok=True) 
                    if os.path.exists(nc_out):
                            inp = input(f"Delete old file {filename} (y/n):")
                            if inp.lower() in ["y"]:
                                os.remove(nc_out)
                                print(f"File  with path: {nc_out} removed")
                            else:
                                filename = "temp_file.nc"
                                nc_out = os.path.join(savepath, filename)
                                print(f"Filename change to {filename}")

                    # Save to netcdf file
                    with dask.config.set(scheduler='threads'):
                        ds_var.to_netcdf(nc_out)
                        print(f"File with path: {nc_out} saved")
                        
                else:
                    savepath = f'../../data/CMIP6/{ds_var.experiment_id}/raw/{var}/'
                    filename = f'CMIP.{ds_var.source_id}.{ds_var.experiment_id}.{var}.nc'
                    nc_out = os.path.join(savepath, filename)
                    os.makedirs(savepath, exist_ok=True) 
                    if os.path.exists(nc_out):
                            inp = input(f"Delete old file {filename} (y/n):")
                            if inp.lower() in ["y"]:
                                os.remove(nc_out)
                                print(f"File  with path: {nc_out} removed")
                            else:
                                filename = "temp_file.nc"
                                nc_out = os.path.join(savepath, filename)
                                print(f"Filename change to {filename}")

                    # Save to netcdf file
                    with dask.config.set(scheduler='threads'):
                        ds_var.to_netcdf(nc_out)
                        print(f"File with path: {nc_out} saved")
       
    else:
        for key in save_file.keys():
            ds_in = save_file[key]
            filename = f'CMIP.{ds_in.source_id}.{ds_in.experiment_id}.nc'
            savepath = f'../../data/CMIP6/{ds_in.experiment_id}/{folder}'
            nc_out = os.path.join(savepath, filename)
            os.makedirs(savepath, exist_ok=True) 
            if os.path.exists(nc_out):
                inp = input(f"Delete old file {filename} (y/n):")
                if inp.lower() in ["y"]:
                    os.remove(nc_out)
                    print(f"File  with path: {nc_out} removed")
                else:
                    filename = "temp_file.nc"
                    nc_out = os.path.join(savepath, filename)
                    print(f"Filename change to {filename}")

            # Save to netcdf file
            with dask.config.set(scheduler='threads'):
                ds_in.to_netcdf(nc_out)

    return nc_out

### 1. Open dkrz catalog and save CMIP6 catalog

In [20]:
cat.dkrz_cmip6_disk

TypeError: esm_datastore.__init__() missing 1 required positional argument: 'esmcol_obj'

In [None]:
# ----Open dkrz catalog----
cat = intake.open_catalog(["https://dkrz.de/s/intake"])

# ----Save CMIP6 catalog----
cat_cmip6 = cat.dkrz_cmip6_disk

### 2. Load important hydroecological data


#### Attributes
| 'source_id' | 'member_id' | 'variable_id' |
|:-----------:|:-----------:|:-----------:|
| 'BCC-CSM2-MR', 'CESM2', 'CNRM-CM6-1-HR', 'NorESM2-MM', 'SAM0-UNICON', 'TaiESM1'  | 'r1i1p1f1', 'r1i1p1f2' (for CNRM)   | 'pr', 'mrro', 'mrros', 'evspsbl', 'evspsblsoi', 'evspsblveg', 'tran', 'mrso', 'mrsos', 'mrsol', 'huss', 'hurs', 'lai', 'gpp', 'npp' | 

In [None]:
# ----Define attributes----
attrs = dict(
    
    #experiment_id = "historical", #ssp126 historical
    #member_id = ['r1i1p1f1'], #,'r1i1p1f2', 'r1i1p2f1', 'r1i1p1f3'],
    #institution_id = ["AWI"],
    #[TaiESM1, BCC-CSM2-MR,  CanESM5, CNRM-CM6-1, CNRM-ESM2-1, IPSL-CM6A-LR, UKESM1-0-LL, MPI-ESM1-2-LR, CESM2-WACCM, NorESM2-MM]
   source_id = ['BCC-CSM2-MR'], #'IPSL-CM6A-LR'], #, 'TaiESM1', 'AWI-ESM-1-1-LR', 'BCC-CSM2-MR', 'BCC-ESM1', 'CanESM5', 'CNRM-CM6-1', 'CNRM-CM6-1-HR', 'CNRM-ESM2-1','UKESM1-0-LL', 'CESM2', 'CESM2-FV2', 'CESM2-WACCM', 'CESM2-WACCM-FV2', 'MPI-ESM1-2-LR', 'NorESM2-MM'], 
    table_id =['Amon', 'Lmon', 'Emon'], #'Amon', 'Lmon', 'Emon'
    variable_id=[
              #  'tas',
                'ps', #surface pressure
              #   'pr', # CESM2 has problems loading pr with other Amon data
              #  'mrro', 
              #  'mrros', 
              #  'evspsbl', 
              #  'evspsblsoi', 
              #  'evspsblveg', 
              #  'tran', 
              #  'mrso', 
              #  'mrsos', 
              # 'mrsol', 
              #  'huss', 
             # 'hurs',  # TaiESM1 has hurs only in daily resolution
              #  'lai', 
              # 'gpp', 
             #   'npp',
              # 'tsl'
        #'sftof'
        #'sftlf' # land area fraction
    ]
  #  ,version = ['v20200623', 'v20200624'] #TaiESM1 has two versions for gpp. I select the newer version.
)

In [None]:
# ----Save data selection----
selection = cat_cmip6.search(**attrs)
#selection = cat_cmip6.search(require_all_on=["source_id"], **attrs) #require_all_on defines that source ID must include all important variables

In [None]:
# ----Set properties of pandas tables ----
pd.set_option('display.max_colwidth', None) #pd.reset_option('display.max_colwidth')
pd.set_option('display.max_rows', None) #pd.reset_option('display.max_rows', None)

# ----Print table with different attributes of selected data----
selection.df.groupby(
    [
       # "grid_label",
        "institution_id",
        "source_id",
     #   "version",
        'member_id',
       # "time_range",
        'experiment_id',
    
      #   'table_id'
    #    'variable_id'
    ]
)['variable_id'].unique().apply(list).to_frame()

In [None]:
##### ========= Load selection in dictionary ========== (I always have to run the 'define attrs' cell again...)

with dask.config.set(**{"use_cftime": True, "decode_times": True, 'consolidated': True}):
    ds_dict = selection.to_dataset_dict(preprocess=pre_preprocessing) 

In [14]:
# =========== Drop redundant coordinates and variables ================

# Define redundant coordinates and variables
drop_list = ['member_id','type','nbnd', 'bnds', 'height', 'depth', 'lat_bnds', 'lon_bnds', 'time_bnds', 'time_bounds', 'depth_bnds', 'sdepth_bounds', 'depth_bounds', 'hist_interval', 'axis_nbounds'] #depth is not dropped for datasets with variable mrsol

# Drop redundant coordinates and variables
ds_dict = drop_redundant(ds_dict, drop_list)

Dropped coordinate: member_id
Dropped coordinate: bnds
Dropped variable: lat_bnds
Dropped variable: lon_bnds
Dropped variable: time_bnds


In [15]:
# =========== Merge datasets with different table_id and same source_id ================
ds_dict = merge_source_id_data(ds_dict)

Merging dataset 'CMIP.AWI-ESM-1-1-LR.historical.Lmon.gn' with source_id 'AWI-ESM-1-1-LR' and table_id 'Lmon'...
Dataset 'CMIP.AWI-ESM-1-1-LR.historical.Lmon.gn' ('Lmon') is saved in 'ds_dict'.


In [17]:
unique_values = ds_dict[list(ds_dict.keys())[0]]['tran'].to_series().unique()
print(unique_values)

[ 0.0000000e+00            nan -3.4177360e-06 ... -6.1483814e-07
 -4.0925994e-07 -2.3161016e-07]


### 3. Save Data as netCDF

In [14]:
# =========== Store file and remove any former one ==========
nc_out = save_file(ds_dict, folder='raw')

File with path: ../../data/CMIP6/historical/raw/sftof/CMIP.CESM2-WACCM.historical.sftof.nc saved


In [43]:
# =========== Check stored file ==============
xr.open_dataset(nc_out)