In [142]:
# The purpose of this script is to generate a pandas dataframe for all candidate ndv/dv's of feature values calculated at each storm center at some time between the CCKW intersection time and up to X day(s) afterwards.
# Feature values are anomalized and plotted. Basic feature statistics are also calculated and explored. 

# INPUTS
# --------------------------------------------------------
# ds <- Rosi's aquaplanet dataset 
# selected_eq_dv_ndv_timeseries.csv <- timeseries of selected developers and nondevelopers for ANN ML analysis, generated by 02_select_storms_and_explore 
# ndv_time_criteria <- Defined in 02_select_storms_and_explore; how long do ndv's have to live for (in days) after crossing a CCKW crest?
# dt_after_intersection <- user defined variable (below). How long (hrs) after valid CCKW cross do we want to extract information at storm center for ANN effort? 
# grid_size <- user defined variable (below). Size of n x n grid box around storm center to average feature values over. 
# n_bins = 10 <- number of bins used for 


#The following two inputs were generated by Calc_derived_vars_and_find_latitude_averages_aquaplanet.ipynb
#vor850.nc -> calculates 850hPa vorticity over all dimensions in the aquaplanet simulation. 
#lat_avg_values_aquaplanet_with_vor850.nc -> find zonal means of aquaplanet variables (including derived vars such as 850hPa vorticity)


# OUTPUTS
# --------------------------------------------------------
# df_for_ANN.csv <- dataframe consisting of all pertinent information for the ANN effort. We identify a point in time for each developer (relative to the CCKW intersection)
# and extract feature variables around the storm center at that point in time from the simulation dataset. 
# feature_list_for_ANN.txt <- dataframe consisiting of select features that will be used for the ANN effort. 



In [143]:
#Import relevant packages 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import numpy as np
import seaborn as sns
from scipy.stats import zscore, wasserstein_distance
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
import papermill
from pathlib import Path
import os
import glob

In [144]:
#Which simulation are we gonna pull data from? 
data_dir =  Path("/glade/u/home/sjsharma/CCKW_Project/data_and_scripts_from_3km_simulation")

#Which folder are we going to save the data in? 
save_dir = Path("/glade/u/home/sjsharma/CCKW_Project/CNN/CNN_outputs_3km")

In [145]:
#Read in csv generated by 02_select_storms of candidate devs/non-devs. 

df = pd.read_csv(f'{save_dir}/02_outputs/selected_eq_dv_ndv_timeseries.csv',parse_dates=['valid_time'])
#read in txt file of ndv criteria set in 02 (how long do ndv's have to live for (in days) after crossing a CCKW crest?
with open(f'{save_dir}/02_outputs/ndv_time_criteria.txt', "r") as f:
    ndv_time_criteria = float(f.read().strip())


In [146]:
#read in Rosi's aquaplanet dataset 
exp_name = 'TC_3km'
pth = "/glade/campaign/mmm/dpm/rberrios/glade_scratch/MPAS_APE/aqua_sstmax10N/%s/"%exp_name
fname = pth+'latlon/diags_gaussian_global_nospinup_r3600x1800.nc'
ds = xr.open_dataset(fname)



In [147]:
#How many hours after CCKW intersection do we want to take a data snapshot for CNN? 
dt_after_intersection = 0

In [148]:
#Select lat/lon range for CNN fields (centered on storm lat/lon coords at a given point in time).

lat_pixels = 64 #number of latitude grid cells 
lon_pixels = 64 #number of longitude grid cells 


In [149]:
#use simulation grid spacing to set csv title 

dgrid = round(ds['lat'].isel(lat=1).item()-ds['lat'].isel(lat=0).item(),3)

lat_range = lat_pixels*dgrid
lon_range = lon_pixels*dgrid 

append_to_filename_raw = f'_{dt_after_intersection}_hrs_after_CCKW_cross_{exp_name}_simulation_{dgrid}_deg_grid_resolution_lat_{lat_pixels}_pixels_lon_{lon_pixels}_pixels'
append_to_filename = append_to_filename_raw.replace(".","p") 

In [150]:
#get list of all feature variables from Rosi's dataset (and also add vor850 to the list - an important derived quantity not in the dataset)
feature_vars = list(ds.data_vars)
feature_vars.append('vor850') #append derived quantity (vor850)


In [151]:
#load key variables averages dataset generated previously 
avg_ds = xr.open_dataset(f'{data_dir}/lat_avg_values_aquaplanet_with_vor850.nc')

#read in vorticity data 
vor850 = xr.open_dataset(f'{data_dir}/vor850.nc')

