In [1]:
# Import packages 
import pandas as pd
import numpy as np
import netCDF4
import h5netcdf
import xarray as xr
from os.path import join, exists
import joblib
from glob import glob
import datetime as dt
import sys, os
import pyresample
import itertools
from pathlib import Path
import pickle

#Filters
from scipy.ndimage import uniform_filter, maximum_filter, gaussian_filter

#Custom Packages
sys.path.append('/home/samuel.varga/python_packages/WoF_post') #WoF post package
sys.path.append('/home/samuel.varga/python_packages/wofs_ml_severe/')
sys.path.append('/home/samuel.varga/python_packages/MontePython/')
sys.path.append('/home/samuel.varga/projects/deep_learning/')

from wofs.post.utils import (
    save_dataset,
    load_multiple_nc_files,
)
from main.dl_2to6_data_pipeline import get_files, load_dataset
from collections import ChainMap


lookup_file: /home/samuel.varga/python_packages/WoF_post/wofs/data/psadilookup.dat


In [2]:
#Get list of Patch files - convert cases to datetime
path_base = f'/work/samuel.varga/data/2to6_hr_severe_wx/DEEP_LEARNING/SummaryFiles/'
file_base = f'wofs_DL2TO6_16_16_data.feather'
meta_file_base = f'wofs_DL2TO6_16_16_meta.feather'
out_path = '/work/samuel.varga/data/2to6_hr_severe_wx/DEEP_LEARNING/'

In [3]:
dates=[d for d in os.listdir(path_base) if '.txt' not in d]

paths=[] #Valid paths for worker function
bad_paths=[]
for d in dates:
    if d[4:6] !='05': 
        continue

    times = [t for t in os.listdir(join(path_base, d)) if 'basemap' not in t] #Init time

    for t in times:
        path = join(path_base, d , t)
        if exists(join(path,file_base)):
            paths.append(path)
print(paths[0])
print(f'Num Total Paths: {len(paths)} ')

/work/samuel.varga/data/2to6_hr_severe_wx/DEEP_LEARNING/SummaryFiles/20200518/1800
Num Total Paths: 1154 


In [4]:
#Check files to see where bad MRMS data, drop cases from list of files
for path in paths:
    ds = xr.load_dataset(join(join(path_base, path), file_base))
    if np.any(ds['MESH_severe__4km'].values<0) or np.any(ds['MRMS_DZ'].values<0):
        print('Bad path found - Missing Data')
        bad_paths.append(path)
        paths.remove(path)
    elif np.any(ds['MRMS_DZ'].values > 10**35):
        print('Bad path found - MRMS DZ Values exceed expected range')
        bad_paths.append(path)
        paths.remove(path)
    ds.close()
print(f'Num Paths w/ usable data: {len(paths)}') 

Bad path found - MRMS DZ Values exceed expected range
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Bad path found - Missing Data
Num Paths w/ usable data: 1129


In [11]:
#Convert remaining files into train/validation/test based on day
temp_paths=[path.split('/')[-2][0:8]+path.split('/')[-1] for path in paths] #Different domains on the same day are treated as identical for the purposes of T/T split
dates=[pd.to_datetime(path, format=f'%Y%m%d%H%M') for path in temp_paths]

#Split into train/test
from sklearn.model_selection import KFold as kfold, train_test_split
import random

all_dates = np.unique([date.strftime('%Y%m%d') for date in dates])
random.Random(42).shuffle(all_dates)
train_dates, test_dates = train_test_split(all_dates, test_size=0.3)
print('Training Dates:')
print(train_dates)

print('Testing Dates:')
print(test_dates)

#Split training set into 5 folds
train_folds = kfold(n_splits = 5, random_state=42, shuffle=True).split(train_dates)

with open(f'/work/samuel.varga/data/dates_split_deep_learning.pkl', 'wb') as date_file:
    pickle.dump({'train_dates':train_dates,'test_dates':test_dates}, date_file)

