# CMIP6 Preprocess Data

**Following steps are included in this script:**

1. Load netCDF files
2. Regrid data to 1x1° format
3. Create consistent time coordinates
4. Convert units
5. Appyl landmask
6. Compute Soil Moisture for top 1 and 2 meters by interpolation

Save and replace netcdf files

In [1]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import xesmf as xe
import intake
import dask
import os

import matplotlib.pyplot as plt

%matplotlib inline

### 1. Load netCDF files

In [2]:
# ========= Define period, models and path ==============
experiment_id = 'historical'
source_id = ['CESM2-WACCM'] # 'TaiESM1', 'AWI-ESM-1-1-LR', 'BCC-CSM2-MR', 'BCC-ESM1', 'CanESM5', 'CNRM-CM6-1', 'CNRM-CM6-1-HR', 'UKESM1-0-LL', 'CESM2', 'CESM2-FV2', 'CESM2-WACCM', 'NorESM2-MM'], 
savepath = f'../data/CMIP6/{experiment_id}/raw'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# ========= Create a helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

# ========= Create dictionary using a dictionary comprehension and Dask =======
ds_dict, = dask.compute({model: open_dataset(os.path.join(savepath, f'CMIP.{model}.{experiment_id}.nc'))
                        for model in source_id})

In [16]:
# =========== Drop redundant coordinates and variables ================

# Define redundant coordinates and variables
drop_list = ['member_id','nbnd', 'bnds', 'height', 'depth', 'lat_bnds', 'lon_bnds', 'time_bnds', 'time_bounds', 'depth_bnds', 'sdepth_bounds', 'depth_bounds', 'hist_interval', 'axis_nbounds'] #depth is not dropped for datasets with variable mrsol

# Drop redundant coordinates and variables
ds_dict = drop_redundant(ds_dict, drop_list)

In [None]:
#ds_dict[list(ds_dict.keys())[0]]['tsl'] = ds_dict_tsl[list(ds_dict_tsl.keys())[0]].tsl

In [17]:
# ========= Have a look into the dictionary =======
print(list(ds_dict.keys()))
ds_dict[list(ds_dict.keys())[0]]

['CESM2-WACCM']


In [None]:
# ========= Have a look into the dictionary =======
print(list(ds_dict_tsl.keys()))
ds_dict_tsl[list(ds_dict_tsl.keys())[0]]

### 2. Regrid data to 1x1° format

In [None]:
# ========== Regridding ===========
ds_dict = regrid(ds_dict, method="patch") #"bilinear", "nearest_s2d", "patch", "conservative", "conservative_normed"

### 3. Convert units

In [50]:
# ========== Convert units ============

# New unit conversion must be defined in function
conv_units = {'pr': 'mm/day',
            'evspsbl': 'mm/day',
            #'evspsblsoi': 'mm/day', 
            #'evspsblveg': 'mm/day', 
            'mrro': 'mm/day', 
            #'mrros': 'mm/day',
            'gpp': 'gC/m²/day', 
            #'npp': 'gC/m²/day',
            'tran': 'mm/day'
            }

ds_dict = set_units(ds_dict, conv_units)

Unit of pr converted from kg/m²/s to mm/day.
Unit of evspsbl converted from kg/m²/s to mm/day.
Unit of mrro converted from kg/m²/s to mm/day.
Unit of gpp converted from kg/m²/s to gC/m²/day.
Unit of tran converted from kg/m²/s to mm/day.


### 4. Create consistent coordinates

In [None]:
# =========== Create consistent time coordinate ==========

# Define reference dataset with desired time coordinate
ref_ds = xr.open_dataset(f'../data/CMIP6/historical/raw/CMIP.NorESM2-MM.historical.nc')

# Apply time coordinate on dictionary
ds_dict = consis_time(ds_dict, ref_ds)

### 5. Apply landmask

In [None]:
# ========== Create Landmask of each Dataset ========

In [38]:
landmask_path = def_landmask(ds_dict)

Time variable is already in the requested format
Grid of AWI-ESM-1-1-LR already 1x1°
Unique values: [-9223372036854775808                    0                    1
                    2                    3                    4
                    5                    6]
