In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import xradar as xd

In [None]:
org_file = 'radar.nc'

In [None]:
!ncdump -h radar.nc

In [None]:
ds_org = xr.open_dataset(org_file)

In [None]:
ds_org

In [None]:
ds_org.isel(angle=0)['longitude']

In [None]:
dtree = xd.io.open_nexradlevel2_datatree(
    "/Users/syed44/Downloads/Others/Git_Stuff/PYART_CAPPI_TEST/KTLH_HELENE/KTLH20240926_233515_V06")

In [None]:
dtree.groups

In [None]:
dtree = dtree.xradar.georeference()

In [None]:
def get_geocoords(ds):
    """
    Converts Cartesian coordinates (x, y, z) in a radar dataset to geographic
    coordinates (longitude, latitude, altitude) using CRS transformation.

    Parameters
    ----------
    ds : xarray.Dataset
        Radar dataset with Cartesian coordinates.

    Returns
    -------
    xarray.Dataset
        Dataset with added 'lon', 'lat', and 'alt' coordinates and their attributes.
    """
    from pyproj import CRS, Transformer

    # Convert the dataset to georeferenced coordinates
    ds = ds.xradar.georeference()
    # Define source and target coordinate reference systems (CRS)
    src_crs = ds.xradar.get_crs()
    trg_crs = CRS.from_user_input(4326)  # EPSG:4326 (WGS 84)
    # Create a transformer for coordinate conversion
    transformer = Transformer.from_crs(src_crs, trg_crs)
    # Transform x, y, z coordinates to latitude, longitude, and altitude
    trg_y, trg_x, trg_z = transformer.transform(ds.x, ds.y, ds.z)
    # Assign new coordinates with appropriate attributes
    ds = ds.assign_coords(
        {
            "lon": (ds.x.dims, trg_x, xd.model.get_longitude_attrs()),
            "lat": (ds.y.dims, trg_y, xd.model.get_latitude_attrs()),
            "alt": (ds.z.dims, trg_z, xd.model.get_altitude_attrs()),
        }
    )
    return ds

In [None]:
dtree = dtree.xradar.map_over_sweeps(get_geocoords)

In [None]:
ds = dtree.xradar.to_cf1()

In [None]:
def filter_radar(ds):
    ds = ds.where(ds["DBZH"] > 0)
    return ds

In [None]:
dtree = dtree.xradar.map_over_sweeps(filter_radar)

In [None]:
ds

In [None]:
ds = ds.where(ds.range<250e3, drop=True)

In [None]:
ds['DBZH'][:720].plot(x='lon', y='lat')

In [None]:
dtree['sweep_0']['DBZH'].plot(x='lon', y='lat')

In [None]:
nrays = dtree['sweep_0']['azimuth'].size
elevation_angles = np.unique(ds['fixed_angle'].values)
# np.unique(np.around(elevation_angles*100)/100)
elevation_angles

In [None]:
import pyart

In [None]:
radar = pyart.io.read("/Users/syed44/Git_Stuff/Git_Projects/EAPS539_PROJECT/KVNX20220625_231233_V06.nc")

$$\text{PRF} = \frac{c}{2 \times \text{max range}}$$

In [None]:
c =  3.0 * 10**8
max_range = radar.instrument_parameters['unambiguous_range']['data'].max()
PRF = c/(2*max_range)
PRF

In [None]:
PRT = 1 / np.float32(321.19916)
print(PRT)  # Output in seconds

In [None]:
def calculate_radial_velocity_error(WRADH, PRT, N=None):
    """
    Calculate the radial velocity error using Doppler spectrum width and PRT.

    Parameters:
    WRADH : np.array
        Doppler spectrum width (meters per second)
    PRT : float
        Pulse Repetition Time (seconds)
    N : int, optional
        Number of independent samples (default: estimated from PRF)

    Returns:
    np.array
        Estimated radial velocity error (sigma_v) in m/s
    """
    # Calculate PRF from PRT
    PRF = 1 / PRT
    # Estimate N if not provided (typical assumption)
    if N is None:
        N = PRF  # Approximate number of independent pulses

    sigma_v = WRADH / np.sqrt(2 * N)
    
    return sigma_v

In [None]:
velocity_error = calculate_radial_velocity_error(ds['WRADH'], PRT)
velocity_error

In [None]:
ds['radial_velocity_err'] = velocity_error

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(ds['VRADH'].mean('range'))
plt.plot(velocity_error.mean('range'))

