## Import Packages

In [1]:
import glob
import warnings
import numpy as np
import xarray as xr
from xgcm import Grid
from datetime import datetime
warnings.filterwarnings('ignore')

## User-Defined Fields

In [3]:
AUTHOR   = 'Savannah L. Ferretti'
EMAIL    = 'savannah.ferretti@uci.edu'
FILEDIR  = '/ocean/projects/atm200007p/sferrett/Repos/monsoon-pr/data/raw/models'
SAVEDIR  = '/ocean/projects/atm200007p/sferrett/Repos/monsoon-pr/data/interim'
YEARS    = [2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014]
MONTHS   = [6,7,8]
LATRANGE = (0.,30.) 
LONRANGE = (50.,90.)
LEVS     = [500.,550.,600.,650.,700.,750.,775.,800.,825.,850.,875.,900.,925.,950.,975.,1000.]
MODELS   = [
    # 'AWI-ESM-1-1-LR',
    # 'BCC-CSM2-MR',
    # 'CESM2',
    # 'CMCC-CM2-SR5',
    # 'CMCC-ESM2',
    # 'CanESM5',
    # 'FGOALS-g3',
    # 'GISS-E2-1-G',
    # 'IITM-ESM',
    # 'MIROC-ES2L',
    # 'MIROC6',
    # 'MPI-ESM-1-2-HAM',
    # 'MPI-ESM1-2-HR',
    # 'MPI-ESM1-2-LR',
    # 'MRI-ESM2-0',
    # 'NESM3',
    # 'NorESM2-MM',
    # 'SAM0-UNICON',
    # 'TaiESM1',
]

## Functions

In [None]:
def load(model,varname,filedir=FILEDIR):
    files = sorted(glob.glob(f'{filedir}/{varname}*{model}*.nc'))
    data  = xr.open_mfdataset(files)
    frequency = data.attrs['table_id']
    if varname == 'pr' and frequency != '3hr':
        data.coords['time'] = data.time.dt.floor('3H')
        data = data.groupby('time').mean()
    elif (varname == 'ta' or varname == 'hus') and frequency != '6hrLev':
        data.coords['time'] = data.time.dt.floor('6H')
        data = data.groupby('time').first()            
    return data

In [None]:
def preprocess(data,shape):
    if shape == '3D':
        dims = ['time','lat','lon']
    elif shape == '4D':
        dims = ['time','lat','lon','lev']
    data = data.drop_dims(set(data.dims)-{*dims})
    for dim in dims:
        if dim == 'time' and data.coords[dim].dtype.kind != 'M':
            data.coords[dim] = data.indexes[dim].to_datetimeindex()
        elif dim != 'time':
            data.coords[dim] = data.coords[dim].astype(float)
            # if dim == 'lon' and (data.coords[dim].min() >= 0 and data.coords[dim].max() <= 360):
            #     data.coords[dim] = ((data.coords[dim]+180)%360)-180
    data = data.sortby(dims).transpose(*dims)
    return data

def subset(data,years=YEARS,months=MONTHS,latrange=LATRANGE,lonrange=LONRANGE):
    data = data.sel(time=(data['time.year'].isin(years))&(data['time.month'].isin(months)))
    data = data.sel(lat=slice(*latrange),lon=slice(*lonrange))
    return data

def interpolate(data,varname,levs=LEVS):
    levtype = data.lev.attrs['standard_name']
    if levtype == 'alevel':
        p = -1*data.lev
        data['lev'] = p
        dims = ['time','lat','lon','lev']
        data = data.sortby(dims).transpose(*dims)
        vardata = {data[varname].name:([*data[varname].dims],data[varname].data)}
        coords  = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data,'lev':data.lev.data}
        data = xr.Dataset(vardata,coords)
        interped = data[varname].interp(lev=LEVS,kwargs={'fill_value':'extrapolate'})
    elif levtype == 'atmosphere_hybrid_sigma_pressure_coordinate' or levtype == 'atmosphere_sigma_coordinate':
        if 'p0' in list(data.variables): 
            p = data.a*data.p0 + data.b*data.ps
        elif 'ap' in list(data.variables):
            p = data.ap + data.b*data.ps
        elif 'ptop' in list(data.variables):
            p = data.ptop + data.lev*(data.ps-data.ptop)
        p = p/100
        dims = ['time','lat','lon','lev']
        p = p.sortby(dims).transpose(*dims)
        vardata = {data[varname].name:([*data[varname].dims],data[varname].data),'p':([*p.dims],p.data)}
        coords  = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data,'lev':data.lev.data}
        data = xr.Dataset(vardata,coords)    
        grid = Grid(data,coords={'Z':{'center':'lev'}},periodic=False)
        interped = grid.transform(data[varname],'Z',np.array(LEVS),target_data=data.p,method='log',mask_edges=False).rename({'p':'lev'})
    if varname == 'ta':
        interped.name = 't'
    elif varname == 'hus':
        interped.name = 'q'
    return interped

In [40]:
def dataset(data,longname,units,model,frequency,author=AUTHOR,email=EMAIL):
    varname = data.name
    vardata = {data.name:([*data.dims],data.data)}
    if 'lev' in data.dims:
        coords = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data,'lev':data.lev.data}
    else:
        coords = {'time':data.time.data,'lat':data.lat.data,'lon':data.lon.data}
    data = xr.Dataset(vardata,coords)
    data[varname].attrs = dict(long_name=longname,units=units)
    data.time.attrs = dict(long_name='Time')
    data.lat.attrs = dict(long_name='Latitude',units='°N')
    data.lon.attrs = dict(long_name='Longitude',units='°E')
    if 'lev' in data.dims:
        data.lev.attrs = dict(long_name='Pressure level',units='hPa')
    data.attrs = dict(source=model,frequency=frequency,
                      history=f'Created on {datetime.today().strftime("%Y-%m-%d")} by {author} ({email})')
    return data
    
def save(data,model,savedir=SAVEDIR):
    varname  = list(data.keys())[0]
    return data.compute().to_netcdf(f'{savedir}/{model}_{varname}.nc',mode='w')

## Process & Save Variables

In [41]:
for model in MODELS:
    
    ds = load(model,'hus')
    ds = preprocess(ds,shape='4D')
    q  = interpolate(ds,varname='hus')
    q  = dataset(q,longname='Specific humidity',units='kg/kg',frequency='6-hourly',model=model)
    q  = subset(q)
    save(q,model)
    del ds,q
    
    ds = load(model,'ta')
    ds = preprocess(ds,shape='4D')
    t  = interpolate(ds,varname='ta')
    t  = dataset(t,longname='Air temperature',units='K',frequency='6-hourly',model=model)
    t  = subset(t)
    save(t,model)
    del t
    
    ps = ds.ps/100
    ps = dataset(ps,longname='Surface pressure',units='hPa',frequency='6-hourly',model=model)
    ps = subset(ps)
    save(ps,model)
    del ds,ps
    
    ds = load(model,'pr')
    ds = preprocess(ds,shape='3D')
    pr = ds.pr*86400
    pr = pr.where(pr>=0,0)
    pr = dataset(pr,longname='Precipitation flux',units='mm/day',frequency='3-hourly mean',model=model)
    pr = subset(pr)
    save(pr,model)
    del ds,pr