#### Prototype of the unified faeture-environment extraction with a fixed box size

In [None]:
import os
import sys
import xarray as xr
import numpy as np
import matplotlib
import pandas as pd
from matplotlib import pyplot as plt
from datetime import datetime
from pathlib import Path
import cartopy.crs as ccrs
import warnings

In [None]:
warnings.filterwarnings('ignore')

In [None]:
def coordinates_processors(data):
    """ 
    converting longitude/latitude into lon/lat 
    """

    coord_names = []
    for coord_name in data.coords:
        coord_names.append(coord_name)

    if (set(coord_names) & set(['longitude','latitude'])): # if coordinates set this way...

        data2 = data.rename({'latitude': 'lat'})
        data2 = data2.rename({'longitude': 'lon'})
    else:
        data2 = data

    # check if latitutde is decreasing
    if (data2.lat[1] - data2.lat[0]) < 0:
        data2 = data2.reindex(lat=list(reversed(data2.lat))) # flipping latitude accoordingly

    return data2

In [None]:
class ds_feature_environment:
    
    __version__ = "1.0beta"
    
    def __init__(self):

        self.name = None                          # name of the feature-environment dataset
        self.track_data = None                    # xarray dataset
        self.object_data = None                   # xarray dataset
        self.env_data = None                      # xarray dataset
        self.feature_data_sources = None          # e.g., ERA5, GPM-IMERG+MERGE-IR
        self.environmental_data_sources = None    # e.g., ERA5
        self.track_frequency = None               # hourly
        self.env_frequency = None                 # hourly
        self.lon_env = None                       # longitude of the env. data
        self.lat_env = None                       # latitude of the env. data
        self.lon_feature = None                   # longitude of the feature data
        self.lat_feature = None                   # latitude of the feature data
        self.feature_track = None
        self.feature_mask = None
        self.track_dir = None
        self.env2d_dir = None
        self.env3d_dir = None
        self.envderive_dir = None
        
    def create_featenv_directory(self, path_dir):
        """
        create subdirectories under the given path_dir
        """

        if path_dir.exists():
            print('the given directory already exists ...')
            main_dir = Path(path_dir)
            self.track_dir = Path( str(main_dir) + '/feature_catalogs/track' )
            self.env2d_dir = Path( str(main_dir) + '/environment_catalogs/VARS_2D' )
            self.env3d_dir = Path( str(main_dir) + '/environment_catalogs/VARS_3D' )
            self.envderive_dir = Path( str(main_dir) + '/environment_catalogs/VARS_derived' )
        else:
            print('generate feature-environment data directory...')

            main_dir = Path(path_dir)
            featcats_dir = main_dir / 'feature_catalogs'
            envcats_dir = main_dir / 'environment_catalogs'
            feattrack_dir = featcats_dir / 'track'
            featobj_dir = featcats_dir / 'object'
            env2d_dir = envcats_dir / 'VARS_2D'
            env3d_dir = envcats_dir / 'VARS_3D'
            envderive_dir = envcats_dir / 'VARS_derived'

            os.system('mkdir {}'.format(main_dir))
            os.system('mkdir {}'.format(featcats_dir))
            os.system('mkdir {}'.format(envcats_dir))
            print('Create main directoy: {}'.format(main_dir))
            print('{}'.format(featcats_dir))
            print('{}'.format(envcats_dir))

            if self.feature_track:
                os.system('mkdir {}'.format(feattrack_dir))
                print(feattrack_dir)
                self.track_dir = feattrack_dir
                if self.feature_mask:
                    os.system('mkdir {}'.format(feattrack_dir/'2D_mask'))
                    print(feattrack_dir/'2D_mask')
                    self.featmask_dir = feattrack_dir/'2D_mask'

            else:
                os.system('mkdir {}'.format(featobj_dir))
                print(featobj_dir)
                if self.feature_mask:
                    os.system('mkdir {}'.format(featobj_dir/'2D_mask'))
                    print(featobj_dir/'2D_mask')
                    self.featmask_dir = featobj_dir/'2D_mask'

            os.system('mkdir {}'.format(env2d_dir))
            print(env2d_dir)
            self.env2d_dir = env2d_dir

            os.system('mkdir {}'.format(envderive_dir))
            print(envderive_dir)
            self.envderive_dir = envderive_dir

            os.system('mkdir {}'.format(env3d_dir))
            print(env3d_dir)
            self.env3d_dir = env3d_dir
    
    def load_track_data(self, file_path):
        self.track_data = xr.open_dataset(file_path)

        return self.track_data

    def load_object_data(self, file_path):
        self.object_data = xr.open_dataset(file_path)

        return self.track_data

    def locate_env_data(self, variable_name, path_dir):
        self.locate_env_data = {}

        if len(self.locate_env_data) == 0:
            self.locate_env_data[variable_name] = path_dir

    def locate_feature_data(self, variable_name, path_dir):
        self.locate_feature_data = {}

        if len(self.locate_feature_data) == 0:
            self.locate_feature_data[variable_name] = path_dir

    def get_track_info(self, track_number):

        track_info = self.track_data.sel(tracks=track_number)

        return track_info

    def get_object_info(self, object_id):

        obj_info = self.object_data.sel(object_id=object_id)

        return obj_info
    
    def get_environment_vars_track(self, track_id, lat_range, lon_range, p_level=None):
        
        if len(self.locate_env_data) == 0:
            raise ValueError("No environmental data located. Please call locate_env_data() first")
        
        else:
            
            track_info = self.get_track_info(track_number=track_id)
             
            lat_cen = track_info.meanlat.values # MCS lat centroid
            lon_cen = track_info.meanlon
            lon_cen = lon_cen.where(lon_cen >= 0, lon_cen+360) # converting to 0-360
            lon_cen = lon_cen.values
            
            # find out when the tracked MCS ends as indicated by NaT
            idx_end = np.where(np.isnat(track_info.base_time.values))[0][0] 

            data_chunk = []
            time_chunk = []
            
            for t in range(idx_end):

                time64 = track_info.base_time[t].values
                timestamp = (time64 - np.datetime64('1970-01-01T00:00:00Z')) / np.timedelta64(1, 's')
                time_sel = datetime.utcfromtimestamp(timestamp)
            
                # determine the env_data to be loaded            
                year = str(time_sel.year)
                month = str(time_sel.month).zfill(2)
                day = str(time_sel.day).zfill(2)
                hour = str(time_sel.hour).zfill(2)

                data_var = []
                for var in [i for i in self.locate_env_data.keys()]:
                    
                    filename = Path('/neelin2020/ERA-5/NC_FILES') /'{}'.format(year)/ 'era-5.{}.{}.{}.nc'.format(var,year,month)
                    data_file = xr.open_dataset(filename)
                    data_file = coordinates_processors(data_file)
                    
                    # find nearest ERA5 grid for the MCS centroid
                    idx_sel = np.argmin(np.abs(data_file.lon.values - lon_cen[t]))
                    lon_cen_reset = data_file.lon[idx_sel]
                    idx_sel = np.argmin(np.abs(data_file.lat.values - lat_cen[t]))
                    lat_cen_reset = data_file.lat[idx_sel]
                
                    lat_min = lat_cen_reset - lat_range/2
                    lat_max = lat_cen_reset + lat_range/2
                    lon_min = lon_cen_reset - lon_range/2
                    lon_max = lon_cen_reset + lon_range/2
                
                    data_extract = data_file.sel(lat=slice(lat_min, lat_max),
                                                        lon=slice(lon_min, lon_max))
                    data_extract = data_extract.sel(time=time_sel, method='nearest')
                    
                    # x-y grid poiints coordinate not lat-lon 
                    dlon = (data_file.lon[1] - data_file.lon[0]).values
                    dlat = (data_file.lat[1] - data_file.lat[0]).values
                    data_extract_xy = data_extract.interp(lon=np.linspace(data_extract.lon.min(), data_extract.lon.max(),int(lon_range/dlon)+1),
                                              lat=np.linspace(data_extract.lat.min(), data_extract.lat.max(),int(lat_range/dlat)+1))
                    # converting lat-lon into x-y coordinates
                    data_extract_xy = data_extract_xy.assign_coords(x=("lon", np.arange(len(data_extract_xy.lon))), y=("lat", np.arange(len(data_extract_xy.lat))))
                    data_extract_xy = data_extract_xy.swap_dims({'lon':'x', 'lat': 'y'}).drop('time')

                    if p_level is not None: # for 3-D data ERA5 only, with vertical dim. named "level"

                        data_extract_xy = data_extract_xy.sel(level=p_level) # update data_extract which is single layer
                            
                    data_var.append(data_extract_xy)
                data_var_merged = xr.merge(data_var) # merge variables into one xr.dataset
                data_chunk.append(data_var_merged)
                time_chunk.append(time_sel)
            
            # add base_time into the dataset
            data_chunk_xr = xr.concat(data_chunk, dim=pd.Index(range(len(data_chunk)),name='time'))
            ds_basetime_xr = xr.Dataset(data_vars=dict(base_time = (['time'], time_chunk)),
                                     coords=dict(time = (['time'], range(len(data_chunk)))))
            data_track_xr = xr.merge([data_chunk_xr, ds_basetime_xr])
            
            # save lat/lon into self
            self.lon_env = data_file.lon
            self.lat_env = data_file.lat
                                   
            return data_track_xr 
        
    def get_environment_vars_single(self, object_id, lat_range, lon_range, p_level=None):
        
        if len(self.locate_env_data) == 0:
            raise ValueError("No environmental data located. Please call locate_env_data() first")
        
        else:
            
            obj_info = self.get_object_info(object_id=object_id)
        
            lat_cen = obj_info.meanlat.values # MCS lat centroid
            lon_cen = obj_info.meanlon
            lon_cen = lon_cen.where(lon_cen >= 0, lon_cen+360) # converting to 0-360
            lon_cen = lon_cen.values

            data_chunk = []
                        
            time64 = obj_info.base_time.values
            timestamp = (time64 - np.datetime64('1970-01-01T00:00:00Z')) / np.timedelta64(1, 's')
            time_sel = datetime.utcfromtimestamp(timestamp)

            # determine the env_data to be loaded            
            year = str(time_sel.year)
            month = str(time_sel.month).zfill(2)
            day = str(time_sel.day).zfill(2)
            hour = str(time_sel.hour).zfill(2)

            data_var = []
            for var in [i for i in self.locate_env_data.keys()]:

                filename = Path('/neelin2020/ERA-5/NC_FILES') /'{}'.format(year)/ 'era-5.{}.{}.{}.nc'.format(var,year,month)
                data_file = xr.open_dataset(filename)
                data_file = coordinates_processors(data_file)

                # find nearest ERA5 grid for the MCS centroid
                idx_sel = np.argmin(np.abs(data_file.lon.values - lon_cen))
                lon_cen_reset = data_file.lon[idx_sel]
                idx_sel = np.argmin(np.abs(data_file.lat.values - lat_cen))
                lat_cen_reset = data_file.lat[idx_sel]

                lat_min = lat_cen_reset - lat_range/2
                lat_max = lat_cen_reset + lat_range/2
                lon_min = lon_cen_reset - lon_range/2
                lon_max = lon_cen_reset + lon_range/2

                data_extract = data_file.sel(lat=slice(lat_min, lat_max),
                                                    lon=slice(lon_min, lon_max))
                data_extract = data_extract.sel(time=time_sel, method='nearest')

                # x-y grid poiints coordinate not lat-lon 
                dlon = (data_file.lon[1] - data_file.lon[0]).values
                dlat = (data_file.lat[1] - data_file.lat[0]).values
                data_extract_xy = data_extract.interp(lon=np.linspace(data_extract.lon.min(), data_extract.lon.max(),int(lon_range/dlon)+1),
                                          lat=np.linspace(data_extract.lat.min(), data_extract.lat.max(),int(lat_range/dlat)+1))
                # converting lat-lon into x-y coordinates
                data_extract_xy = data_extract_xy.assign_coords(x=("lon", np.arange(len(data_extract_xy.lon))), y=("lat", np.arange(len(data_extract_xy.lat))))
                data_extract_xy = data_extract_xy.swap_dims({'lon':'x', 'lat': 'y'}).drop('time')

                if p_level is not None: # for 3-D data ERA5 only, with vertical dim. named "level"

                    data_extract_xy = data_extract_xy.sel(level=p_level) # update data_extract which is single layer

                data_var.append(data_extract_xy)
            data_var_merged = xr.merge(data_var) # merge variables into one xr.dataset
                                   
        return data_var_merged
        
    def get_feature_vars_track(self, track_id, lat_range, lon_range):
        
        if len(self.locate_feature_data) == 0:
            raise ValueError("No feature data located. Please call locate_feature_data() first")
        
        else:
            
            track_info = self.get_track_info(track_number=track_id)
             
            lat_cen = track_info.meanlat.values # MCS lat centroid
            lon_cen = track_info.meanlon
            lon_cen = lon_cen.where(lon_cen >= 0, lon_cen+360) # converting to 0-360
            lon_cen = lon_cen.values
            
            # find out when the tracked MCS ends as indicated by NaT
            idx_end = np.where(np.isnat(track_info.base_time.values))[0][0] 

            data_chunk = []
            time_chunk = []
            
            for t in range(idx_end):

                time64 = track_info.base_time[t].values
                timestamp = (time64 - np.datetime64('1970-01-01T00:00:00Z')) / np.timedelta64(1, 's')
                time_sel = datetime.utcfromtimestamp(timestamp)
            
                # determine the env_data to be loaded            
                year = str(time_sel.year)
                month = str(time_sel.month).zfill(2)
                day = str(time_sel.day).zfill(2)
                hour = str(time_sel.hour).zfill(2)

                data_var = []
                for var in [i for i in self.locate_feature_data.keys()]:
                    
                    filename = Path('/neelin2020/mcs_flextrkr/') /'{}0101.0000_{}0101.0000'.format(year,int(year)+1)/ 'mcstrack_{}{}{}_{}30.nc'.format(year,month,day,hour)
                    data_file = xr.open_dataset(filename)[var] # get the specified variable
                    data_file = coordinates_processors(data_file)
                    # regrid feature grids into the env. data if needed 
                    data_file = data_file.interp(lon=self.lon_env, lat=self.lat_env)
                    
                    # find nearest ERA5 grid for the MCS centroid
                    idx_sel = np.argmin(np.abs(data_file.lon.values - lon_cen[t]))
                    lon_cen_reset = data_file.lon[idx_sel]
                    idx_sel = np.argmin(np.abs(data_file.lat.values - lat_cen[t]))
                    lat_cen_reset = data_file.lat[idx_sel]
                
                    lat_min = lat_cen_reset - lat_range/2
                    lat_max = lat_cen_reset + lat_range/2
                    lon_min = lon_cen_reset - lon_range/2
                    lon_max = lon_cen_reset + lon_range/2
                
                    data_extract = data_file.sel(lat=slice(lat_min, lat_max),
                                                        lon=slice(lon_min, lon_max))
                    data_extract = data_extract.sel(time=time_sel, method='nearest')
                    
                    # x-y grid poiints coordinate not lat-lon 
                    dlon = (data_file.lon[1] - data_file.lon[0]).values
                    dlat = (data_file.lat[1] - data_file.lat[0]).values
                    data_extract_xy = data_extract.interp(lon=np.linspace(data_extract.lon.min(), data_extract.lon.max(),int(lon_range/dlon)+1),
                                              lat=np.linspace(data_extract.lat.min(), data_extract.lat.max(),int(lat_range/dlat)+1))
                    # converting lat-lon into x-y coordinates
                    data_extract_xy = data_extract_xy.assign_coords(x=("lon", np.arange(len(data_extract_xy.lon))), y=("lat", np.arange(len(data_extract_xy.lat))))
                    data_extract_xy = data_extract_xy.swap_dims({'lon':'x', 'lat': 'y'}).drop('time')
                    data_var.append(data_extract_xy)
                    
                data_var_merged = xr.merge(data_var) # merge variables into one xr.dataset
                data_chunk.append(data_var_merged)
                time_chunk.append(time_sel)
            
            # add base_time into the dataset
            data_chunk_xr = xr.concat(data_chunk, dim=pd.Index(range(len(data_chunk)),name='time'))
            ds_basetime_xr = xr.Dataset(data_vars=dict(base_time = (['time'], time_chunk)),
                                     coords=dict(time = (['time'], range(len(data_chunk)))))
            data_track_xr = xr.merge([data_chunk_xr, ds_basetime_xr], compat='override')
            
            # save lat/lon into self
            self.lon_feature = data_file.lon
            self.lat_feature = data_file.lat
                                   
            return data_track_xr 