In [None]:
def get_fm128(ds):
    ds = ds[['DBZH', 'VRADH', 'radial_velocity_err', 'fixed_angle', 'azimuth']]
    ds = ds.rename({'DBZH':'reflectivity', 'VRADH':'radial_velocity',
           'lon':'longitude', 'lat':'latitude', 'alt':'altitude'})
    ds['angle'] = np.unique(ds['fixed_angle'].values)
    ds['degree'] = np.around(np.arange(0, nrays, 1) + 0.5)
    ds = ds.reset_coords(
        ['longitude', 'latitude', 'altitude']
    ).drop_vars(['azimuth', 'fixed_angle'])
    ds = ds.rename({'range':'distance'}).drop_vars('distance')
    
    # Extract minimum time (assumes time is a coordinate)
    time = ds.time.min()

    # Verify original shape of reflectivity
    reflectivity_values = ds['reflectivity'].values
    print("Original reflectivity shape:", reflectivity_values.shape)

    # Compute expected shape
    time_dim = 1  # Since we are adding a new time axis
    angle_dim = len(ds.angle)  # Number of unique angles
    degree_dim = len(ds.degree)  # Azimuthal rays count
    distance_dim = ds.sizes['distance']  # Distance bins count

    # Handle shape mismatches
    actual_shape = reflectivity_values.shape

    if len(actual_shape) == 3:
        # If missing time dimension, add it
        reflectivity_values = reflectivity_values[np.newaxis, :, :, :]
    elif len(actual_shape) == 2:
        # If missing angle and time dimensions, expand
        reflectivity_values = reflectivity_values[np.newaxis, np.newaxis, :, :]

    # Ensure correct shape before reshaping
    reflectivity_values = reflectivity_values.reshape(
        [time_dim, angle_dim, degree_dim, distance_dim]
    )

    # Create new dataset
    ds_new = xr.Dataset()

    # Assign stacked reflectivity to ds_new
    ds_new['reflectivity'] = (['time', 'angle', 'degree', 'distance'], reflectivity_values)

    # Assign coordinates
    ds_new = ds_new.assign_coords({
        'time': [time],
        'angle': ds.angle,
        'degree': ds.degree,
        'distance': ds['distance'].values
    })

    return ds_new

In [None]:
dd = get_fm128(ds)
dd

In [None]:
dd['reflectivity'].values.reshape([-1]).reshape([1, 14, 720, 992])

In [None]:
#### dd['reflectivity'].values.reshape([2, dd.angle.size, dd.degree.size, dd.distance.size])

In [None]:
ds#.sizes['azimuth']

In [None]:
ds_org

In [None]:
'''
description:    Python library to write WRFDA fm128_radar ascii files
license:        APACHE 2.0
author:         Ronald van Haren, NLeSC (r.vanharen@esciencecenter.nl)
'''

import numpy