File  with path: ../data/CMIP6/historical/CMIP.historical.landmask_AWI-ESM-1-1-LR.nc removed


In [39]:
# ========== Check new landmask ===============
xr.open_dataarray(landmask_path)

In [None]:
# =========== Apply landmask =============

In [None]:
landmask = xr.open_dataarray(f'../data/CMIP6/historical/CMIP.historical.landmask_{ds_dict[list(ds_dict.keys())[0]].source_id}.nc')

In [None]:
for i, (name, ds) in enumerate(ds_dict.items()):
    masked_ds = ds * landmask
    masked_ds.attrs = ds.attrs
    for var in ds.variables:
        masked_ds[var].attrs = ds[var].attrs
    ds_dict[name] = masked_ds

In [None]:
ds_dict[list(ds_dict.keys())[0]]

In [27]:
# =========== Apply landmask =============

In [43]:
landmask = xr.open_dataarray(f'../data/CMIP6/historical/CMIP.historical.landmask_nan_1.nc')

In [44]:
for i, (name, ds) in enumerate(ds_dict.items()):
    masked_ds = ds * landmask
    masked_ds.attrs = ds.attrs
    for var in ds.variables:
        masked_ds[var].attrs = ds[var].attrs
    ds_dict[name] = masked_ds

In [45]:
ds_dict[list(ds_dict.keys())[0]].tran.isel(time=1).plot()

NameError: name 'ds_dict_' is not defined

In [None]:
ds_dict[list(ds_dict.keys())[0]].tran.isel(time=1).plot()

### 6. Compute Soil Moisture/Temperature for top 1 and 2 meters by interpolation and isolate liquid soil moisture

In [None]:
# ============ Compute soil moisture/temperature for top 1 and 2 meters =============
var='tsl' #tsl mrsol
var_1m='tsl1m' #tsl1m mrsol1m
var_2m='tsl2m' #tsl2m mrsol2m


for i, (name, ds) in enumerate(ds_dict.items()):
    if ds["depth"].attrs["positive"] == "down":
        depth = ds["depth"]
    else:
        depth = -ds["depth"]
        
    # Interpolate mrsol/tsl at 1 meters depth
    target_depth = 1  # Depth in meter (m) (depending on the dataset's units)
    ds[var_1m] = ds[var].interp(depth=target_depth, method="linear")
    
    # Interpolate mrsol/tsl at 1 meters depth
    target_depth = 1  # Depth in meter (m) (depending on the dataset's units)
    ds[var_2m] = ds[var].interp(depth=target_depth, method="linear")

In [None]:
# Get only 1 and 2m data
dict_1m = {}
dict_2m = {}

depth=1.0

for i, (name, ds) in enumerate(ds_dict.items()):
    max_depth_below = ds.depth.where(ds.depth < depth).max().compute().values

    if not 'depth' in ds_dict[name][var_1m].dims:
        ds1m = ds[var_1m].expand_dims({'depth': [depth]})
    else:
        ds1m  = ds[var_1m]

    dict_1m[name], = dask.compute(xr.concat([ds_dict[name][var].sel(depth=slice(None, max_depth_below)), 
                                    ds1m], dim='depth'))

for i, (name, ds) in enumerate(ds_dict.items()):
    dict_1m[name]=dict_1m[name].rename({'depth': 'depth_1m'})

    ds[var_1m] = dict_1m[name]

depth=2.0

for i, (name, ds) in enumerate(ds_dict.items()):
    max_depth_below = ds.depth.where(ds.depth < depth).max().compute().values

    if not 'depth' in ds_dict[name][var_2m].dims:
        ds2m = ds[var_2m].expand_dims({'depth': [depth]})
    else:
        ds2m  = ds[var_2m]

    dict_2m[name], = dask.compute(xr.concat([ds_dict[name][var].sel(depth=slice(None, max_depth_below)), 
                                    ds2m], dim='depth'))

for i, (name, ds) in enumerate(ds_dict.items()):
    dict_2m[name]=dict_2m[name].rename({'depth': 'depth_2m'})

    ds[var_2m] = dict_2m[name]

In [None]:
# ============ First compute soil temperature for top 1 and 2 meters =============