## Application 
- Paths of feature track and environmental variables
- Call module and save the dataset
- Plotting 

#### Call feature-environment module

In [None]:
# call the feature-environemnt module
featenv = ds_feature_environment()
print('version: ', featenv.__version__)
featenv.name = 'MCS_FLEXTRKR'
featenv.feature_data_sources = 'GPM-IMERG; MERGE-IR'
featenv.environmental_data_sources = 'ERA5'
featenv.track_frequency = 'hourly'
featenv.env_frequency = 'hourly'
featenv.feature_track = True
featenv.feature_mask = True

# create directories according to the above descriptions 
main_dir = Path('/scratch/wmtsai/featenv_test/{}'.format(featenv.name))
featenv.create_featenv_directory(main_dir)

print("Feature data sources:", featenv.feature_data_sources)
print("Environmental data sources:", featenv.environmental_data_sources)

# 1. locate environment variables: variable names, direct paths
env_dir = Path('/neelin2020/ERA-5/NC_FILES/')
feat_dir = Path('/neelin2020/mcs_flextrkr/')
featenv.locate_env_data('T', env_dir)
featenv.locate_env_data.update({'q': env_dir})
featenv.locate_env_data.update({'ua': env_dir})
featenv.locate_env_data.update({'va': env_dir})
featenv.locate_env_data.update({'omega': env_dir})

