## Import Packages

In [1]:
import intake
import warnings
import numpy as np
import xarray as xr
from xgcm import Grid
from datetime import datetime
from xmip.preprocessing import combined_preprocessing
warnings.filterwarnings('ignore')

## User-Defined Fields

In [1]:
AUTHOR   = 'Savannah L. Ferretti'
EMAIL    = 'savannah.ferretti@uci.edu'
SAVEDIR  = '/ocean/projects/atm200007p/sferrett/Repos/monsoon-pr/data/interim'
YEARS    = [2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014]
MONTHS   = [6,7,8]
LATRANGE = (0.,30.) 
LONRANGE = (50.,90.)
LEVS     = [500.,550.,600.,650.,700.,750.,775.,800.,825.,850.,875.,900.,925.,950.,975.,1000.]
MODELS  = [
    # 'AWI-ESM-1-1-LR',
    # 'BCC-CSM2-MR',
    # 'CESM2',
    # 'CMCC-CM2-SR5',
    # 'CMCC-ESM2',
    # 'CanESM5',
    # 'FGOALS-g3',
    # 'GISS-E2-1-G',
    # 'IITM-ESM',
    # 'MIROC-ES2L',
    # 'MIROC6',
    # 'MPI-ESM-1-2-HAM',
    # 'MPI-ESM1-2-HR',
    # 'MPI-ESM1-2-LR',
    # 'MRI-ESM2-0',
    # 'NESM3',
    # 'NorESM2-MM',
    # 'SAM0-UNICON',
    # 'TaiESM1',
]

## Functions

In [None]:
def get_data_dict(subset):
    kwargs = {
        'aggregate':False,
        'zarr_kwargs':{'consolidated':True,'use_cftime':True},
        'preprocess':combined_preprocessing}
    modeldict = subset.to_dataset_dict(**kwargs)
    return modeldict

def load(model,varname):
    url  = 'https://storage.googleapis.com/cmip6/cmip6-pgf-ingestion-test/catalog/catalog.json'
    catalog = intake.open_esm_datastore(url)
    query = dict(activity_id='CMIP',
                 experiment_id='historical',
                 source_id=model,
                 table_id='6hrLev',
                 variable_id=varname,
                 grid_label='gn')
    subset = catalog.search(**query)
    modeldict = get_data_dict(subset)
    for key,ds in modeldict.items():
        modeldict[key] = ds
    return ds

In [3]:
def preprocess(data,shape):
    if shape == '3D':
        dims = ['time','y','x']
    elif shape == '4D':
        dims = ['time','y','x','lev']
    data = data.squeeze()
    data = data.drop_dims(set(data.dims)-{*dims}).drop({'lon','lat'})
    for dim in dims:
        if dim == 'time' and data.coords[dim].dtype.kind != 'M':
            data.coords[dim] = data.indexes[dim].to_datetimeindex()
        elif dim != 'time':
            data.coords[dim] = data.coords[dim].astype(float)
    data = data.sortby(dims).transpose(*dims)
    data = data.rename({'y':'lat','x':'lon'})
    return data

def subset(data,years=YEARS,months=MONTHS,latrange=LATRANGE,lonrange=LONRANGE):
    data = data.sel(time=(data['time.year'].isin(years)) & (data['time.month'].isin(months)))
    data = data.sel(lat=slice(*latrange),lon=slice(*lonrange))
    return data

def interpolate(data,varname):
    levtype = data.lev.attrs['standard_name']
    if levtype == 'atmosphere_hybrid_sigma_pressure_coordinate':
        if 'p0' in list(data.variables):
            p = (data.a*data.p0 + data.b*data.ps)/100
        else:
            p = (data.ap + data.b*data.ps)/100
    elif levtype == 'atmosphere_sigma_coordinate':
        p = (data.ptop + data.lev*(data.ps-data.ptop))/100
    elif levtype == 'alevel':
        p = (-1*data.lev).expand_dims({'time':t.time,'lat':t.lat,'lon':t.lon},(1,2,3))
    dims = ['time','lat','lon','lev']
    p = p.sortby(dims).transpose(*dims)
    vardata = {data[varname].name:([*data[varname].dims],data[varname].data),'p':([*p.dims],p.data)}
    coords  = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data,'lev':data.lev.data}
    data = xr.Dataset(vardata,coords)    
    grid = Grid(data,coords={'Z':{'center':'lev'}},periodic=False)
    interped = grid.transform(data[varname],'Z',np.array(LEVS),target_data=data.p,method='log',mask_edges=False).rename({'p':'lev'})
    if varname == 'ta':
        interped.name = 't'
    elif varname == 'hus':
        interped.name = 'q'
    return interped

In [None]:
def dataset(data,longname,units,model,frequency,author=AUTHOR,email=EMAIL):
    varname = data.name
    vardata = {data.name:([*data.dims],data.data)}
    if 'lev' in data.dims:
        coords = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data,'lev':data.lev.data}
    else:
        coords = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data}
    data = xr.Dataset(vardata,coords)
    data[varname].attrs = dict(long_name=longname,units=units)
    data.time.attrs = dict(long_name='Time')
    data.lat.attrs = dict(long_name='Latitude',units='°N')
    data.lon.attrs = dict(long_name='Longitude',units='°E')
    if 'lev' in data.dims:
        data.lev.attrs = dict(long_name='Pressure level',units='hPa')
    data.attrs = dict(source=model,frequency=frequency,
                      history=f'Created on {datetime.today().strftime("%Y-%m-%d")} by {author} ({email})')
    return data
    
def save(data,model,savedir=SAVEDIR):
    varname = list(data.keys())[0]
    return data.compute().to_netcdf(f'{savedir}/{model}_{varname}.nc',mode='w')

## Process & Save Variables

In [5]:
for model in MODELS:

    ds = load(model,'hus')
    ds = preprocess(ds,shape='4D')
    q  = interpolate(ds,varname='hus')
    q  = dataset(q,longname='Specific humidity',units='kg/kg',frequency='6-hourly',model=model)
    q  = subset(q)
    save(q,model)
    del ds,q
    
    ds = load(model,'ta')
    ds = preprocess(ds,shape='4D')
    t  = interpolate(ds,varname='ta')
    t  = dataset(t,longname='Air temperature',units='K',frequency='6-hourly',model=model)
    t  = subset(t)
    save(t,model)
    del t

    ps = ds.ps/100
    ps = dataset(ps,longname='Surface pressure',units='hPa',frequency='6-hourly',model=model)
    ps = subset(ps)
    save(ps,model)
    del ds,ps