# Set the freezing point of water in the soil (in Kelvin)
freezing_point = 273.15

In [None]:
# Create a mask where tsl is less than or equal to the freezing_point_kelvin
for i, (name, ds) in enumerate(ds_dict.items()):
    frozen_soil_mask = ds["tsl1m"] <= freezing_point
    ds_dict[name]["tsl1mfrozen"]= frozen_soil_mask

In [None]:
for i, (name, ds) in enumerate(ds_dict.items()):
    frozen_soil_mask = ds["tsl2m"] <= freezing_point
    ds_dict[name]["tsl2mfrozen"]= frozen_soil_mask

In [None]:
# Calculate the liquid soil moisture per layer
for i, (name, ds) in enumerate(ds_dict.items()):
    liquid_soil_moisture_per_layer = ds["mrsol1m"].where(~ds_dict[name]['tsl1mfrozen'])
    ds_dict[name]['mrsol1m_liquid'] = liquid_soil_moisture_per_layer

In [None]:
for i, (name, ds) in enumerate(ds_dict.items()):
    liquid_soil_moisture_per_layer = ds["mrsol2m"].where(~ds_dict[name]['tsl2mfrozen'])
    ds_dict[name]['mrsol2m_liquid'] = liquid_soil_moisture_per_layer

In [None]:
# Compute cumulative liquid soil moisture 

for i, (name, ds) in enumerate(ds_dict.items()):
    ds_dict[name]['lmrso_1m'] = ds['mrsol1m_liquid'].sum(dim='depth_1m')

In [None]:
for i, (name, ds) in enumerate(ds_dict.items()):
    ds_dict[name]['lmrso_2m'] = ds['mrsol2m_liquid'].sum(dim='depth_2m')

In [None]:
ds_dict[list(ds_dict.keys())[0]].lmrso_1m.attrs = {'standard_name': 'mass_content_of_liquid_water_in_1m_soil_column',
                                               'long_name': 'Total Liquid Soil Moisture Content of 1 m Column',
                                                'comment': 'The mass per unit area  (summed over all soil layers until 1 m depth) of liquid water.',
                                                'units': 'kg/m²'
                                               }

In [None]:
ds_dict[list(ds_dict.keys())[0]].lmrso_2m.attrs = {'standard_name': 'mass_content_of_liquid_water_in_2m_soil_column',
                                               'long_name': 'Total Liquid Soil Moisture Content of 2 m Column',
                                                'comment': 'The mass per unit area  (summed over all soil layers until 2 m depth) of liquid water.',
                                                'units': 'kg/m²'
                                               }

In [None]:
ds_dict[list(ds_dict.keys())[0]]

In [None]:
ds_dict[list(ds_dict.keys())[0]] = ds_dict[list(ds_dict.keys())[0]].drop(['depth_1m', 'depth_2m', 'tsl1m', 'tsl2m', 'mrsol1m', 'mrsol2m', 'tsl1mfrozen', 'tsl2mfrozen', 'mrsol1m_liquid','mrsol2m_liquid']).squeeze()

### 7. Add new variables

In [None]:
# Water Use Efficiency
for i, (name, ds) in enumerate(ds_dict.items()):
    ds_dict[name]['wue'] = ds['gpp']/ds['tran']
    ds_dict[name]['wue'].attrs = {'long_name': 'Water Use Efficiency (GPP/Tr)'}
    


## Save netcdf files (Remove old ones?)(Add a code where I have to press y to delete the old one or n to keep a copy)

In [None]:
for key in land_dict.keys():
    ds_in = land_dict[key]
    filename = f'CMIP.historical.landmask_nan_1.nc'
    savepath = f'../data/CMIP6/historical'
    nc_out = os.path.join(savepath, filename)
    os.makedirs(savepath, exist_ok=True) 
    if os.path.exists(nc_out):
        os.remove(nc_out)
        print(f"File  with path: {nc_out} removed")
    # Save to netcdf file
    with dask.config.set(scheduler='threads'):
        ds_in.to_netcdf(nc_out)

In [None]:
nc_out = save_file(ds_dict, folder='preprocessed')

In [None]:
#test if data is correct
xr.open_dataset(nc_out)

