# Interactive Est Forecasting

This notebook provides an interactive interfaces for using trained models to view forecasts for available input data.

**not compatible with Safari due to use of HTML date input field**

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

import numpy as np
from sklearn import preprocessing
from scipy import stats

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib as mpl
from matplotlib.ticker import FormatStrFormatter
plt.style.use('seaborn-talk')
mpl.rcParams['figure.figsize'] = [8, 8]

import pandas as pd
import tqdm

from bdl_tensorflow import *
from helper_functions import *
import itertools

import ipywidgets as widgets

In [18]:
omni = pd.read_hdf('data/omni_hourly_alldata_smallfilled.h5')
goes = pd.read_hdf('data/GOES_xrs_xhf_hourly_1986-2018.h5')
flares = pd.read_hdf('data/GOES_flare_locations_hourly_1975-2016.h5')
ext = pd.read_hdf('data/external_coefficients.h5')
cme = pd.read_hdf('data/cme_hourly_complete.h5')
p10s = pd.read_hdf('data/dst_est_ist.h5')

dst = omni['Dst (nT)']

# let's use the most relevant SW measurements from omni
omni_use_cols = ['BX, nT (GSE, GSM)',
                'BY, nT (GSM)',
                'BZ, nT (GSM)',  
                'SW Proton Density, N/cm^3',  
                'SW Plasma Speed, km/s',  
                'SW Plasma Temperature, K',
                'SW Plasma flow long. angle',
                'SW Plasma flow lat. angle',]

# use all GOES data
goes_use_cols = ['short channel',
                'long channel']
# goes_use_cols = ['Xhf short channel',
#                 'Xhf long channel']

# flare data to use
flares_use_cols = ['lat', 'lon', 'intensity']

# cme data to use
cme_use_cols = ['Central PA', 
                'Width',
                'Linear Speed',
                '2nd order speed: initial',
                 '2nd order speed: final',
                 '2nd order speed: 20R',
                 'Accel',
                'Mass',
                'Kinetic energy']

# p10 terms to use, just est for now
p10s_use_cols = ['est']

# let's focus on first zonal coefficients
ext_use_cols = ['q10']

# form full dataframe
data = pd.concat([omni[omni_use_cols], goes[goes_use_cols], p10s[p10s_use_cols], cme[cme_use_cols]], axis=1)


In [3]:
in_cols = list(itertools.chain(goes_use_cols, cme_use_cols, omni_use_cols, p10s_use_cols))

out_cols = p10s_use_cols

data_in = data[in_cols].copy()
data_out = data[out_cols].copy()

input_dim = data_in.shape[1]
output_dim = data_out.shape[1]

ndat = data.shape[0]
t = data.index

In [4]:
# only run if using CME or x-ray data
data_in['short channel'] = np.log10(data_in['short channel']+1e-8)
data_in['long channel'] = np.log10(data_in['long channel']+1e-8)
data_in['Mass'] = np.log10(data_in['Mass'] + 1e13)
data_in['Kinetic energy'] = np.log10(data_in['Kinetic energy'] + 1e27)

In [5]:
lahead_max = 6

data_in_arr = data_in[0:-lahead_max].values.astype(np.float32)
# data_out_arr = data_out[lahead:].values.astype(np.float32)

t_in = t[0:-lahead_max]

In [6]:
# normalize
scaler_input = preprocessing.MinMaxScaler(feature_range=(0,1))
data_in_scaled = scaler_input.fit_transform(data_in_arr)

# reshape input data
data_in_scaled = data_in_scaled.reshape(-1, 1, input_dim)

In [7]:
lstm_dim_1 = 20
lstm_dim_2 = 10
lstm_dim_3 = 5
lstm_dim_4 = 10

dense_dim_1 = 5
dense_dim_2 = 10
dense_dim_3 = 5
dense_dim_4 = 5

In [8]:
models = []

# gather all of the models necessary for 1-6 hour ahead forecasts
for ii in range(6):
    model = tf.keras.models.Sequential([
        tf.keras.layers.LSTM(lstm_dim_1, return_sequences=True, input_shape=(1, input_dim)),
        tf.keras.layers.LSTM(lstm_dim_2, return_sequences=False),
        tf.keras.layers.Dense(dense_dim_2, 
                              activation='tanh'),
        tf.keras.layers.Dense(2, 
                              kernel_initializer=tf.keras.initializers.Constant(0), 
                              bias_initializer=tf.keras.initializers.Constant([0, 20])),
        tfp.layers.DistributionLambda(
            lambda t: tfd.Normal(loc=t[..., 0:1], 
                                 scale=1e-3 + tf.math.softplus(1.0 * t[..., 1:])))
    ])
    model_name = 'Estout_XrayCMESWEst_Gaussian_mod_L20L10D10_t%d' % (ii+1)
    model_weights_path = 'models/' + model_name + '/cp.ckpt'
    model.load_weights(model_weights_path)
    models.append(model)

In [9]:
data_out_preds = []

for ii in range(6):
    post_model = models[ii](data_in_scaled)

    # make distributions from learned parameters
    data_out_preds.append(stats.norm(loc=post_model.loc.numpy(), scale=post_model.scale.numpy()))

## Plot 1-6 hour forecast

This widget plots for a selected date the forecasted Est from one to six hours ahead. The user can also specify how many hours of context to provide before the forecasted Est as well as which input data to plot alongside the forecast.

In [91]:
# IN:
# t0: 
# dat_in_plot: which input data to plot, must match column names in the dataframe
# context: number of hours of context (i.e. observed Est and selected inputs) before plotting forecast
def plot_forecast_1_6(date, hours, dat_in_plot, context):
    # origin time index
    t0 = datetime.datetime(date.year, date.month, date.day) + datetime.timedelta(hours=hours)
    idx = np.where(data.index == t0)[0]
    if len(idx) == 0:
        print('Requested date not found. Please select another.')
        return
    
    # if we have selected a very young date
    if idx < context
    return len(idx)

problem is now that we have to find the parts where all the data overlap, which will require reworking the data preprocessing code that generated batches.

In [94]:
date_widget = widgets.DatePicker(
    description='Pick a Date',
    disabled=False
)

hours_widget = widgets.IntText(value=0, min=0, max=23, step=1, description='Hour: ')

dat_in_widget = widgets.SelectMultiple(
                    options=list(data),
                    value=['SW Plasma Speed, km/s'],
                    rows=20,
                    description='Data to plot',
                    disabled=False
)

context_widget = widgets.IntText(value=96, min=0, max=23, step=1, description='Hours before: ')

w = widgets.interact_manual(plot_forecast_1_6,
                            date=date_widget, 
                            dat_in_plot=dat_in_widget, 
                            hours=hours_widget,
                            context=context_widget)

interactive(children=(DatePicker(value=None, description='Pick a Date'), IntText(value=0, description='Hour: '…

## Plot k hour ahead forecast

This widget produces output more similar to that from the figures in the manuscript, showing the k-hour ahead forecast for user select k in [1,6]. 