class write_fm128_radar:
    '''
    Class module that writes write radar data to FM128_RADAR ascii
    format that can be used in WRFDA data assimiliation

    :param radar_name: name of radar
    :param lat0: latitude of radar station [deg]
    :param lon0: longitude of radar station [deg]
    :param elv0: elevation of radar station [m]
    :param date: date of observation
    :param lat: latitude of measurement point [deg]
    :param lon: longitude of measurement point [deg]
    :param elv: elevation of measurement point [m]
    :param rf: reflectivity
    :param rf_qc: quality control flag reflectivity
    :param rf_err: error on reflectivity measurement
    :param rv: radial velocity
    :param rv_qc: quality control flag radial velocity
    :param rv_err: error on radial velocity
    :param outfile: output filename of FM128_RADAR ascii file
    :param single: has reflection angle its own distinct lon/lat grid?
    :type radar_name: str
    :type lat0: float
    :type lon0: float
    :type elv0: float
    :type date:  datetime.datetime
    :type lat: numpy.ndarray
    :type lon: numpy.ndarray
    :type elv: numpy.ndarray
    :type rf: numpy.ndarray
    :type rf_qc: numpy.ndarray
    :type rf_err: numpy.ndarray
    :type rv: numpy.ndarray
    :type rv_qc: numpy.ndarray
    :type rv_err: numpy.ndarray
    :type outfile: str
    :type single: bool
    '''
    def __init__(self, radar_name, lat0, lon0, elv0, date, lat,
                 lon, elv, rf, rf_qc, rf_err,
                 rv, rv_qc, rv_err, outfile='fm128_radar.out', single=True):
        if ((isinstance(radar_name, (list, numpy.ndarray))
             and (len(radar_name) > 1))):
            # multiple radars in output file
            nrad = len(radar_name)
            # convert date to string
            dstring = [d.strftime('%Y-%m-%d %H:%M:%S') for d in date]
            self.init_file(nrad, outfile)
            for r_int in range(0, nrad):
                if single:
                    max_levs = numpy.shape(elv[r_int])[0]
                else:
                    max_levs = 1
                np = self.get_number_of_points(rf[r_int])
                # number of points: degrees * distance
                self.write_header(radar_name[r_int], lon0[r_int], lat0[r_int],
                                  elv0[r_int], dstring[r_int], np,
                                  max_levs)
                if single:
                    self.write_data_single(dstring[r_int], lat[r_int],
                                           lon[r_int], elv0[r_int], elv[r_int],
                                           rv[r_int], rv_qc[r_int],
                                           rv_err[r_int], rf[r_int],
                                           rf_qc[r_int], rf_err[r_int])
                else:
                    self.write_data(dstring[r_int], lat[r_int], lon[r_int],
                                    elv0[r_int], elv[r_int], rv[r_int],
                                    rv_qc[r_int], rv_err[r_int], rf[r_int],
                                    rf_qc[r_int], rf_err[r_int])
        else:
            # one radar in output file
            self.init_file(1, outfile)
            if single:
                max_levs = numpy.shape(elv)[0]
            else:
                max_levs = 1
            np = self.get_number_of_points(rf)
            dstring = date.strftime('%Y-%m-%d %H:%M:%S')
            self.write_header(radar_name, lon0, lat0, elv0, dstring, np,
                              max_levs)
            if single:
                self.write_data_single(dstring, lat, lon, elv0, elv, rv, rv_qc,
                                       rv_err, rf, rf_qc, rf_err)
            else:
                self.write_data(dstring, lat, lon, elv0, elv, rv, rv_qc,
                                rv_err, rf, rf_qc, rf_err)
        self.close_file()

    def init_file(self, nrad, outfile):
        '''
        Initialize output file

        :param nrad: number of radars in output file
        :param outfile: name of output file
        :type nrad: int
        :type outfile: str
        '''
        self.f = open(outfile, 'w')
        fmt = "%14s%3i"
        self.f.write(fmt % ("TOTAL RADAR = ", nrad))
        self.f.write("\n")
        self.f.write("%s" % ("#-----------------------------#"))
        self.f.write("\n")
        self.f.write("\n")

    def close_file(self):
        '''
        Close output file
        '''
        self.f.close()

    def write_header(self, radar_name, lon0, lat0, elv0, date, np, max_levs):
        '''
        Write the radar specific header to the output file

        :param radar_name: name of radar
        :param lat0: latitude of radar station [deg]
        :param lon0: longitude of radar station [deg]
        :param elv0: elevation of radar station [m]
        :param date: date of observation
        :param np: total number of measurement points for the radar
        :param max_levs: number of vertical levels
        :type radar_name: str
        :type lat0: float
        :type lon0: float
        :type elv0: float
        :type date:  datetime.datetime
        :type np: int
        :type max_levs: int
        '''
        # define header format
        fmt = "%5s%2s%12s%8.3f%2s%8.3f%2s%8.1f%2s%19s%6i%6i"
        # add temporary test data
        name = 'RADAR'
        hor_spacing = ''
        self.f.write(fmt % (name, hor_spacing, radar_name, lon0, hor_spacing,
                            lat0, hor_spacing, elv0, hor_spacing, date,
                            np, max_levs))
        self.f.write("\n")
        self.f.write("%s" % (
            '#---------------------------------------------------------#'))
        self.f.write("\n")
        self.f.write("\n")

    @staticmethod
    def get_number_of_points(rf):
        '''
        Return the total number of points for a radar:
            - For a masked array this all all points that are not masked
            - For a normal array this is all points
 
        :param rf: (masked) array of reflectivity measurements
        :type rf: numpy.ndarray
        :returns: total number of measurement points of the radar
        :rtype: int
        '''
        try:
            # masked array
            return len(rf.count(axis=0).flatten().nonzero()[0])
        except AttributeError:
            # Fallback for a non-masked array
            return len(rf[0, :].flatten())

    @staticmethod
    def get_levs_point(rf_data_point):
        '''
        Return the number of levels for a data point
        - For a masked array this all all points that are not masked
        - For a normal array this is all points

        :param rf_data_point: (masked) array of reflectivity data points
        :type rf_data_point: numpy.ndarray
        :returns: number of levels for a data point
        :rtype: int
        '''
        try:
            return rf_data_point.count()
        except AttributeError:
            try:
                return len(rf_data_point)
            except TypeError:
                return 1

    def write_data(self, date, lat, lon, elv0, elv, rv_data, rv_qc, rv_err,
                   rf_data, rf_qc, rf_err):
        '''
        Write radar measurements to the output file

        :param date: date of observation
        :param lat: latitude of measurement point [deg]
        :param lon: longitude of measurement point [deg]
        :param elv0: elevation of radar station [m]
        :param elv: elevation of measurement point [m]
        :param rv_data: radial velocity
        :param rv_qc: quality control flag radial velocity
        :param rv_err: error on radial velocity
        :param rf_data: reflectivity
        :param rf_qc: quality control flag reflectivity
        :param rf_err: error on reflectivity measurement
        :type date:  datetime.datetime
        :type lat: numpy.ndarray
        :type lon: numpy.ndarray
        :type elv0: float
        :type elv: numpy.ndarray
        :type rv_data: numpy.ndarray
        :type rv_qc: numpy.ndarray
        :type rv_err: numpy.ndarray
        :type rf_data: numpy.ndarray
        :type rf_qc: numpy.ndarray
        :type rf_err: numpy.ndarray
        '''
        fmt = "%12s%3s%19s%2s%12.3f%2s%12.3f%2s%8.1f%2s%6i"
        hor_spacing = ''
        # loop over horizontal data points
        for m in range(0, numpy.shape(lat)[0]):  # vertical levels
            for i in range(0, numpy.shape(lat)[1]):
                for j in range(0, numpy.shape(lat)[2]):
                    levs = self.get_levs_point(rf_data[m, i, j])
                    if levs > 0:
                        # Only write the output data if there is
                        # at least 1 vertical level
                        # with reflectivity data
                        self.f.write(fmt % ('FM-128 RADAR', hor_spacing, date,
                                            hor_spacing, lat[m, i, j],
                                            hor_spacing, lon[m, i, j],
                                            hor_spacing, elv0, hor_spacing,
                                            levs))
                        self.f.write("\n")
                        # loop over vertical elevations for each radar
                        if hasattr(rf_data, 'mask'):
                            if rf_data.mask[m, i, j]:
                                continue
                            else:
                                self.write_measurement_line(hor_spacing,
                                                            elv[m,i,j],
                                                            rv_data[m,i,j],
                                                            rv_qc[m,i,j],
                                                            rv_err[m,i,j],
                                                            rf_data[m,i,j],
                                                            rf_qc[m,i,j],
                                                            rf_err[m,i,j])
                        else:
                            self.write_measurement_line(hor_spacing,
                                                        elv[m,i,j],
                                                        rv_data[m,i,j],
                                                        rv_qc[m,i,j],
                                                        rv_err[m,i,j],
                                                        rf_data[m,i,j],
                                                        rf_qc[m,i,j],
                                                        rf_err[m,i,j])
                            
    def write_data_single(self, date, lat, lon, elv0, elv, rv_data, rv_qc,
                          rv_err, rf_data, rf_qc, rf_err):
        '''
        Write radar measurements to the output file

        :param date: date of observation
        :param lat: latitude of measurement point [deg]
        :param lon: longitude of measurement point [deg]
        :param elv0: elevation of radar station [m]
        :param elv: elevation of measurement point [m]
        :param rv_data: radial velocity
        :param rv_qc: quality control flag radial velocity
        :param rv_err: error on radial velocity
        :param rf_data: reflectivity
        :param rf_qc: quality control flag reflectivity
        :param rf_err: error on reflectivity measurement
        :type date:  datetime.datetime
        :type lat: numpy.ndarray
        :type lon: numpy.ndarray
        :type elv0: float
        :type elv: numpy.ndarray
        :type rv_data: numpy.ndarray
        :type rv_qc: numpy.ndarray
        :type rv_err: numpy.ndarray
        :type rf_data: numpy.ndarray
        :type rf_qc: numpy.ndarray
        :type rf_err: numpy.ndarray
        '''
        fmt = "%12s%3s%19s%2s%12.3f%2s%12.3f%2s%8.1f%2s%6i"
        hor_spacing = ''
        # loop over horizontal data points
        for i in range(0, numpy.shape(lat)[0]):
            for j in range(0, numpy.shape(lat)[1]):
                levs = self.get_levs_point(rf_data[:, i, j])
                if levs > 0:
                    # Only write the output data if there is
                    # at least 1 vertical level
                    # with reflectivity data
                    self.f.write(fmt %
                                 ('FM-128 RADAR', hor_spacing, date,
                                  hor_spacing, lat[i, j], hor_spacing,
                                  lon[i, j], hor_spacing, elv0, hor_spacing,
                                  levs))
                    self.f.write("\n")
                    # loop over vertical elevations for each radar
                    for m in range(0, numpy.shape(elv)[0]):
                        if hasattr(rf_data, 'mask'):
                            if rf_data.mask[m, i, j]:
                                continue
                            else:
                                self.write_measurement_line(hor_spacing,
                                                            elv[m,i,j],
                                                            rv_data[m,i,j],
                                                            rv_qc[m,i,j],
                                                            rv_err[m,i,j],
                                                            rf_data[m,i,j],
                                                            rf_qc[m,i,j],
                                                            rf_err[m,i,j])
                        else:
                            self.write_measurement_line(hor_spacing,
                                                        elv[m,i,j],
                                                        rv_data[m,i,j],
                                                        rv_qc[m,i,j],
                                                        rv_err[m,i,j],
                                                        rf_data[m,i,j],
                                                        rf_qc[m,i,j],
                                                        rf_err[m,i,j])

    def write_measurement_line(self, hor_spacing, elv,
                               rv_data, rv_qc, rv_err,
                               rf_data, rf_qc, rf_err):
        '''
        Write measurement line to output file

        :param hor_spacing: horizontal spacing
        :param elv: elevation of measurement point [m]
        :param rv_data: radial velocity
        :param rv_qc: quality control flag radial velocity
        :param rv_err: error on radial velocity
        :param rf_data: reflectivity
        :param rf_qc: quality control flag reflectivity
        :param rf_err: error on reflectivity measurement
        :type elv: float
        :type rv_data: float
        :type rv_qc: float
        :type rv_err: float
        :type rf_data: float
        :type rf_qc: float
        :type rf_err: float
        '''
        fmt_2 = "%3s%12.1f%12.3f%4i%12.3f%2s%12.3f%4i%12.3f%2s"
        self.f.write(fmt_2 % (hor_spacing, elv,
                              rv_data, rv_qc,
                              rv_err, hor_spacing,
                              rf_data, rf_qc,
                              rf_err, hor_spacing))
        self.f.write("\n")
        
