In [1]:
# This notebook was run with ESMPy version 8.0.0
import ESMF
import xesmf as xe
import numpy as np
import time
import scipy.sparse

In [7]:
# Convenience function for constructing an ESMF Mesh

def create_mesh(lons, lats, lon_res, lat_res):
    mesh = ESMF.Mesh(parametric_dim=2, spatial_dim=2, coord_sys=ESMF.constants.CoordSys.SPH_DEG)
    num_node = (lat_res+1)*(lon_res+1)
    num_elem = lat_res*lon_res
    nodeId = np.array([x+1 for x in range(num_node)])
    nodeCoord = np.stack([lons.T, lats.T]).flatten("F")
    nodeOwner = np.zeros(num_node)
    elemId = np.array([x+1 for x in range(num_elem)])
    elemType = np.array([ESMF.MeshElemType.QUAD for _ in range(num_elem)])
    elemConn = np.array([(x//lon_res)*(lon_res+1) + x%lon_res + h for x in range(num_elem) for h in [0, 1, lon_res+2, lon_res+1]])
    mesh.add_nodes(num_node, nodeId, nodeCoord, nodeOwner)
    mesh.add_elements(num_elem, elemId, elemType, elemConn)
    return mesh

In [3]:
# Create a mesh that covers the chosen latitudes with a desired number of cells

def create_lat_bound_mesh(lower, upper, resolution=400):
    lons = np.linspace(-180, 180, resolution + 1)
    lons = lons[:, np.newaxis] * np.ones([2])
    lonsflat = lons.flatten()
    lats = np.ones([resolution + 1])[:, np.newaxis] * np.array([lower, upper])
    latsflat = lats.flatten()
    mesh = ESMF.Mesh(parametric_dim=2, spatial_dim=2, coord_sys=ESMF.constants.CoordSys.SPH_DEG)
    num_node = resolution*2+2
    num_elem = resolution
    nodeId = np.array([x+1 for x in range(num_node)])
    nodeCoord = np.stack([lonsflat.T, latsflat.T]).flatten("F")
    nodeOwner = np.zeros(num_node)
    elemId = np.array([x+1 for x in range(num_elem)])
    elemType = np.array([[4 for _ in range(num_elem)]])
    elemConn = np.array([2*x + c for x in range(resolution) for c in [1,0,2,3]])
    mesh.add_nodes(num_node, nodeId, nodeCoord, nodeOwner)
    mesh.add_elements(num_elem, elemId, elemType, elemConn)
    return mesh

In [4]:
# Define weighted statistics

def mean(data, weights):
    w_sum = weights.sum()
    return (data*weights).sum()/w_sum

def w_median(data, weights):
    sorter = np.argsort(data)
    s_data = data[sorter]
    s_weights = weights.flatten()[sorter].cumsum()
    i = np.where(s_weights > s_weights.max()/2)[0].min()
    return s_data[i]

In [5]:
# Apply a weighted statistic over a given latitude bound

def lat_stat(field, bounds, stat=mean, resolution=400):
    mesh = create_lat_bound_mesh(bounds[0], bounds[1], resolution)
    target_f = ESMF.Field(mesh, meshloc=ESMF.MeshLoc.ELEMENT)
    regridder = ESMF.Regrid(field, target_f, regrid_method=ESMF.RegridMethod.CONSERVE,
                            ignore_degenerate=False,
                            unmapped_action=ESMF.UnmappedAction.IGNORE,
                            norm_type=ESMF.api.constants.NormType.DSTAREA,
                            factors=True)
    weights = regridder.get_weights_dict(deep_copy=True)
    regridder.destroy()
    area_f = ESMF.Field(mesh, meshloc=ESMF.MeshLoc.ELEMENT)
    area_f.get_area()
    matrix = scipy.sparse.csr_matrix((weights["weights"], (weights["row_dst"]-1, weights["col_src"]-1)), shape=(resolution, len(field.data)))
    area_adjust = scipy.sparse.csr_matrix((area_f.data.flatten(), ([0 for _ in range(resolution)], list(range(resolution)))), shape=(1, resolution))
    final_weights = (area_adjust*matrix).toarray()
    return stat(field.data, final_weights)

In [10]:
# Construct the ESMF Field to be regridded

lon_res = 1000
lat_res = 1000
extent = 25

lons, lats = np.meshgrid(np.linspace(-extent, extent, num=lon_res+1), np.linspace(-extent, extent, num=lat_res+1))

src_mesh = create_mesh(lons, lats, lon_res, lat_res)

src_field = ESMF.Field(src_mesh, meshloc=ESMF.MeshLoc.ELEMENT)

src_field.data[:] = np.array(range(len(src_field.data)))

In [12]:
# Perform regridding for for different resolutions and different statistics

for res in [4, 10, 40, 100, 400]:
    t = time.time()
    print(f"Resolution: {res}")
    print("")
    t = time.time()
    print("grid method")
    print(f"Mean: {lat_stat(src_field, (-15,0), resolution=res)}")
    print(f"Time taken: {time.time()-t}")
    print("")
    t = time.time()
    print(f"Median: {lat_stat(src_field, (-15,0), stat=w_median, resolution=res)}")
    print(f"Time taken: {time.time()-t}")
    print("")
    print("")

Resolution: 4

grid method
Mean: 350372.7970952477
Time taken: 5.165668964385986

Median: 350805.0
Time taken: 4.7773473262786865


Resolution: 10

grid method
Mean: 350376.9337357486
Time taken: 4.789089679718018

Median: 350809.0
Time taken: 4.796271085739136


Resolution: 40

grid method
Mean: 350561.66135593783
Time taken: 4.8612220287323

Median: 350993.0
Time taken: 4.826596021652222


Resolution: 100

grid method
Mean: 350815.3744050493
Time taken: 5.869265794754028

Median: 351244.0
Time taken: 5.701635122299194


Resolution: 400

grid method
Mean: 350859.2353558041
Time taken: 10.692472219467163

Median: 351287.0
Time taken: 10.56484341621399