Training Dates:
['20220527' '20220531' '20210512' '20210520' '20230523' '20220506'
 '20200528' '20190502' '20230516' '20200521' '20190530' '20210510'
 '20220529' '20220518' '20190518' '20230524' '20210527' '20230518'
 '20230512' '20200504' '20230521' '20210517' '20230503' '20200518'
 '20210519' '20220520' '20190501' '20200526' '20210514' '20200527'
 '20210524' '20230530' '20190514' '20200519' '20200522' '20220502'
 '20200505' '20230504' '20220525' '20230502' '20210525' '20210521'
 '20230511' '20230517' '20220524' '20220509' '20190526' '20210507'
 '20220513' '20210504' '20200515' '20190523' '20190506' '20210528'
 '20190515' '20190520' '20210503' '20190529' '20230505' '20220510'
 '20200507' '20190513' '20220503' '20220512' '20210523' '20200508'
 '20190521' '20210518' '20230526' '20190516' '20200506' '20230510'
 '20200520' '20220511' '20220519']
Testing Dates:
['20210505' '20230501' '20220526' '20220516' '20190510' '20220530'
 '20230531' '20210506' '20220504' '20220528' '20190524' '202305

In [6]:
for i, (train_index, val_index) in enumerate(train_folds):
    print(f'Rotation: {i}')
    print(train_index, val_index)
    print(len(list(np.array(paths)[np.isin(np.array([date.strftime('%Y%m%d') for date in dates]), train_dates[train_index])])))
    print(len(list(np.array(paths)[np.isin(np.array([date.strftime('%Y%m%d') for date in dates]), train_dates[val_index])])))

Rotation: 0
[ 1  2  3  5  6  7  8 11 13 14 15 16 17 19 20 21 22 23 24 25 26 27 29 30
 31 32 33 36 37 38 39 40 41 43 44 45 46 47 48 50 51 52 53 54 55 56 57 58
 59 60 62 65 66 67 68 70 71 72 73 74] [ 0  4  9 10 12 18 28 34 35 42 49 61 63 64 69]
638
153
Rotation: 1
[ 0  1  2  3  4  6  8  9 10 11 12 13 14 15 17 18 19 20 21 23 24 25 26 27
 28 29 32 34 35 36 37 38 41 42 43 46 48 49 50 51 52 53 54 55 57 59 60 61
 62 63 64 65 67 68 69 70 71 72 73 74] [ 5  7 16 22 30 31 33 39 40 44 45 47 56 58 66]
639
152
Rotation: 2
[ 0  1  2  4  5  7  9 10 11 12 14 15 16 18 20 21 22 23 24 26 27 28 29 30
 31 32 33 34 35 37 39 40 41 42 43 44 45 46 47 48 49 51 52 55 56 57 58 59
 60 61 63 64 65 66 67 68 69 70 71 73] [ 3  6  8 13 17 19 25 36 38 50 53 54 62 72 74]
640
151
Rotation: 3
[ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 16 17 18 19 20 21 22 23 25 28
 29 30 31 33 34 35 36 37 38 39 40 42 44 45 47 49 50 51 52 53 54 56 58 59
 60 61 62 63 64 65 66 69 70 71 72 74] [11 15 24 26 27 32 41 43 46 48 55 57 67 68 73]
613


In [7]:
def format_metadata(meta_data_list):
    '''Reformats the metadata to appease the duplicate index errors.'''
    '''Args: meta_data_list: list of opened datasets'''
    meta = {}
    for v in meta_data_list[0].variables:
        #print(v)
        if v in ['run_date','init_time','patch_no']:
            meta[v] = np.append(np.array([]), [ x[v].values for x in meta_data_list])
        else:
            meta[v] = (['patch','NY_ind','NX_ind'],np.reshape(np.append(np.array([]), [x[v].values for x in meta_data_list]), (10*len(meta_data_list),16,16)))
        #print(np.shape(meta[v]))

    #Open NC file, add vars, save
    meta_ds = xr.Dataset(meta)
    return meta_ds

In [8]:
def save_rotation_nc(rot_num, train_ind, val_ind, unique_dates, path_list, date_list, out_path=out_path):
    '''rot_num: int - rotation number
        train_ind: list - list of indices for training folds - indices correspond to day in training_dates
        val_ind: list - list of indices for validation folds - indices correspond to day in training_dates
        unique_dates: list - list of unique dates in training set
        path_list: list - list of file paths of length N that contain directory info and init time
        date_list: list - list of dates of length N, with each date being YYYYmmdd for the corresponding path in path_list
    '''
    #Get list of paths for current rotation
    training_paths=list(np.array(path_list)[np.isin(np.array([date.strftime('%Y%m%d') for date in date_list]), unique_dates[train_ind])])
    validation_paths=list(np.array(path_list)[np.isin(np.array([date.strftime('%Y%m%d') for date in date_list]), unique_dates[val_ind])])
    
    #Add the filename to each of the paths
    print('Appending Filenames')
    training_file_paths = [join(path, file_base) for path in training_paths[:10]]
    training_meta_paths=[join(path, meta_file_base) for path in training_paths[:10]]
    validation_file_paths = [join(path, file_base) for path in validation_paths[:10]]
    validation_meta_paths=[join(path, meta_file_base) for path in validation_paths[:10]]
    
    
    #Create Training Data
    print(f'Saving training data for Rot {rot_num}')
    ds = [xr.open_dataset(f) for f in training_file_paths]
    ds = xr.concat(ds, dim='patch_no')

    #Save mean/variance for use in scaling 
    mean = np.array([np.nanmean(ds[v]) for v in ds.variables if 'severe' not in v])
    var = np.array([np.nanvar(ds[v]) for v in ds.variables if 'severe' not in v])
    #with open(f'/work/samuel.varga/data/2to6_hr_severe_wx/DEEP_LEARNING/scaling/rot_{rot_num}_scaling.pkl', 'wb') as scale_file:
    #    pickle.dump({'mean':mean,'var':var}, scale_file)
    
    #ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__rot_{rot_num}__training_data'))
    ds.close()
    
    print(f'Saving metadata for Rot {rot_num}')
    meta_ds = [xr.open_dataset(f) for f in training_meta_paths]
    meta_ds = format_metadata(meta_ds)
    #meta_ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__rot_{rot_num}__training_meta'))
    meta_ds.close()
    
    #Create validation data
    print(f'Saving validation data for Rot {rot_num}')
    ds = [xr.open_dataset(f) for f in validation_file_paths]
    ds = xr.concat(ds, dim='patch_no')
    #ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__rot_{rot_num}__validation_data'))
    ds.close()
    
    print(f'Saving metadata for Rot {rot_num}')
    meta_ds = [xr.open_dataset(f) for f in validation_meta_paths]
    meta_ds = format_metadata(meta_ds)
    #meta_ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__rot_{rot_num}__validation_meta'))
    meta_ds.close()
                          
    return None

In [12]:
#Save training folds:
for i, (train_ind, val_ind) in enumerate(train_folds):
    d = save_rotation_nc(i, train_ind, val_ind, train_dates, paths, dates)

Appending Filenames
Saving training data for Rot 0
Saving metadata for Rot 0
Saving validation data for Rot 0
Saving metadata for Rot 0
Appending Filenames
Saving training data for Rot 1
Saving metadata for Rot 1
Saving validation data for Rot 1
Saving metadata for Rot 1
Appending Filenames
Saving training data for Rot 2
Saving metadata for Rot 2
Saving validation data for Rot 2
Saving metadata for Rot 2
Appending Filenames
Saving training data for Rot 3
Saving metadata for Rot 3
Saving validation data for Rot 3
Saving metadata for Rot 3
Appending Filenames
Saving training data for Rot 4
Saving metadata for Rot 4
Saving validation data for Rot 4
Saving metadata for Rot 4


In [13]:
#Save testing set
testing_paths=list(np.array(paths)[np.isin(np.array([date.strftime('%Y%m%d') for date in dates]), test_dates)])
testing_file_paths = [join(path, file_base) for path in testing_paths[:10]]
testing_meta_paths=[join(path, meta_file_base) for path in testing_paths[:10]]


print(f'Saving testing data')
ds = [xr.open_dataset(f) for f in testing_file_paths]
ds = xr.concat(ds, dim='patch_no')
#ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__testing_data'))
ds.close()
    
print(f'Saving testing metadata')
meta_ds = [xr.open_dataset(f) for f in testing_meta_paths]
meta_ds = format_metadata(meta_ds)
#meta_ds.to_netcdf(join(out_path, f'wofs_dl_severe__2to6hr__testing_meta'))
meta_ds.close()

Saving testing data
Saving testing metadata