# Functions

In [4]:
def regrid(ds_dict, method='bilinear'):
    """
    Combines different grid labels via interpolation with xesmf

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset 
                        and each value is the dataset itself.
        method (str): Interpolation method for xesmf, by default 'bilinear'. Other options are "nearest_s2d", 
                        "patch", "conservative" or "conservative_normed".

    Returns:
        dict: A dictionary of combined datasets (usually will combine across different variable ids).
    """
  
    # Define output grid 1x1°
    ds_out = xr.Dataset(
        {
            'lat': (['lat'], np.arange(-90, 90, 1), {'units': 'degrees_north'}),
            'lon': (['lon'], np.arange(-180, 180, 1), {'units': 'degrees_east'}),
        }
    )

    for i, (name, ds) in enumerate(ds_dict.items()):
        # Check if lat/lon in dataset
        if ('lat' in ds.coords) and ('lon' in ds.coords):

            # Check if regridding is necessary
            if ds['lat'].equals(ds_out['lat']) and ds['lon'].equals(ds_out['lon']): # regridding not necessary
                    ds_dict[name] = ds 
                    print(f"Grid of {name} already 1x1°")
            else:
                # Initialize regridder
                regridder = xe.Regridder(ds, ds_out, method, ignore_degenerate=True, periodic=True)

                # Regrid data
                reg_ds = regridder(ds, keep_attrs=True)

                # Assign attributes
                reg_ds.attrs.update(ds.attrs)
                reg_ds.attrs["Regridding method"] = method

                # Update the ds_dict with the regridded dataset
                ds_dict[name] = reg_ds
        else:
            raise ValueError(f"No lat and lon in dataset '{name}'.")

    
    return ds_dict

In [5]:
def consis_time(ds_dict, ref_ds):
    """
    Creates consistent time coordinate based on a reference dataset

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset 
                        and each value is the dataset itself.
        ref_ds (xarray): A xarray dataset as reference for the consistent time coordinate

    Returns:
        dict: A dictionary with a new time coordinate depending on the reference dataset.
    """
    time = ref_ds.time
    
    for i, (name, ds) in enumerate(ds_dict.items()):
        # Create consistent time coordinate using the first time coordinate for all following models
        if not ds['time'].equals(time):
            ds['time'] = time
        else:
            print('Time variable is already in the requested format')
            
    return ds_dict

In [6]:
def set_units(ds_dict, conv_units):
    """
     Convert units for specified variables
    """
    
    for i, (name, ds) in enumerate(ds_dict.items()):

        for var in list(conv_units.keys()):

            if var in ds.variables:
                old_unit = ds[var].units
                
                if conv_units[var] == ds[var].units:
                    print('Unit already in the requested format')

                elif conv_units[var] == 'gC/m²/day' and ds[var].units == 'kg/m²/s':
                    
                    # Keep existing attributes and only modify the units attribute
                    attrs = ds[var].attrs
                    attrs['units'] = conv_units[var]
                    ds[var] = ds[var] * 1000 * 60 * 60 * 24 
                    ds[var].attrs = attrs
                    
                    print(f"Unit of {var} converted from {old_unit} to {ds[var].units}.")

                elif conv_units[var] == 'mm/day' and ds[var].units == 'kg/m²/s':
    
                    # Keep existing attributes and only modify the units attribute
                    attrs = ds[var].attrs
                    attrs['units'] = conv_units[var]
                    ds[var] = ds[var] * 60 * 60 * 24 
                    ds[var].attrs = attrs
                    
                    print(f"Unit of {var} converted from {old_unit} to {ds[var].units}.")

                else: 
                    raise ValueError(f"No unit conversion for variable '{var}' specified.")

            else:
                raise ValueError(f"No variable '{var}' in ds_dict.")
                
    return ds_dict

In [7]:
def drop_coords(ds_dict, drop_coords_list):
    for ds_name, ds_data in ds_dict.items():
        print(f'Coordinates in dataset: ', list(ds_data.coords))
        for coord in drop_coords_list:
            if coord in ds_data.coords:
                ds_data = ds_data.drop(coord).squeeze()
                print(f'Dropped coordinate: {coord}')
            # Update the dictionary with the modified dataset
            ds_dict[ds_name] = ds_data
    return ds_dict