'''
description:    Download KNMI radar hdf5 files (inside a tar archive) from ftp
license:        APACHE 2.0
author:         Ronald van Haren, NLeSC (r.vanharen@esciencecenter.nl)
'''

from netCDF4 import Dataset
from netCDF4 import num2date
import datetime
import numpy
# from fm128_radar.write_fm128_radar import *
from scipy.interpolate import griddata
# from geopy.distance import vincenty
from geopy.distance import geodesic




class convert_to_ascii:
    def __init__(self, filename, outfile, pdry, interpolate=False):
        self.outputfile = outfile
        self.pdry = pdry
        self.check_file_exists(filename, 'r')
        self.read_netcdf(filename)
        if interpolate:
            self.interpolate(interpolate)
            self.set_mask_interpolate()
            self.write_ascii(single=True)
            # self.plot_refl()
        else:
            self.set_mask()
            self.write_ascii(single=False)

    def plot_refl(self):
        from mpl_toolkits.basemap import Basemap, cm
        import matplotlib.pyplot as plt
        # plot
        x1 = numpy.min(self.longitude)
        x2 = numpy.max(self.longitude)
        y1 = numpy.min(self.latitude)
        y2 = numpy.max(self.latitude)

        for alt in range(0, len(self.altitude)):
            fig = plt.figure()
            map = Basemap(resolution='i', projection='merc',
                          llcrnrlat=y1, urcrnrlat=y2, llcrnrlon=x1, urcrnrlon=x2,
                          lat_ts=(x1+x2)/2)
            map.drawcoastlines(linewidth=0.25)
            map.drawcountries(linewidth=0.25)
            map.drawmeridians(numpy.arange(0, 360, 30))
            map.drawparallels(numpy.arange(-90, 90, 30))
            x, y = map(self.longitude, self.latitude)
            cs = map.pcolormesh(
                x, y, self.rf[alt, :], cmap=cm.s3pcpn_l, vmin=0, vmax=50)
            cbar = map.colorbar(cs, spacing='proportional',
                                location='right', pad="5%")
            cbar.set_label('dBz')
            plt.savefig('rf_' + str(alt) + '.png', bbox_inches='tight')
            plt.close(fig)
            fig = plt.figure()
            map = Basemap(resolution='i', projection='merc',
                          llcrnrlat=y1, urcrnrlat=y2, llcrnrlon=x1, urcrnrlon=x2,
                          lat_ts=(x1+x2)/2)
            map.drawcoastlines(linewidth=0.25)
            map.drawcountries(linewidth=0.25)
            map.drawmeridians(numpy.arange(0, 360, 30))
            map.drawparallels(numpy.arange(-90, 90, 30))
            x, y = map(self.longitude, self.latitude)
            cs = map.pcolormesh(
                x, y, self.rv[alt, :], cmap=cm.s3pcpn_l, vmin=-20, vmax=20)
            cbar = map.colorbar(cs, spacing='proportional',
                                location='right', pad="5%")
            cbar.set_label('m/s')
            plt.savefig('rv_' + str(alt) + '.png', bbox_inches='tight')
            plt.close(fig)
            fig = plt.figure()
            map = Basemap(resolution='i', projection='merc',
                          llcrnrlat=y1, urcrnrlat=y2, llcrnrlon=x1, urcrnrlon=x2,
                          lat_ts=(x1+x2)/2)
            map.drawcoastlines(linewidth=0.25)
            map.drawcountries(linewidth=0.25)
            map.drawmeridians(numpy.arange(0, 360, 30))
            map.drawparallels(numpy.arange(-90, 90, 30))
            x, y = map(self.longitude, self.latitude)
            cs = map.pcolormesh(
                x, y, self.altitude[alt, :], cmap=cm.s3pcpn_l, vmin=0, vmax=10000)
            cbar = map.colorbar(cs, spacing='proportional',
                                location='right', pad="5%")
            cbar.set_label('m')
            plt.savefig('height_' + str(alt) + '.png', bbox_inches='tight')
            plt.close(fig)

    def check_file_exists(self, filepath, mode):
        ''' Check if a file exists and is accessible. '''
        try:
            f = open(filepath, mode)
            f.close()
        except IOError as e:
            raise IOError('File ' + filepath + ' is not accessible')

    def read_netcdf(self, netcdffile):
        '''
        Read the data from the netcdffile
        The netcdf file should contain at least reflectivity data
        '''
        ncfile = Dataset(netcdffile, 'r')
        # reflectivity
        try:
            self.rf = ncfile.variables['reflectivity'][0, :]
        except KeyError:
            raise KeyError('netCDF file ' + netcdffile +
                           ' does not contain any reflectivity data')
        try:
            self.rf_qc = ncfile.variables['reflectivity_qc'][0, :]
        except KeyError:
            self.rf_qc = numpy.zeros(numpy.shape(self.rf))
        try:
            self.rf_err = ncfile.variables['reflectivity_err'][0, :]
        except KeyError:
            self.rf_err = 2.0 * numpy.ones(numpy.shape(self.rf))
        # radial velocity
        try:
            self.rv = ncfile.variables['radial_velocity'][0, :]
        except KeyError:
            self.rv = -88 * numpy.ones(numpy.shape(self.rf))
        try:
            self.rv_qc = ncfile.variables['radial_velocity_qc'][0, :]
        except KeyError:
            self.rv_qc = numpy.zeros(numpy.shape(self.rf))
        try:
            self.rv_err = ncfile.variables['radial_velocity_err'][0, :]
        except KeyError:
            self.rv_err = -888888 * numpy.ones(numpy.shape(self.rf))
        # lon/lat/altitude/time
        self.latitude = ncfile.variables['latitude'][:]
        self.longitude = ncfile.variables['longitude'][:]
        self.altitude = ncfile.variables['altitude'][:]
        time_in = ncfile.variables['time']
        # extract radar information from global attributes
        self.lon0 = float(ncfile.radar_longitude)
        self.lat0 = float(ncfile.radar_latitude)
        self.elv0 = float(ncfile.radar_height)
        self.radar_name = ncfile.radar_name
        # convert integer time to correct string
        self.time = num2date(time_in[0], calendar=time_in.calendar,
                             units=time_in.units)
        ncfile.close()

    def get_interpolate_grid(self, filename):
        '''
        Get grid to interpolate to from netcdf file containing XLAT/XLONG
        If XLAT/XLONG is (time, south_north_stag, west_east), use timestep 0
        '''
        ncfile = Dataset(filename, 'r')
        XLAT = ncfile.variables['XLAT'][0, :]
        if len(numpy.shape(XLAT)) == 1:
            XLAT = ncfile.variables['XLAT'][:]
        XLONG = ncfile.variables['XLONG'][0, :]
        if len(numpy.shape(XLONG)) == 1:
            XLONG = ncfile.variables['XLONG'][:]
        return XLAT, XLONG

    def interpolate(self, filename):
        '''
        interpolate to regular grid using nearest neighbor interpolation
        '''
        # get target grid from filename
        xlat, xlong = self.get_interpolate_grid(filename)
        # hardcode position of de Bilt for now
        LAT_bilt = 52.10168
        LON_bilt = 5.17834
        height_bilt = 44.
        # calculate horizontal distance for each xlat, xlong to de bilt
        # define source x,y,z
        x = self.longitude
        y = self.latitude
        z = self.altitude
        # source grid
        src = numpy.vstack((x.reshape(-1), y.reshape(-1), z.reshape(-1)))
        # source values, set last measurement in range to nan so we don't
        # extrapolate when using nearest neighbor
        self.rf[:, :, -1] = numpy.nan  # reflectivity
        vals_rf = self.rf.reshape(-1)
        self.rv[:, :, -1] = numpy.nan  # radial velocity
        vals_rv = self.rv.reshape(-1)
        self.rv_err[:, :, -1] = numpy.nan  # variance radial velocity
        vals_rv_err = self.rv_err.reshape(-1)
        # original shape
        orig_shape = numpy.shape(self.rf)
        # define target coordinates
        xtrg = numpy.tile(xlong.reshape(-1), 13)
        ytrg = numpy.tile(xlat.reshape(-1), 13)
        # calculate horizontal distance for each gridpoint from base
        base_point = (LAT_bilt, LON_bilt)
        xlat_target = xlat.reshape(-1)
        xlon_target = xlong.reshape(-1)
        hor_dist = [geodesic(base_point, (xlat_target[x], xlon_target[x])).m for
                    x in range(0, len(xlat_target))]
        ke = 1.3333333333  # adjustment factor to account for refractivity
        re = 6370040.0  # radius Earth
        angles = [0.4, 0.8, 1.1, 2.0, 3.0, 4.5, 6.0, 8.0, 10.0, 12.0,
                  15.0, 20.0,  25.0]
        # calculate altitude for each beam angle at each grid point
        self.altitude = numpy.array([(numpy.sqrt((hor_dist/numpy.cos(numpy.deg2rad(theta)))**2
                                     + (ke*re)**2 + 2*(hor_dist/numpy.cos(
                                         numpy.deg2rad(theta)))*ke*re *
                                     numpy.sin(numpy.deg2rad(theta))
                                     ) - ke*re).reshape(numpy.shape(xlat)) for
                                     theta in angles]) + height_bilt
        ztrg = self.altitude[:, :].reshape(-1)
        trg = numpy.vstack((xtrg, ytrg, ztrg))
        # interpolate (using nearest neighbor)
        # scipy.interpolate.griddata is faster than wradlib.ipol
        self.rf = griddata(src.T, vals_rf, trg.T, method='nearest', rescale=True
                           ).reshape(numpy.shape(self.altitude))
        self.rv = griddata(src.T, vals_rv, trg.T, method='nearest', rescale=True
                           ).reshape(numpy.shape(self.altitude))
        self.rv_err = griddata(src.T, vals_rv_err, trg.T, method='nearest', rescale=True
                               ).reshape(numpy.shape(self.altitude))
        # use the single lon/lat value for the output we interpolated to
        self.longitude = xlong
        self.latitude = xlat
        # set the error to 2.0dBz
        self.rf_err = 2.0 * numpy.ones(numpy.shape(self.rf))
        self.rf_qc = numpy.zeros(numpy.shape(self.rf))
        self.rv_qc = numpy.zeros(numpy.shape(self.rf))

    def set_mask(self):
        '''
        set mask
        '''
        # create random mask with approx self.pdry% of data points
        # mask for altitudes above 10 km
        if self.pdry:
            if (self.pdry <= 100) and (self.pdry > 0):
                # correction factor on self.pdry to take into account dry points
                # max should not be larger than 100
                self.pdry = min(100,
                                len(self.rf.flatten())/float(
                                    numpy.sum(self.rf < 0)) * self.pdry)
                max_int = numpy.int(numpy.round((1/(self.pdry/100.))))
            else:
                max_int = 1
        else:
            max_int = 1
        # mask dry points
        mask = numpy.random.randint(0, max_int,
                                    size=self.rf.shape).astype(numpy.bool)
        self.rf[self.rf < -30] = -30
        self.rv_err[self.rv <= -24] = -888888
        self.rv_qc[self.rv <= -24] = -88
        self.rv[self.rv <= -24] = -888888
        if not self.pdry == 0:
            mask_dry = (mask) & (self.rf < 0)
        else:
            mask_dry = (self.rf < 0)
        # mask altitude
        mask_altitude = self.altitude > 6000
        # mask NaN
        mask_nan = numpy.isnan(self.rf)
        # combine masks and apply to reflectivity
        self.rf = numpy.ma.masked_where((mask_dry | mask_altitude | mask_nan),
                                        self.rf)
        self.rv = numpy.ma.masked_where((mask_dry | mask_altitude | mask_nan),
                                        self.rv)

    def set_mask_interpolate(self):
        '''
        set mask
        '''
        # create random mask with approx self.pdry% of data points
        # mask for altitudes above 10 km
        if self.pdry:
            if (self.pdry <= 100) and (self.pdry > 0):
                # correction factor on self.pdry to take into account dry points
                # max should not be larger than 100
                self.pdry = min(100,
                                len(self.rf[0, :].flatten())/float(
                                    numpy.sum(self.rf[0, :] < 0)) * self.pdry)
                max_int = numpy.int(numpy.round((1/(self.pdry/100.))))
            else:
                max_int = 1
        else:
            max_int = 1
        # mask dry points
        mask_xy = numpy.random.randint(0, max_int, size=self.rf[0, :].shape
                                       ).astype(numpy.bool)
        mask = numpy.vstack([mask_xy[numpy.newaxis, :]]
                            * numpy.shape(self.rf)[0])
        self.rf[self.rf < -30] = -30
        self.rv_err[self.rv <= -24] = -888888
        self.rv_qc[self.rv <= -24] = -88
        self.rv[self.rv <= -24] = -888888
        if not self.pdry == 0:
            mask_dry = (mask) & (self.rf <= 7)
        else:
            mask_dry = (self.rf <= 7)
        # mask altitude
        mask_altitude = self.altitude > 6000
        # mask NaN
        mask_nan = numpy.isnan(self.rf)
        # combine masks and apply to reflectivity
        self.rf = numpy.ma.masked_where((mask_dry | mask_altitude | mask_nan),
                                        self.rf)
        self.rv = numpy.ma.masked_where((mask_dry | mask_altitude | mask_nan),
                                        self.rv)

    def write_ascii(self, single=False):
        '''
        Write data to fm128_radar ascii file
        '''
        write_fm128_radar(self.radar_name, self.lat0, self.lon0, self.elv0,
                          self.time, self.latitude, self.longitude, self.altitude,
                          self.rf, self.rf_qc, self.rf_err,
                          self.rv, self.rv_qc, self.rv_err, outfile=self.outputfile,
                          single=single)