featenv.locate_feature_data('cloudtracknumber_nomergesplit', feat_dir)
featenv.locate_feature_data.update({'precipitation': feat_dir})
featenv.locate_feature_data.update({'tb': feat_dir})

print('Environmental data located: \n',featenv.locate_env_data)
print('Feature data located: \n',featenv.locate_feature_data)

#### 1. feature tracks: MCS tracks as example 
(timestamps for each track, lat_centroid, lon_centroid)

In [None]:
# load feature track : global MCS tracks from FLEXTRKR in 2020 as example
processed_year = 2014
track_dir = Path('/neelin2020/mcs_flextrkr/mcs_stats/')
track_data = featenv.load_track_data(track_dir / 'mcs_tracks_final_extc_{}0101.0000_{}0101.0000.nc'.format(processed_year, processed_year+1))

# 2. a subset of MCSs over the tropical Indian Ocean (50-90, -10,10)
meanlon = track_data.meanlon.sel(times=0)
meanlat = track_data.meanlat.sel(times=0)
cond1 = (meanlon >= 50) & (meanlon <=90)
cond2 = (meanlat >= -10) & (meanlat <=10)
track_sub = np.intersect1d(np.where(cond1 == 1)[0], np.where(cond2 == 1)[0])

# update track_data with a small subset
featenv.track_data = track_data.isel(tracks=track_sub)