In [8]:
def save_file(save_file, folder):
    """
    Save files as netCDF.

    Args:
        savefile (dict or dataset): Dictionary of xarray datasets or dataset.
        folder (string): Name of folder data is saved in.
        

    Returns:
        nc_out: Path were data is saved in.
    """
    
    if type(save_file) == dict:
        for key in save_file.keys():
            ds_in = save_file[key]
            filename = f'CMIP.{ds_in.source_id}.{ds_in.experiment_id}.nc'
            savepath = f'../data/CMIP6/{ds_in.experiment_id}/{folder}'
            nc_out = os.path.join(savepath, filename)
            os.makedirs(savepath, exist_ok=True) 
            if os.path.exists(nc_out):
                os.remove(nc_out)
                print(f"File  with path: {nc_out} removed")
            # Save to netcdf file
            with dask.config.set(scheduler='threads'):
                ds_in.to_netcdf(nc_out)
            
    elif type(save_file) == xr.core.dataset.Dataset:
            filename = f'CMIP.{save_file.source_id}.{save_file.experiment_id}.nc'
            savepath = f'../data/CMIP6/{save_file.experiment_id}/{folder}'
            nc_out = os.path.join(savepath, filename)
            os.makedirs(savepath, exist_ok=True) 
            if os.path.exists(nc_out):
                os.remove(nc_out)
                print(f"File  with path: {nc_out} removed")
            # Save to netcdf file
            with dask.config.set(scheduler='threads'):
                ds_in.to_netcdf(nc_out)
    else:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")
        
    return nc_out

#done = False
#while not done:
#    inp = input("Do we need to make a classlist first? (y/n):")
#    if inp.lower() in ["y", "n"]:
#        done = True
        #rest of your code

In [9]:
def plot_sm_profile(ds_depth, save_fig=False, xlim_bound=3, ylim_bound=1000):
    """
    Plots soil moisture profile.

    Args:
        ds_depth (dict): A dictionary of xarray datasets with depth and mean mean soil water content per layer (mrsol).
        save_fig (bool): If True, save the figure to a file. Default is False.
        xlim_bound (float): A value to set the max for the x-axis. Default is 3.
        ylim_bound (float): A value to set the max for the y-axis. Default is 1000.
    """
    
    fig, ax = plt.subplots(figsize=(30, 15))

    plt.xlim(0, xlim_bound)
    plt.ylim(0, ylim_bound)

    # Define the marker size for the plot
    marker_size = 150

    for i, (name, ds) in enumerate(ds_depth.items()):

        data_to_plot = ds.squeeze()
        data_lines = ax.plot(data_to_plot['depth'], data_to_plot.variable, linestyle='--', label=f"{name}")
        data_color = data_lines[0].get_color()
        data_markers = data_to_plot.plot.scatter(x='depth', y='variable', s=marker_size, c=data_color, ax=ax, label=None)

    plt.legend()

    if save_fig:
        fig.savefig(f'../results/CMIP6/soil_moisture_profile.png', dpi=300)

