In [1]:
import xcdat
import metpy.interpolate
import numpy as np
import xarray as xr
import xgcm
import cdms2
import cdutil
import yappi
from functools import partial

# Local python import, see `geocat.py`
from geocat import pressure_from_hybrid
from geocat import geocat_interp_hybrid_to_pressure

Compare the performance of different vertical regridding implementations.
1. GeoCAT; Uses `metpy.interpolate.interpolate_1d` with `xr.apply_ufunc` allowing for parallelized operations.
2. MetPy
3. xgcm
4. cdutil

In [2]:
ds = xcdat.open_dataset('T_185001_201312.nc')

cdat_ds = cdms2.open('T_185001_201312.nc')

In [3]:
ds_subset = ds.isel(time=slice(0, 48))

cdat_subset_T = cdat_ds('T', time=slice(0, 48))
cdat_subset_PS = cdat_ds('PS', time=slice(0, 48))
cdat_subset_hybm = cdat_ds('hybm', time=slice(0, 48))
cdat_subset_hyam = cdat_ds('hyam', time=slice(0, 48))

In [4]:
pressure = pressure_from_hybrid(ds_subset['PS'], ds_subset['hyam'], ds_subset['hybm'], p0=100000.)

ds_subset['pressure'] = pressure

target_plev = np.array([92500, 80000])

In [5]:
def geocat_method(ds, target_plev):
    return geocat_interp_hybrid_to_pressure(ds['T'], ds['PS'], ds['hyam'], ds['hybm'], p0=100000., new_levels=target_plev)

In [6]:
def metpy_method(ds, target_plev):
    interp_axis = ds['T'].dims.index('lev')
    
    # Need to pass np.ndarray to func_interpolate, otherwise error `no implementation found for 'numpy.apply_along_axis'` is raised
    # Known issues https://github.com/Unidata/MetPy/issues/1889
    return metpy.interpolate.interpolate_1d(target_plev, ds['pressure'].data, ds['T'].data, axis=interp_axis)

In [7]:
def xgcm_method(ds, target_plev):
    grid = xgcm.Grid(ds, coords={'lev': {'center': 'lev'}}, periodic=False)
    
    return grid.transform(ds['T'], 'lev', target_plev, target_data=ds['pressure'])

In [8]:
def cdms2_method(var, ps, hybm, hyam, target_plev):
    pressure = cdutil.vertical.reconstructPressureFromHybrid(ps, hyam, hybm, 100000.)
    pressure.units = 'Pa'
    
    return cdutil.vertical.linearInterpolation(var, pressure, target_plev)

In [9]:
methods = {
    'geocat': partial(geocat_method, ds_subset, target_plev),
    'metpy': partial(metpy_method, ds_subset, target_plev),
    'xgcm': partial(xgcm_method, ds_subset, target_plev),
    'cdms2': partial(cdms2_method, cdat_subset_T, cdat_subset_PS, cdat_subset_hybm, cdat_subset_hyam, target_plev)
}

def time_method(name, method):
    print('Testing: ', name)
    yappi.set_clock_type("cpu")
    yappi.start()
    result = method()
    # yappi.get_func_stats().print_all()
    yappi.get_thread_stats().print_all()

for m in methods.items():
    time_method(*m)

Testing:  geocat





name           id     tid              ttot      scnt        
_MainThread    0      140560001046336  16.14532  23        
..tPollerUnix  3      140559381427776  0.008746  16        
Thread         2      140559932180032  0.002312  5         
..avingThread  1      140559800833600  0.000555  2         
Testing:  metpy

name           id     tid              ttot      scnt        
_MainThread    0      140560001046336  32.31218  39        
..tPollerUnix  3      140559381427776  0.021923  32        
Thread         2      140559932180032  0.003567  6         
..avingThread  1      140559800833600  0.000555  2         
Testing:  xgcm





name           id     tid              ttot      scnt        
_MainThread    0      140560001046336  33.93527  41        
..tPollerUnix  3      140559381427776  0.021974  33        
Thread         2      140559932180032  0.006012  8         
..avingThread  1      140559800833600  0.000555  2         
Testing:  cdms2

name           id     tid              ttot      scnt        
_MainThread    0      140560001046336  47.12235  56        
..tPollerUnix  3      140559381427776  0.022611  46        
Thread         2      140559932180032  0.006802  10        
..avingThread  1      140559800833600  0.000555  2         