#### 2. "get_environment_vars" and return the individual feature-env data

In [None]:
%%time
# extract feat-env data for individual tracks
ds_merged = []
for track in featenv.track_data.tracks.values[:3]: # take the first 100 tracks as example
    ds_env_vars = featenv.get_environment_vars_track(track_id=track, lat_range=10, lon_range=10)
    ds_feat_vars = featenv.get_feature_vars_track(track_id=track, lat_range=10, lon_range=10)
    ds_vars = xr.merge([ds_env_vars, ds_feat_vars], compat='override') # some float differeces. TBD
    ds_merged.append(ds_vars)
ds_merged_xr = xr.concat(ds_merged, dim=pd.Index(featenv.track_data.tracks.values[:3], name='tracks'))

#### 3. Save datasets of individual feature tracks/feature objects 
- create directories of feature catelogs and environmental varaibles
- feature track in the standard format (time, lat_centroid, lon_centroid)
- environmental variables: 2-D / 3-D and subdirectories named by the varialbe short names

In [None]:
%%time
# save feature and environmental variables accordingly
for var in ds_merged_xr.keys():
    
    if var != 'base_time':
        ds = ds_merged_xr[var]
        check3d = [i for i in ds.dims if i == 'level']
        if check3d and len(ds.dims) > 2:
            out_dir = featenv.env3d_dir
        elif len(ds.dims) > 2:
            out_dir = featenv.env2d_dir
    
        print(out_dir)
        ds.to_netcdf(out_dir / '{}_{}_merged.nc'.format(featenv.name, var), encoding={var: {'dtype': 'float32'}})
        print('save file: {}_{}_merged.nc'.format(featenv.name, var))

