In [1]:
import warnings
warnings.filterwarnings("ignore") 

In [2]:
import numpy as np
import xesmf as xe
import xarray as xr

In [3]:
import zarr
import gcsfs

import time

gcs = gcsfs.GCSFileSystem(token='anon')

In [4]:
# functions for calculating potential density (dens0)

def smow(t):
    a = (999.842594, 6.793952e-2, -9.095290e-3, 1.001685e-4, -1.120083e-6,
         6.536332e-9)
    T68 = t * 1.00024
    return (a[0] + (a[1] + (a[2] + (a[3] + (a[4] + a[5] * T68) * T68) * T68) *
            T68) * T68)

def dens0(s, t):
    T68 = t * 1.00024
    b = (8.24493e-1, -4.0899e-3, 7.6438e-5, -8.2467e-7, 5.3875e-9)
    c = (-5.72466e-3, 1.0227e-4, -1.6546e-6)
    d = 4.8314e-4
    return (smow(t) + (b[0] + (b[1] + (b[2] + (b[3] + b[4] * T68) * T68) *
            T68) * T68) * s + (c[0] + (c[1] + c[2] * T68) * T68) * s *
            s ** 0.5 + d * s ** 2)

In [5]:
# functions for calculate MLD
def func_mld(dens_diff, depths):
    '''
    Function for calculate mld from density difference (den - den10 - 0.03) and depth
    Return mixed layer depth 
    '''
    if np.isnan(dens_diff[0]):
        mld = np.nan
    elif dens_diff[0] >= 0:
        mld = np.nan
    else:
        nthr_index = np.where(dens_diff > 0)[0]
        if len(nthr_index) == 0:
            naninds = np.where(np.isnan(dens_diff))[0]
            if len(naninds) > 0:
                nanindex = naninds[0]
            else:
                nanindex = len(depths)
            mld = depths[nanindex-1]
        else:
            nind = nthr_index[0] + 1
            mld = np.interp(0, dens_diff[nind-2:nind], depths[nind-2:nind])                
    return mld

def xr_func_mld(dens):
    '''
    Function for parallel computing
    '''
    dens10 = dens.interp(lev = 10, method = 'linear')  # density at 10m
    dens_diff = dens - dens10 - 0.03               # density differences 
    mld = xr.apply_ufunc(
        func_mld, 
        dens_diff,#.chunk({"time":25, "x":30, "y":30}),  
        dens_diff.lev, 
        input_core_dims = [["lev"], ["lev"]], 
        vectorize = True,
        dask = "parallelized",
        output_dtypes = [dens_diff.lev.dtype],
    )
    return mld

In [6]:
from intake import open_catalog
url = "https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/master.yaml"
col = open_catalog(url).climate.cmip6_gcs()

In [7]:
models = ['CESM2-WACCM','GFDL-CM4','GFDL-ESM4','IPSL-CM6A-LR','GISS-E2-1-G','GISS-E2-1-G-CC','MIROC-ES2L',
          'NorCPM1', 'NorESM2-LM','MPI-ESM1-2-HR','MPI-ESM1-2-LR','UKESM1-0-LL','CNRM-ESM2-1','ACCESS-ESM1-5',
          'CanESM5-CanOE','CanESM5', 'EC-Earth3']

# models = ['GFDL-CM4']

for model in models:
    print('{}:'.format(model))
    dslist = col.search(variable_id = ['so','thetao'], table_id = 'Omon', experiment_id = 'historical', 
                        source_id = model, grid_label = 'gr')
    if len(list(dslist)) == 0:
        print('no gr')
        dslist = col.search(variable_id = ['so','thetao'], table_id = 'Omon', experiment_id = 'historical', 
                            source_id = model, grid_label = 'gn')  
    if len(list(dslist)) > 0:
        ds = dslist[list(dslist)[0]](zarr_kwargs = {'consolidated': True, 'decode_times':True}).to_dask()
                                     # preprocess = rename_coords).to_dask()
        if ('so' in ds) & ('thetao' in ds):
            ds = ds.isel(member_id = 0)
            ds = ds.reset_coords('member_id', drop = True)
            ds = ds.sel(time = slice('1991', '2010'))

            print(ds.thetao.coords)
            
        else:
            print('no t/s data')
    else:
        print('no data')

CESM2-WACCM:
Coordinates:
  * time     (time) object 1991-01-15 12:00:00 ... 2010-12-15 12:00:00
  * lat      (lat) float64 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * lev      (lev) float64 0.0 10.0 20.0 30.0 ... 4e+03 4.5e+03 5e+03 5.5e+03
  * lon      (lon) float64 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5
GFDL-CM4:
Coordinates:
  * lat      (lat) float64 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * lev      (lev) float64 2.5 10.0 20.0 32.5 ... 5e+03 5.5e+03 6e+03 6.5e+03
  * lon      (lon) float64 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5
  * time     (time) object 1991-01-16 12:00:00 ... 2010-12-16 12:00:00
GFDL-ESM4:
Coordinates:
  * lat      (lat) float64 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * lev      (lev) float64 2.5 10.0 20.0 32.5 ... 5e+03 5.5e+03 6e+03 6.5e+03
  * lon      (lon) float64 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5
  * time     (time) object 1991-01-16 12:00:00 ... 2010-12-16 12:

In [8]:
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.0)),
                     'lon': (['lon'], np.arange(0, 361, 1.0))})

In [14]:
models = ['CESM2-WACCM','GFDL-CM4','GFDL-ESM4','IPSL-CM6A-LR','GISS-E2-1-G','GISS-E2-1-G-CC','MIROC-ES2L',
          'NorCPM1', 'NorESM2-LM','MPI-ESM1-2-HR','MPI-ESM1-2-LR','UKESM1-0-LL','CNRM-ESM2-1','ACCESS-ESM1-5',
          'CanESM5-CanOE','CanESM5', 'EC-Earth3']

mld_dict = {}

for model in models:
    if model == 'IPSL-CM6A-LR':
        mappert = gcs.get_mapper('gs://cmip6/CMIP6/CMIP/IPSL/IPSL-CM6A-LR/historical/r1i1p1f1/Omon/thetao/gn/v20180803')
        mappers = gcs.get_mapper('gs://cmip6/CMIP6/CMIP/IPSL/IPSL-CM6A-LR/historical/r1i1p1f1/Omon/so/gn/v20180803')
        
        dst = xr.open_zarr(mappert, consolidated = True, decode_times = True)
        dss = xr.open_zarr(mappers, consolidated = True, decode_times = True)
        
        dst = dst.thetao.sel(time = slice('1991', '2010'))
        dss = dss.so.sel(time = slice('1991', '2010'))
        
        dst = dst.rename({"olevel": "lev"})
        dss = dss.rename({"olevel": "lev"})
        
    else:
        dslist = col.search(variable_id = ['so','thetao'], table_id = 'Omon', experiment_id = 'historical',
                            source_id = model, grid_label = 'gr')
        if len(list(dslist)) == 0:
            dslist = col.search(variable_id = ['so','thetao'], table_id = 'Omon', experiment_id = 'historical', 
                                source_id = model, grid_label = 'gn')
        
        ds = dslist[list(dslist)[0]](zarr_kwargs = {'consolidated': True, 'decode_times':True}).to_dask()
        ds = ds.isel(member_id = 0)
        ds = ds.reset_coords('member_id', drop = True)
        ds = ds.sel(time = slice('1991', '2010'))
            
        dst = ds.thetao
        dss = ds.so
            
    if model == 'CESM2-WACCM' or model == 'GFDL-CM4' or model == 'GFDL-ESM4' or \
       model == 'GISS-E2-1-G' or model == 'GISS-E2-1-G-CC':
        dst_np = dst.sel(lat=slice(45,50)).sel(lon=slice(210,220))
        dss_np = dss.sel(lat=slice(45,50)).sel(lon=slice(210,220))
    else:
        if 'nav_lat' in dst.coords:
            ds_in = xr.Dataset({"lat": dss.nav_lat, "lon": dss.nav_lon})
        elif 'latitude' in dst.coords:
            ds_in = xr.Dataset({"lat": ds.latitude, "lon": ds.longitude})
        else:
            ds_in = xr.Dataset({"lat": ds.lat, "lon": ds.lon})
        
        regrid = xe.Regridder(ds_in, ds_out, 'bilinear', periodic=True, ignore_degenerate=True)
        dst_gr = regrid(dst)
        dss_gr = regrid(dss)
        
        dst_np = dst_gr.sel(lat=slice(45,50)).sel(lon=slice(210,220)) 
        dss_np = dss_gr.sel(lat=slice(45,50)).sel(lon=slice(210,220))
    
    dens0_np = dens0(dss_np, dst_np)
    mld_np = xr_func_mld(dens0_np)
    mld_np_avg = mld_np.mean('lat').mean('lon')
    
    mld_np_season = mld_np_avg.groupby('time.month').mean('time')
    mld_dict[model] = mld_np_season

In [22]:
from dask_gateway import Gateway
# from dask_gateway import GatewayCluster
gateway = Gateway()
# # cluster = GatewayCluster()
options = gateway.cluster_options()

options.worker_cores = 4 # 16 #16 / 4
options.worker_memory = 16 # 32 #32 / 8  

cluster = gateway.new_cluster(options)

# cluster.adapt(minimum = 2, maximum = 150)
cluster.scale(100)

client = cluster.get_client() 
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [24]:
for model in mld_dict:
    print(model)
    mld_season = mld_dict[model]
    mld_season = mld_season.load()
    savef = 'mld_seasonal/{}.txt'.format(model)
    with open(savef, 'w') as sf:
        for idata in mld_season.values:
            sf.write(str(idata) +"\n")
    time.sleep(20)

CESM2-WACCM
GFDL-CM4
GFDL-ESM4
IPSL-CM6A-LR
GISS-E2-1-G
GISS-E2-1-G-CC
MIROC-ES2L
NorCPM1
NorESM2-LM
MPI-ESM1-2-HR
MPI-ESM1-2-LR
UKESM1-0-LL
CNRM-ESM2-1
ACCESS-ESM1-5
CanESM5-CanOE
CanESM5
EC-Earth3


In [25]:
client.close()
cluster.close()