In [152]:
#create groupby obj of eq dv/ndv time series
df_groupby = df.groupby("ID")

In [153]:
#This cell identifies the point along the storm track on/after the relevant CCKW intersection where data should be collected (some time b/w CCKW intersection and less than X day(s) afterwards). 

#Because devs/non-devs have multiple CCKW intersections we must ensure we only grab the relevant ones. 

# For devs, we shall grab the CCKW preceding tropical cyclogenesis (these have already been selected to have the closest CCKW preceding tc-genesis) 

# For non-devs, they can have multiple valid CCKW crossings but the storm needs to exist for X day(s) (as defined by ndv_time_criteria) after the cross 
# and must have full data richness (no missing or skipped pts b/w CCKW cross up to time X day(s) afterwards). 

selected_rows = [] 

for ID, group in df_groupby:

    # let's work on developers first -> find the closest CCKW preceding tcgenesis and find the desired time to extract a snapshot for ANN analysis - X hours after crossing a CCKW) 

    if group['developer'].iloc[0] == 1: #if storm is a developer 
        
        idx_tcgen = np.where(group['tc_genesis']==1)[0][0]

        idx_cckw_crosses = np.where(group['cckw_crest_cross']==1)[0]

        select_idx_cckw_crosses = idx_cckw_crosses[idx_cckw_crosses<idx_tcgen] #all cckw crosses prior to tc_genesis 

        idx_closest_preceding_cckw_cross_to_tcgen = select_idx_cckw_crosses[np.abs(idx_cckw_crosses-idx_tcgen).argmin()]

        #get time at this desired cckw cross

        cckw_cross_time = group['valid_time'].iloc[idx_closest_preceding_cckw_cross_to_tcgen] 

        desired_time = cckw_cross_time + pd.Timedelta(hours=dt_after_intersection)

        if not (desired_time in group['valid_time'].values):

            raise ValueError(f"Storm ID {ID} does not contain a datapoint at the desired time {desired_time}")

        #get index of desired time 

        desired_time_idx = np.where(group['valid_time']==desired_time)[0][0]

        #if conditions are met, add this point to df 
        
        selected_rows.append(group.iloc[desired_time_idx])

    #for non-developers -> let's find every valid CCKW crossing (that has data points up to X day(s) after the CCKW cross without additional cckw crosses in that time) 

    else: 
        
        idx_cckw_crosses = np.where(group['cckw_crest_cross']==1)[0]

        for idx_cckw_cross in idx_cckw_crosses:

            #first check to see if that storm has a valid CCKW crossing (has data through the time of the CCKW cross plus X day(s) after) 

            cckw_cross_time = group['valid_time'].iloc[idx_cckw_cross]

            time_window_max = cckw_cross_time + pd.Timedelta(days=ndv_time_criteria)

            if not (time_window_max in group['valid_time'].values):
                
                continue 

            time_window_max_idx = np.where( group['valid_time']==time_window_max)[0][0]

            #make sure no additional cckw crosses in the slice of time from CCKW cross to X day(s) after 

            cckw_cross_data_slice = group['cckw_crest_cross'].iloc[idx_cckw_cross+1:time_window_max_idx]

            if np.any(cckw_cross_data_slice==1): 
            
                continue 

            desired_time = cckw_cross_time + pd.Timedelta(hours=dt_after_intersection)

            if not (desired_time in group['valid_time'].values):

                print(f"Storm ID {ID} does not contain a datapoint at the desired time {desired_time}")

                continue 

            #get index of desired time 

            desired_time_idx = np.where(group['valid_time']==desired_time)[0][0]

            #if conditions are met, add this point to df 
        
            selected_rows.append(group.iloc[desired_time_idx])  # double brackets!

df_for_CNN = pd.DataFrame(selected_rows).reset_index(drop=True)

df_for_CNN['index']=df_for_CNN.index

In [154]:
#Let's also save off the df for CNN

df_for_CNN.to_csv(f'{save_dir}/03_outputs/CNN_config_dataframes/df_for_CNN_{append_to_filename}.csv')

In [155]:
df_for_CNN