# if __name__ == "__main__":
#     convert_to_ascii("radar.nc", 'radar_out.ascii', 0)

In [None]:
filname = "radar.nc"
convert_to_ascii("radar.nc", 'radar_out1.ascii', 0, False).plot_refl()

In [None]:
np.diff(np.unique(ds.sweep_start_ray_index))

In [None]:
for swp in dtree.match('sweep_*'):
    print(dtree[swp].data_vars, "\n")

In [None]:
dtree['sweep_0']

In [None]:
subset_dict = {}
for swp in dtree.match("sweep_*"):
    if dtree[swp]['azimuth'].size == nrays:
        subset_dict[swp] = dtree[swp]  # Add to subset
        
# Create a new DataTree from the filtered dictionary
subset_tree = xr.DataTree.from_dict(subset_dict)


In [None]:
# subset_tree = subset_tree.root.update(dtree.root)

In [None]:
subset_tree.update(dtree.root.to_dataset())

In [None]:
for swp in dtree.match("radar_*"):
    subset_tree[swp] = dtree[swp]

In [None]:
subset_tree.xradar.to_cf1()

In [None]:
# Create a new DataTree with only the first 3 sweeps
subset_tree = xr.DataTree.from_dict({
    '/sweep_0': dtree['/sweep_0'],
    '/sweep_1': dtree['/sweep_1'],
    '/sweep_2': dtree['/sweep_2']
})


In [None]:
subset_tree

In [None]:
dtree