In [1]:
# general tools
import os
import sys
from glob import glob

# data tools
import re
import time
import h5py
import random
import numpy as np
from random import shuffle
from datetime import datetime, timedelta

#tf.config.run_functions_eagerly(True)

sys.path.insert(0, '/glade/u/home/ksha/NCAR/')
sys.path.insert(0, '/glade/u/home/ksha/NCAR/libs/')

from namelist import *
import data_utils as du
import model_utils as mu


ModuleNotFoundError: No module named 'data_utils'

In [None]:
# ==================== #
weights_round = 0
save_round = 1
seeds = 777
model_prefix_load = 'RE2_peak2_base{}'.format(weights_round) #False
model_prefix_save = 'RE2_peak2_base{}'.format(save_round)
N_vars = L_vars = 15
# ==================== #

In [None]:
# ----------------------------------------------------- #
# Collect pos and neg batch filenames

vers = ['v3', 'v4x', 'v4'] # HRRR v4, v4x, v4
leads = [2, 3, 4, 5, 6, 20, 21, 22, 23]

filenames_pos = {}
filenames_neg = {}

# Identify and separate pos / neg batch files
for ver in vers:
    for lead in leads:
        if ver == 'v3':
            path_ = path_batch_v3
        elif ver == 'v4':
            path_ = path_batch_v4
        else:
            path_ = path_batch_v4x
            
        filenames_pos['{}_lead{}'.format(ver, lead)] = sorted(glob("{}*pos*lead{}.npy".format(path_, lead)))
        filenames_neg['{}_lead{}'.format(ver, lead)] = sorted(glob("{}*neg_neg_neg*lead{}.npy".format(path_, lead)))
        
        print('{}, lead{}, pos: {}, neg: {}'.format(ver, lead, len(filenames_pos['{}_lead{}'.format(ver, lead)]), 
                                             len(filenames_neg['{}_lead{}'.format(ver, lead)])))

In [None]:
# ----------------------------------------------------- #
# Separate train and valid from pos / neg batches
filenames_pos_train = {}
filenames_neg_train = {}

filenames_pos_valid = {}
filenames_neg_valid = {}

for ver in vers:
    for lead in leads:
        temp_namelist_pos = filenames_pos['{}_lead{}'.format(ver, lead)]
        temp_namelist_neg = filenames_neg['{}_lead{}'.format(ver, lead)]
        
        pos_train, pos_valid = mu.name_extract(temp_namelist_pos)
        neg_train, neg_valid = mu.name_extract(temp_namelist_neg)
        
        print('pos train: {} pos valid: {} neg train: {} neg valid {}'.format(len(pos_train), len(pos_valid), len(neg_train),len(neg_valid)))
        
        filenames_pos_train['{}_lead{}'.format(ver, lead)] = pos_train
        filenames_neg_train['{}_lead{}'.format(ver, lead)] = neg_train
        
        filenames_pos_valid['{}_lead{}'.format(ver, lead)] = pos_valid
        filenames_neg_valid['{}_lead{}'.format(ver, lead)] = neg_valid

In [None]:
# ------------------------------------------------------------------ #
# Merge train/valid and pos/neg batch files from multiple lead times
pos_train_all = []
neg_train_all = []
pos_valid_all = []
neg_valid_all = []

for ver in vers:
    for lead in leads:
        pos_train_all += filenames_pos_train['{}_lead{}'.format(ver, lead)]
        neg_train_all += filenames_neg_train['{}_lead{}'.format(ver, lead)]
        pos_valid_all += filenames_pos_valid['{}_lead{}'.format(ver, lead)]
        neg_valid_all += filenames_neg_valid['{}_lead{}'.format(ver, lead)]

In [None]:
def neighbour_leads(lead):
    out = [lead-2, lead-1, lead, lead+1]
    flag_shift = [0, 0, 0, 0]
    
    for i in range(4):
        if out[i] < 0:
            out[i] = 24+out[i]
            flag_shift[i] = -1
        if out[i] > 23:
            out[i] = out[i]-24
            flag_shift[i] = +1
            
    return out, flag_shift

In [None]:
label_smooth_v3 = ()
label_smooth_v4x = ()
label_smooth_v4 = ()

