# Compare Gridding Rate

In [None]:
import os
import pyart
import fsspec
import numpy as np
import radarx as rx
import xarray as xr
import xradar as xd
from osgeo import osr
import wradlib as wrl
import cmweather  # noqa
import cartopy.crs as ccrs
from datetime import datetime

In [None]:
def filter_radar(ds):
    ds = ds.where((ds.reflectivity >= 0) & (ds.reflectivity <= 70))
    return ds

## Grid Setup

In [None]:
x_lims = (-300e3, 250e3)
y_lims = (-250e3, 300e3)
z_lims = (0, 20e3)
h_res = 2000
v_res = 500

## Load data

In [None]:
file = "s3://noaa-nexrad-level2/2018/06/12/KSGF/KSGF20180612_083109_V06"
radar = pyart.io.read_nexrad_archive(file)
filename = os.path.basename(file) + ".nc"
pyart.io.write_cfradial(filename, radar)

In [None]:
dtree = xd.io.open_cfradial1_datatree(filename)

In [None]:
dtree = rx.utils.combine_nexrad_sweeps(dtree)

In [None]:
dtree = dtree.xradar.map_over_sweeps(filter_radar)

In [None]:
dtree = dtree.xradar.georeference()
dtree.groups

In [None]:
dtree["sweep_0"]["reflectivity"].plot(
    x="x",
    y="y",
    xlim=(-300e3, 200e3),
    ylim=(-200e3, 300e3),
    cmap="ChaseSpectral",
    levels=range(-10, 70),
)

## Py-ART

In [None]:
tstart = datetime.now()
# Grid using 11 vertical levels, and 101 horizontal grid cells at a resolution on 1 km
grid = pyart.map.grid_from_radars(
    (radar,),
    grid_shape=(41, 276, 276),
    grid_limits=(
        z_lims,
        y_lims,
        x_lims,
    ),
    fields=["reflectivity"],
)

xg = grid.to_xarray()
print("Py-ART gridding took:", datetime.now() - tstart)
display(xg)

## Radarx

In [None]:
tstart = datetime.now()
ds_rx = dtree.radarx.to_grid(
    data_vars=["reflectivity"],
    pseudo_cappi=False,
    x_lim=x_lims,
    y_lim=y_lims,
    z_lim=z_lims,
    x_step=h_res,
    y_step=h_res,
    z_step=v_res,
    x_smth=0.2,
    y_smth=0.2,
    z_smth=0.6,
)
print("Radarx gridding took:", datetime.now() - tstart)
display(ds_rx)

## Wradlib

In [None]:
# Grid setup
sitecoords = (
    dtree["sweep_0"].longitude.values,
    dtree["sweep_0"].latitude.values,
    dtree["sweep_0"].altitude.values,
)

proj = osr.SpatialReference()
proj.ImportFromEPSG(4326)
maxrange = 275e3
maxalt = 20e3
horiz_res = 2000
vert_res = 500
minalt = dtree.altitude.values
minelev = 0.2
maxelev = 21.0

# Create target 3D grid
trgxyz, trgshape = wrl.vpr.make_3d_grid(
    sitecoords, proj, maxrange, maxalt, horiz_res, vert_res
)

print(trgshape)

In [None]:
tstart = datetime.now()
data_var = "reflectivity"

raw_dt = dtree.xradar.map_over_sweeps(rx.utils.get_geocoords)

proj_crs = xd.georeference.get_crs(raw_dt["sweep_0"].ds)

proj = osr.SpatialReference()

proj.ImportFromEPSG(4326)

swp_list = []
for swp in raw_dt.match("sweep_*"):
    ds = raw_dt[swp].to_dataset()
    xyz = (
        xr.concat(
            [
                ds.coords["x"].reset_coords(drop=True),
                ds.coords["y"].reset_coords(drop=True),
                ds.coords["z"].reset_coords(drop=True),
            ],
            "xyz",
        )
        .stack(npoints=("azimuth", "range"))
        .transpose(..., "xyz")
    )
    swp_list.append(xyz)
xyz = xr.concat(swp_list, "npoints")

data_list = []
for key in list(raw_dt.children):
    if "sweep" in key:
        ds = raw_dt[key].ds
        data = ds[data_var].stack(npoints=("azimuth", "range"))
        data_list.append(data)
data = xr.concat(data_list, "npoints")

# interpolate to Cartesian 3-D volume grid
gridder = wrl.vpr.CAPPI(
    xyz.values,
    trgxyz,
    # gridshape=trgshape,
    maxrange=maxrange,
    minelev=minelev,
    maxelev=maxelev,
)

vol = np.ma.masked_invalid(gridder(data.values).reshape(trgshape))

# diagnostic plot
trgx = trgxyz[:, 0].reshape(trgshape)[0, 0, :]
trgy = trgxyz[:, 1].reshape(trgshape)[0, :, 0]
trgz = trgxyz[:, 2].reshape(trgshape)[:, 0, 0]

ds_wrl = xr.DataArray(
    data=vol,
    coords={"z": trgz, "y": trgy, "x": trgx},
    dims=("z", "y", "x"),
    name=data_var,
).to_dataset()
ds_wrl["time"] = ds_rx["time"]
ds_wrl.attrs = ds_rx.attrs
ds_wrl["latitude"] = ds_rx["latitude"]
ds_wrl["longitude"] = ds_rx["longitude"]
ds_wrl["lon"] = xr.DataArray(ds_rx.lon.values, dims=["x"])
ds_wrl["lat"] = xr.DataArray(ds_rx.lat.values, dims=["y"])
ds_wrl = ds_wrl.set_coords(["lon", "lat"])
print("Wradlib gridding took:", datetime.now() - tstart)
display(ds_wrl)

## Plot

Lets plot all three grids

## Radarx

In [None]:
ds_rx.radarx.plot_max_cappi(
    data_var="reflectivity",
    vmin=-10,
    vmax=70,
    range_rings=True,
    add_map=True,
    projection=ccrs.PlateCarree(),
)

## Py-ART

In [None]:
xg.squeeze().radarx.plot_max_cappi(
    data_var="reflectivity",
    vmin=-10,
    vmax=70,
    range_rings=True,
    add_map=True,
    projection=ccrs.PlateCarree(),
)

## Wradlib

In [None]:
ds_wrl.radarx.plot_max_cappi(
    data_var="reflectivity",
    vmin=-10,
    vmax=70,
    range_rings=True,
    add_map=True,
)