#### 4. Simple demonstration of data capability

In [None]:
data_sample = xr.open_dataset(featenv.env3d_dir / 'MCS_FLEXTRKR_q_merged.nc')

In [None]:
data_sample

In [None]:
cwv = 100/9.8*data_sample.q.integrate('level')
cwv_composite = cwv.mean(('x','y'))
cwv_composite.plot(cmap='terrain_r')

#### 5. add environmental variables
- In addition to standard outputs, users can use the module function to extract varaibles from external data sources (gridded data)


#### II. feature objects: time, lat_centroid, lon_centroid
- e.g., co-occurring features 

In [None]:
# creating demo dataset (id, time, lat_centroid, lon_centroid)
obj_id = np.array([1,2])
base_time = np.array([datetime(2020,1,1,0),datetime(2020,1,1,0)])
meanlat = np.array([13.5, -10]) # lat of feature centroid
meanlon = np.array([50, 160])

feature_obj = xr.Dataset(data_vars=dict(
                         base_time=(['object_id'], base_time),
                         meanlon=(['object_id'], meanlon),    
                         meanlat=(['object_id'], meanlat)),
                         coords=dict(object_id = (['object_id'], obj_id)))

featenv.object_data = feature_obj

#### use "get_environment_vars" and return the individual feature-env data

In [None]:
# extract feat-env data for a single object
ds_obj_out = featenv.get_environment_vars_single(object_id=1, lat_range=10, lon_range=10)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(5,4))
cp = ax.pcolormesh(ds_obj_out.lon, ds_obj_out.lat, ds_obj_out.sel(level=1000).q)
plt.colorbar(cp)
ax.set_title(ds_obj_out.time.values)
ax.grid(ls=':')
plt.show()

#### 3. Save datasets of individual feature tracks/feature objects 

In [None]:
out_dir = Path('/scratch/wmtsai/test_ground/')
ds_track_out.to_netcdf(out_dir / 'featenv_dataset.test.nc')

In [None]:
featenv.track_data