## This script needs to run on fat memory (at least 100GB)
Probably 200GB for 30N-30S

For 72 cpus it is a good idea to make chunks={"time" : 1, "lon": 1800}

In [5]:
import xarray as xr
import matplotlib.pyplot as plt
import metcalc
import numpy as np
import scipy.interpolate as scint
import cartopy.crs as ccrs

In [None]:
WORK_DIR = "/work/mh0066/m300577/perpetual_jan/"

In [6]:
def interp1d_np(data, x, xi):
    return np.interp(xi, x, data, left=float('NaN'),right=float('NaN'))


def xr_vertical_in(da,da_p,plevels) :
    # follows the examples from here : https://github.com/pydata/xarray/issues/3931
    # and here: http://xarray.pydata.org/en/stable/examples/apply_ufunc_vectorize_1d.html
      
    interped = xr.apply_ufunc(
        interp1d_np,  # first the function
        da,  # now arguments in the order expected by 'interp1_np'
        da_p,  # as above
        plevels,  # as above
        input_core_dims=[["height"], ["height"], ["plev"]],  # list with one entry per arg
        output_core_dims=[["plev"]],  # returned data has one dimension
        #exclude_dims=set(("height",)),  # dimensions allowed to change size. Must be a set!
        vectorize=True,  # loop over non-core dims
        dask="parallelized",
        output_dtypes=[da.dtype],  # one per output
        #allow_rechunk=True
    )
    interped["plev"] = plevels
    return interped



In [9]:

def preproc(expid,
            var,
            stream,
            months=["04","05","06","07","08","09"],
            levels=np.arange(15.,73.,1.),
            plevels=np.array([1000,950,925,900,850,775,700,650,600,550,500,450,400,350,300,275,250,225,200,175,150,125,100])*100.,
            high_res=True,
            lat_band=20,
            chunks = {"time": 1},
            correct_dims=False
           ) :
    
    '''stream is 3 vor ta and hus, 2 for us and va'''
    if high_res == True :
        res = "01x01"
    else :
        res = "1x1"

    for mon in months :
        
        print("month: " + mon)
        
        var_files = f"{WORK_DIR}/{expid}/{expid}_atm_3d_{str(stream)}_ml_1979{mon}??T000000Z_{var}_{res}.nc"
        
        print(var_files)
        
        p_files = f"{WORK_DIR}/{expid}/{expid}__atm_3d_1_ml_1979{mon}??T000000Z_pfull_{res}.nc"
        
        var_ds = xr.open_mfdataset(var_files, concat_dim="time", parallel=True,combine="nested",chunks=chunks)

        p_ds = xr.open_mfdataset(p_files, concat_dim="time",parallel=True,combine="nested",chunks=chunks)
        
        print(var_ds)

        var_tropics_ = var_ds[var].sel(lat=slice(-lat_band,lat_band))
        p_tropics_ = p_ds.pfull.sel(lat=slice(-lat_band,lat_band))
        
        if high_res == True :
            var_tropics_ = var_tropics_.sel(height=levels)
            p_tropics_ = p_tropics_.sel(height=levels)
        
        print("load data, select latbands and heights...")

        var_tropics = var_tropics_#.compute()
        p_tropics = p_tropics_#.compute()
        #print(var_tropics.lat)
        #print(p_tropics.lat)
        
        if correct_dims : # this is needed if for some reason the coordinates do not match perfectly...
            var_tropics["lat"] = p_tropics.lat
            var_tropics["lon"] = p_tropics.lon
        
        print("vertical interpolation...")

        var_int_ = xr_vertical_in(var_tropics,p_tropics,plevels).to_dataset(name=var)

        var_int = var_int_.compute()
        var_int.to_netcdf(f"{WORK_DIR}/" + expid + "/" + expid + "_" + var + "_"+ str(lat_band) + "N-" + str(lat_band) + "S_1979" + mon + ".nc")

        print("time mean...")
        var_int_tm = var_int.mean("time")

        var_int_tm.to_netcdf(f"{WORK_DIR}/" + expid + "/" + expid + "_" + var + "_" + str(lat_band) + "N-" + str(lat_band) + "S_1979" + mon + "_timemean.nc")

        del var_tropics
        del p_tropics


In [11]:

def preproc_2d(expid,
            var,
            months=["04","05","06","07","08","09"],
            lat_band=20
            ) :
    

    for mon in months :
        
        print("month: " + mon)
        
        var_files = f"{WORK_DIR}/{expid}/{expid}_atm_2d_ml_1979{mon}??T000000Z_*{var}*.nc"
        
        
        var_ds = xr.open_mfdataset(var_files, concat_dim="time", parallel=True,combine="nested")

        var_tropics_ = var_ds[var].sel(lat=slice(-lat_band,lat_band))

        print("load data, select latbands and heights...")

        var_tropics = var_tropics_.compute()

        var_tropics.to_netcdf(f"{WORK_DIR}/" + expid + "/" + expid + "_" + var + "_" + str(lat_band) + "N-" + str(lat_band) + "S_1979" + mon + ".nc")

        print("time mean...")
        var_tropics_tm = var_tropics.mean("time")

        var_tropics_tm.to_netcdf(f"{WORK_DIR}/" + expid + "/" + expid + "_" + var + "_" + str(lat_band) + "N-" + str(lat_band) + "S_1979" + mon + "_timemean.nc")

        del var_tropics


!!

For 3d:

- stream 1, 2 or 3
- use preproc()

For 3d tendencies:

- use preproc_tend()

For 2d :

- Use preproc_2d()

!!



In [12]:
all_months = ["02","03","04","05","06","07","08","09"]

In [14]:
preproc_2d("dap0013-dc","prw",lat_band=20,months=all_months)

month: 02
load data, select latbands and heights...
time mean...
month: 03
load data, select latbands and heights...
time mean...
month: 04
load data, select latbands and heights...
time mean...
month: 05
load data, select latbands and heights...
time mean...
month: 06
load data, select latbands and heights...
time mean...
month: 07
load data, select latbands and heights...
time mean...
month: 08
load data, select latbands and heights...
time mean...
month: 09
load data, select latbands and heights...
time mean...