Unnamed: 0,ID,fhr,valid_time,lon_TRACK,lat_TRACK,vor850,day_adj,tc_genesis,lon_track_hov_idx,day_track_hov_idx,cckw_crest_cross,suspect_point,developer,cckw_filteredpr_val,index
0,5,144.0,2000-04-05 00:00:00,11.652895,18.528280,1.493700,6.50,0.0,117.0,26.0,1.0,0.0,0,-0.058350,0
1,5,246.0,2000-04-09 06:00:00,358.736023,14.347824,1.518751,10.75,0.0,3587.0,43.0,1.0,0.0,0,-0.045556,1
2,81,84.0,2000-04-02 12:00:00,346.038757,12.146024,2.358891,4.00,0.0,3460.0,16.0,1.0,0.0,0,-0.037982,2
3,139,114.0,2000-04-03 18:00:00,278.686096,7.357641,2.415950,5.25,0.0,2787.0,21.0,1.0,0.0,1,0.342577,3
4,320,204.0,2000-04-07 12:00:00,129.846588,10.095192,2.404459,9.00,0.0,1298.0,36.0,1.0,0.0,0,0.225677,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
226,6279,2994.0,2000-08-01 18:00:00,91.697334,4.613536,1.294153,125.25,0.0,917.0,501.0,1.0,0.0,0,0.361769,226
227,6372,3054.0,2000-08-04 06:00:00,119.896683,6.640931,1.116741,127.75,0.0,1199.0,511.0,1.0,0.0,1,1.011430,227
228,6471,3102.0,2000-08-06 06:00:00,133.723282,8.651182,1.112677,129.75,0.0,1337.0,519.0,1.0,0.0,1,0.520501,228
229,6486,3138.0,2000-08-07 18:00:00,248.875549,4.874249,1.213878,131.25,0.0,2489.0,525.0,1.0,0.0,0,0.412700,229


In [156]:
out_dir = f"{save_dir}/03_outputs/netcdf_files/{append_to_filename[1:]}"

grid_spacing = (np.round(ds.lat[1].item() - ds.lat[0].item(),1))

ds_list = [] 

# build the folder path (without the filename)
out_dir = f"{save_dir}/03_outputs/netcdf_files/{append_to_filename[1:]}"
os.makedirs(out_dir, exist_ok=True)

for idx, row in df_for_CNN.iterrows():
    time = row['valid_time']
    lon_center = row['lon_TRACK']
    lat_center = row['lat_TRACK']

    lat_center_idx = ds.indexes['lat'].get_indexer([lat_center], method="nearest")[0]
    lon_center_idx = ds.indexes['lon'].get_indexer([lon_center], method="nearest")[0]

    lon_min_idx = lon_center_idx - lon_range/(2*grid_spacing)
    lon_max_idx = lon_center_idx + lon_range/(2*grid_spacing) - 1
                    
    lon_indices = np.linspace(int(lon_min_idx),int(lon_max_idx),int(lon_max_idx-lon_min_idx)+1)
    fixed_lon_indices = (lon_indices % 3600).astype(int) #deal w/ wrap-around using modulo operator
    time_idx = ds.indexes['time'].get_indexer([time], method="nearest")[0]
    
    lat_box_size = int(lat_range/grid_spacing)
    lon_box_size = int(lon_range/grid_spacing)
    
    sub_ds_main_vars = ds.isel(
    time=time_idx,
    lat=slice(lat_center_idx - int(lat_box_size/2), lat_center_idx + int(lat_box_size/2)),
    lon=fixed_lon_indices,
    )

    sub_ds_vor850 = vor850.isel(
    time=time_idx,
    lat=slice(lat_center_idx - int(lat_box_size/2), lat_center_idx + int(lat_box_size/2)),
    lon=fixed_lon_indices,
    )

    sub_ds = xr.merge([sub_ds_main_vars, sub_ds_vor850])
    # ds_list.append(sub_ds)
    sub_ds.to_netcdf(f'{out_dir}/ds{append_to_filename}_df_index_{idx}.nc')


In [157]:
#Let's load up the netcdf datasets you just generated and then put them in an ordered list by index! 

nc_files = glob.glob(f"{out_dir}/*.nc")

nc_file_list_ordered = [0]*len(nc_files)

for nc_file in nc_files:
    
    ds = xr.open_dataset(nc_file)

    index_raw = str.split(nc_file,'_')[-1]
    index = str.split(index_raw,'.')[0]
    index_number = int(index)

    nc_file_list_ordered[index_number] = ds


In [158]:
all_patches = []

for ds in nc_file_list_ordered:

    # make sure variables are stacked along "channel"
    arr = ds.to_array("channel").transpose("lat", "lon", "channel").values
    # arr.shape should be (150, 150, 22)

    all_patches.append(arr)

# stack into one big NumPy array
X = np.stack(all_patches, axis=0)
print(X.shape)  # (N, 150, 150, 22)

(231, 64, 64, 22)


In [159]:
np.save(f"{save_dir}/03_outputs/numpy_arrays_for_CNN/feature_data{append_to_filename}.npy", X)