In [10]:
def soil_moisture_profile(ds_dict, plot_fig=True, save_fig=False, xlim_bound=3, ylim_bound=1000):
    """
    Plots soil moisture profile.

    Args:
        ds_depth (dict): A dictionary of xarray datasets for computing the and mean soil water content per layer (mrsol).
        plot_fig (bool): If True, plot the figure. Default is True.
        save_fig (bool): If True, save the figure to a file. Default is False. plot_fig has to be True as well to save figure.
        xlim_bound (float): A value to set the max for the x-axis. Default is 3.
        ylim_bound (float): A value to set the max for the y-axis. Default is 1000.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    
    ds_depth = {}
    
    for i, (name, ds) in enumerate(ds_dict.items()):
        
        mean_time = getattr(ds.mrsol, 'mean')("time", keep_attrs=True, skipna=True)
        mean_time_space = getattr(mean_time, 'mean')(("lon", "lat"), keep_attrs=True, skipna=True)
        ds_depth[ds.source_id] = mean_time_space
    
    if plot_fig:
        plot_sm_profile(ds_depth, save_fig=save_fig)

    return ds_depth

In [11]:
def ms_1_and_2m(ds_dict):

    for i, (name, ds) in enumerate(ds_dict.items()):
        if 'mrsol100cm' in ds:
            ds['mrsol1m'] = ds['mrsol100cm']
            ds_dict[name] = ds.drop('mrsol100cm')

    for i, (name, ds) in enumerate(ds_dict.items()):
        if 'mrsol200cm' in ds:
            ds['mrsol2m'] = ds['mrsol200cm']
            ds_dict[name] = ds.drop('mrsol200cm')

    depth=1.0

    # Get only mrsol100cm data
    mrsol1m_dict = {}
    mrsol2m_dict = {}

    for i, (name, ds) in enumerate(ds_dict.items()):
        max_depth_below = ds.depth.where(ds.depth < depth).max().compute().values

        if not 'depth' in ds_dict[name].mrsol1m.dims:
            mrsol1m = ds.mrsol1m.expand_dims({'depth': [depth]})
        else:
            mrsol1m  = ds.mrsol1m

        mrsol1m_dict[name], = dask.compute(xr.concat([ds_dict[name].mrsol.sel(depth=slice(None, max_depth_below)), 
                                        mrsol1m], dim='depth'))

    for i, (name, ds) in enumerate(ds_dict.items()):
        mrsol1m_dict[name]=mrsol1m_dict[name].rename({'depth': 'depth_1m'})

        ds['mrsol1m'] = mrsol1m_dict[name]

    depth=2.0

    for i, (name, ds) in enumerate(ds_dict.items()):
        max_depth_below = ds.depth.where(ds.depth < depth).max().compute().values

        if not 'depth' in ds_dict[name].mrsol2m.dims:
            mrsol2m = ds.mrsol2m.expand_dims({'depth': [depth]})
        else:
            mrsol2m  = ds.mrsol2m

        mrsol2m_dict[name], = dask.compute(xr.concat([ds_dict[name].mrsol.sel(depth=slice(None, max_depth_below)), 
                                        mrsol2m], dim='depth'))

    for i, (name, ds) in enumerate(ds_dict.items()):
        mrsol2m_dict[name]=mrsol2m_dict[name].rename({'depth': 'depth_2m'})

        ds['mrsol2m'] = mrsol2m_dict[name]
        
    return 

In [12]:
def plot_diff_mrsol(ds_dict):

    fig, ax = plt.subplots(figsize=(30, 15))

    plt.xlim(0, 10)
    #plt.ylim(0, ylim_bound)

    # Define the marker size for the plot
    marker_size = 150
    name=list(ds_dict.keys())[0]

    data_to_plot_1 = ds_dict[list(ds_dict.keys())[0]].mrsol.isel(time=100, lat=90, lon=200).squeeze()
    data_lines_1 = ax.plot(data_to_plot_1['depth'], data_to_plot_1, linestyle='--', label=f"{name}")
    data_color_1 = data_lines_1[0].get_color()
    data_markers_1 = data_to_plot_1.plot.scatter(x='depth', y='variable', s=marker_size, c=data_color_1, ax=ax, label=None)

    data_to_plot_2 = ds_dict[list(ds_dict.keys())[0]].mrsol1m.isel(time=100, lat=90, lon=200).squeeze()
    data_lines_2 = ax.plot(data_to_plot_2['depth_1m'], data_to_plot_2, linestyle='--', label=f"{name}")
    data_color_2 = data_lines_2[0].get_color()
    data_markers_2 = data_to_plot_2.plot.scatter(x='depth_1m', y='variable', s=marker_size, c=data_color_2, ax=ax, label=None)

    data_to_plot_3 = ds_dict[list(ds_dict.keys())[0]].mrsol2m.isel(time=100, lat=90, lon=200).squeeze()
    data_lines_3 = ax.plot(data_to_plot_3['depth_2m'], data_to_plot_3, linestyle='--', label=f"{name}")
    data_color_3 = data_lines_3[0].get_color()
    data_markers_3 = data_to_plot_3.plot.scatter(x='depth_2m', y='variable', s=marker_size, c=data_color_3, ax=ax, label=None)

    plt.legend()

    fig.savefig(f'../results/CMIP6/TaiESM1_mrsol_+_1m_+_2m_time100_lat90_lon200.png', dpi=300)

In [13]:
def find_first_datapoint(ds_dict, variable):
    
    data = ds_dict[list(ds_dict.keys())[0]][variable]

    # Find the first time index with non-NaN values
    non_nan_time_index = data.notnull().any(dim=["lat", "lon"]).argmax().values

    # Find the first time index with non-zero values
    non_zero_time_index = (data != 0).any(dim=["lat", "lon"]).argmax().values

    # Find the maximum of both time indices to get the first time index with actual values
    first_actual_values_time_index = max(non_nan_time_index, non_zero_time_index)

    first_actual_values_time = data.time.isel(time=first_actual_values_time_index).values

    print("First time index with actual values:", first_actual_values_time_index)
    print("First time with actual values:", first_actual_values_time)

In [14]:
def def_landmask(ds_dict):
    
    ### set time coordinate to same as other data sets ###
    # Define reference dataset with desired time coordinate
    ref_ds = xr.open_dataset(f'../data/CMIP6/historical/raw/CMIP.NorESM2-MM.historical.nc')

    # Apply time coordinate on dictionary
    ds_dict = consis_time(ds_dict, ref_ds)
    
    ds_landmask = {}
    
    ds_landmask[list(ds_dict.keys())[0]] = ds_dict[list(ds_dict.keys())[0]].lai
    
    ds_landmask = regrid(ds_landmask)
    
    # use lai as land variable to create landmask (check before if reference dataset has values at all land points)
    landmask = (ds_landmask[list(ds_landmask.keys())[0]]).astype(int)
    
    # print unique values
    print(f"Unique values:", np.unique(landmask.isel(time=0).values))
    
    # set all numbers below or equal to zero to nan
    landmask = landmask.where(landmask >= 0)
    
    # set all numbers above zero to 1
    landmask = landmask.where(landmask.isnull(), 1)
    
    # save landmask
    filename = f'CMIP.{ds_dict[list(ds_dict.keys())[0]].experiment_id}.landmask_{ds_dict[list(ds_dict.keys())[0]].source_id}.nc'
    savepath = f'../data/CMIP6/{ds_dict[list(ds_dict.keys())[0]].experiment_id}/landmask/'
    nc_out = os.path.join(savepath, filename)
    os.makedirs(savepath, exist_ok=True) 
    if os.path.exists(nc_out):
        os.remove(nc_out)
        print(f"File  with path: {nc_out} removed")
    # Save to netcdf file
    landmask.to_netcdf(nc_out)
    
    return nc_out

In [15]:
def drop_redundant(ds_dict, drop_list): 
    """
    Remove redundant coordinates and variables from datasets in a dictionary.

    Parameters:
    ds_dict (dict): Dictionary containing dataset names as keys and xarray.Dataset objects as values.
    drop_list (list): List of redundant coordinate or variable names to be removed from the datasets.

    Returns:
    dict: Dictionary with the same keys as the input ds_dict and modified xarray.Dataset objects with redundant elements removed.
    """
    for ds_name, ds_data in ds_dict.items():
        
        if 'sdepth' in ds_data.coords:
            if 'depth' in ds_data.dims:
                ds_data = ds_data.drop('depth')  
            ds_data = ds_data.rename({'sdepth': 'depth'})
            print(f'sdepth changed to depth for model {ds_data.source_id}')
   
        
        if 'mrsol' in ds_data and 'depth' in drop_list or 'tsl' in ds_data and 'depth' in drop_list:
            drop_list.remove('depth')
                      
        for coord in drop_list:
            if coord in ds_data.coords:
                ds_data = ds_data.drop(coord).squeeze()
                print(f'Dropped coordinate: {coord}')
            if coord in ds_data.variables:
                ds_data = ds_data.drop_vars(coord).squeeze()
                print(f'Dropped variable: {coord}')
            # Update the dictionary with the modified dataset
            ds_dict[ds_name] = ds_data
    
    return ds_dict