for lead in leads:

    lead_window, flag_shift = neighbour_leads(lead)
    
    print('Collect HRRR v3 labels ...')
    
    record_all = ()

    for i, lead_temp in enumerate(lead_window):

        flag_ = flag_shift[i]

        with h5py.File(save_dir_scratch+'SPC_to_lead{}_72km_all.hdf'.format(lead_temp), 'r') as h5io:
            record_temp = h5io['record_v3'][...]

        if flag_shift[i] == 0:
            record_all = record_all + (record_temp,)

        if flag_shift[i] == -1:
            record_temp[1:, ...] = record_temp[:-1, ...]
            record_temp[0, ...] = np.nan
            record_all = record_all + (record_temp,)

        if flag_shift[i] == +1:
            record_temp[:-1, ...] = record_temp[1:, ...]
            record_temp[-1, ...] = np.nan
            record_all = record_all + (record_temp,)


    shape_record = record_temp.shape      
    record_v3 = np.empty(shape_record)
    record_v3[...] = 0.0 #np.nan

    for i in range(4):
        record_temp = record_all[i]
        for day in range(shape_record[0]):
            for ix in range(shape_record[1]):
                for iy in range(shape_record[2]):
                    for event in range(shape_record[3]):
                        if record_temp[day, ix, iy, event] > 0:
                            record_v3[day, ix, iy, event] = 1.0
                        elif record_v3[day, ix, iy, event] == 1.0:
                            record_v3[day, ix, iy, event] = 1.0
                        else:
                            record_v3[day, ix, iy, event] = 0.0
    
    label_smooth_v3 += (record_v3[None, ...],)
    
    print('... Done. Collect HRRR v4x labels ...')
    
    record_all = ()

    for i, lead_temp in enumerate(lead_window):

        flag_ = flag_shift[i]

        with h5py.File(save_dir_scratch+'SPC_to_lead{}_72km_v4x.hdf'.format(lead_temp), 'r') as h5io:
            record_temp = h5io['record_v4x'][...]

        if flag_shift[i] == 0:
            record_all = record_all + (record_temp,)

        if flag_shift[i] == -1:
            record_temp[1:, ...] = record_temp[:-1, ...]
            record_temp[0, ...] = np.nan
            record_all = record_all + (record_temp,)

        if flag_shift[i] == +1:
            record_temp[:-1, ...] = record_temp[1:, ...]
            record_temp[-1, ...] = np.nan
            record_all = record_all + (record_temp,)


    shape_record = record_temp.shape      
    record_v4x = np.empty(shape_record)
    record_v4x[...] = np.nan

    for i in range(4):
        record_temp = record_all[i]
        for day in range(shape_record[0]):
            for ix in range(shape_record[1]):
                for iy in range(shape_record[2]):
                    for event in range(shape_record[3]):
                        if record_temp[day, ix, iy, event] > 0:
                            record_v4x[day, ix, iy, event] = 1.0
                        elif record_v4x[day, ix, iy, event] == 1.0:
                            record_v4x[day, ix, iy, event] = 1.0
                        else:
                            record_v4x[day, ix, iy, event] = 0.0
    
    label_smooth_v4x += (record_v4x[None, ...],)
    
    print('... Done. Collect HRRR v4 labels ...')
    
    record_all = ()
    
    for i, lead_temp in enumerate(lead_window):

        flag_ = flag_shift[i]

        with h5py.File(save_dir_scratch+'SPC_to_lead{}_72km_all.hdf'.format(lead_temp), 'r') as h5io:
            record_temp = h5io['record_v4'][...]

        if flag_shift[i] == 0:
            record_all = record_all + (record_temp,)

        if flag_shift[i] == -1:
            record_temp[1:, ...] = record_temp[:-1, ...]
            record_temp[0, ...] = np.nan
            record_all = record_all + (record_temp,)

        if flag_shift[i] == +1:
            record_temp[:-1, ...] = record_temp[1:, ...]
            record_temp[-1, ...] = np.nan
            record_all = record_all + (record_temp,)
            
            
    shape_record = record_temp.shape      
    record_v4 = np.empty(shape_record)
    record_v4[...] = 0.0 #np.nan

    for i in range(4):
        record_temp = record_all[i]
        for day in range(shape_record[0]):
            for ix in range(shape_record[1]):
                for iy in range(shape_record[2]):
                    for event in range(shape_record[3]):
                        if record_temp[day, ix, iy, event] > 0:
                            record_v4[day, ix, iy, event] = 1.0
                        elif record_v4[day, ix, iy, event] == 1.0:
                            record_v4[day, ix, iy, event] = 1.0
                        else:
                            record_v4[day, ix, iy, event] = 0.0
                            
    label_smooth_v4 += (record_v4[None, ...],)
    
    print('... Done')

In [None]:
# label smoothing operations

In [None]:
label_concat_v3 = np.concatenate(label_smooth_v3, axis=0)
label_concat_v4x = np.concatenate(label_smooth_v4x, axis=0)
label_concat_v4 = np.concatenate(label_smooth_v4, axis=0)

In [None]:
label_final = ...

In [None]:
def filename_to_loc(filenames):
    
    indx_out = []
    indy_out = []
    day_out = []
    
    for i, name in enumerate(filenames):
        
        nums = re.findall(r'\d+', name)
        indy = int(nums[-2])
        indx = int(nums[-3])
        day = int(nums[-4])
      
        indx_out.append(indx)
        indy_out.append(indy)
        day_out.append(day)
        
    return np.array(indx_out), np.array(indy_out), np.array(day_out)

In [None]:
indx_pos_train, indy_pos_train, day_pos_train = filename_to_loc(pos_train_all)
indx_neg_train, indy_neg_train, day_neg_train = filename_to_loc(neg_train_all)

indx_pos_valid, indy_pos_valid, day_pos_valid = filename_to_loc(pos_valid_all)
indx_neg_valid, indy_neg_valid, day_neg_valid = filename_to_loc(neg_valid_all)

In [None]:
y_pos_train = label_final[day_pos_train, indx_pos_train, indy_pos_train]
y_neg_train = label_final[day_neg_train, indx_neg_train, indy_neg_train]

y_pos_valid = label_final[day_pos_valid, indx_pos_valid, indy_pos_valid]
y_neg_valid = label_final[day_neg_valid, indx_neg_valid, indy_neg_valid]

In [None]:
# ----------------------------------------------------------------- #
# Load valid files for model training

filename_valid = neg_valid_all[::130] + pos_valid_all[::13]
VALID_target = y_neg_valid[::130] + y_pos_valid[::13]

L_valid = len(filename_valid)
print('number of validation batches: {}'.format(L_valid))

VALID_input_64 = np.empty((L_valid, 64, 64, L_vars))

for i, name in enumerate(filename_valid):
    data = np.load(name)
    for k, c in enumerate(ind_pick_from_batch):
        
        VALID_input_64[i, ..., k] = data[..., c]