# LSTM v3

LSTM preceded by convolutional layers.

- Feed a sequence of full grids with one value per cell as input
- First process each input with one or several CNN's
- Then process the sequence with an LSTM
- Use batch normalization and dropout where appropriate

Later improvements:

- Smaller resolution such that CNN's can pick up more details
- Several features per cell instead of only one (precipitation)

!!! Keras actually includes an implementation of this convolutional LSTM layer => https://keras.io/layers/recurrent/#convlstm2d

## Dependencies

In [1]:
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
import pandas as pd
 
import paths
from TRMM import TRMM
from ModelHelpers import ModelHelpers
from ModelTRMMv3 import ModelTRMMv3

# force autoreload of external modules on save
%load_ext autoreload
%autoreload 2
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)

Using TensorFlow backend.


## Data Wrangling

In [2]:
YEARS = range(1998, 2017)
YEARS_TRAIN = range(1998, 2016)
YEARS_TEST = range(2016, 2017)
PRE_MONSOON = [3, 4, 5]
PREDICT_ON = '{}-05-11'

## should the data be aggregated => specify degrees
## otherwise use None to get the full 140x140 grid
AGGREGATION_RESOLUTION = 1.0

In [3]:
## --- Loading the dataset ---
onset_dates, onset_ts = ModelHelpers.load_onset_dates()
prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)

def filter_fun(df, year):
    return ModelHelpers.filter_until(df, prediction_ts[year])

data_trmm = TRMM.load_dataset(YEARS, PRE_MONSOON, filter_fun=filter_fun, aggregation_resolution=AGGREGATION_RESOLUTION, bundled=False)

> Loading from cache...


In [4]:
onset_dates.tail()

Unnamed: 0,date,timestamp
106,2013-05-22T00:00:00+00:00,1369181000.0
107,2014-06-06T00:00:00+00:00,1402013000.0
108,2015-06-05T00:00:00+00:00,1433462000.0
109,2016-06-08T00:00:00+00:00,1465344000.0
110,2017-05-30T00:00:00+00:00,1496102000.0


In [5]:
data_trmm[1998].describe()

Unnamed: 0,888706800,888793200,888879600,888966000,889052400,889138800,889225200,889311600,889398000,889484400,...,894060000,894146400,894232800,894319200,894405600,894492000,894578400,894664800,894751200,894837600
count,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,...,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0,1156.0
mean,3.256022,6.992055,10.203929,5.802326,3.121004,2.658498,3.433042,1.105775,0.687662,1.354042,...,21.439245,13.763709,16.945637,24.21982,34.125647,35.610794,23.480407,30.55313,33.154218,78.702652
std,14.863695,26.598394,35.259176,20.076506,12.429782,9.792503,17.207129,6.844537,3.747201,8.637036,...,58.383859,39.741938,60.545809,66.353668,82.74785,80.647065,57.451589,90.352838,79.551429,252.047319
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.048,0.094898,0.09,0.73989,2.129305,2.722099,1.059,0.525,0.875747,0.784607
75%,0.156448,0.450719,3.258196,1.984283,0.065919,0.126729,0.0,0.0,0.0,0.0,...,9.976586,7.641604,6.828,12.72898,26.893797,32.368634,14.819206,13.35433,21.19119,18.982943
max,220.093811,495.556726,517.042692,316.860282,167.528366,110.581672,187.201502,108.649827,56.540738,159.564578,...,621.469453,473.365212,1071.545341,933.276034,1109.697169,668.348853,464.6895,1165.589954,605.189995,2514.74987


In [14]:
data_trmm[1998][894837600].unstack()

longitude,63.375,64.375,65.375,66.375,67.375,68.375,69.375,70.375,71.375,72.375,...,87.375,88.375,89.375,90.375,91.375,92.375,93.375,94.375,95.375,96.375
latitude,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
6.125,849.779974,791.189987,1094.339981,1221.839985,789.179977,998.549992,1213.949978,1139.429955,1147.979959,1385.279999,...,83.219999,49.589998,73.049998,155.909995,57.719998,143.489995,353.099993,109.38,139.547018,430.566095
7.125,323.249996,393.899986,540.839983,770.999981,666.809994,1214.519955,1519.559937,1219.169979,706.679993,1023.869984,...,275.969995,150.419999,294.569993,312.96,138.959993,4.89,9.93,21.39,38.789999,172.289993
8.125,126.989995,167.699995,145.079998,205.829996,224.249993,1315.109974,1808.27993,1454.460007,1226.309956,1435.859978,...,223.289993,136.799995,17.759999,92.399999,29.699999,1.56,0.48,0.3,0.0,14.61
9.125,10.86,5.97,17.189999,61.589997,76.799997,346.289991,584.639994,984.389977,1215.329983,1073.069984,...,103.410001,77.399998,18.09,3.27,0.57,0.15,1.77,0.0,0.0,7.08
10.125,1.71,3.18,8.43,8.34,49.859998,47.369998,239.789995,531.119982,396.09,1010.069961,...,8.97,30.029998,4.83,1.02,10.08,3.51,0.45,0.0,0.0,0.0
11.125,4.56,0.78,10.05,7.05,1.56,5.04,35.549999,113.639997,103.469996,180.689995,...,0.33,0.69,1.92,4.35,17.43,10.026772,2.100164,0.0,0.0,1.26
12.125,8.55,21.989999,7.59,2.31,0.0,0.18,2.28,8.97,20.789999,26.369999,...,1.23,18.84,12.18,0.72,26.819999,63.811042,1.285232,0.81,0.84,0.03
13.125,3.42,8.31,28.769999,1.83,0.0,0.0,0.0,0.0,0.51,2.58,...,1.53,6.09,6.51,0.42,11.13,1.38,0.0,0.0,11.91,1.92
14.125,0.3,2.88,8.76,9.66,0.15,0.0,0.0,0.0,0.0,0.0,...,0.99,0.0,0.0,0.09,0.0,0.0,0.0,0.0,0.0,0.0
15.125,0.0,0.09,1.53,1.71,0.0,0.0,0.0,0.0,0.0,0.0,...,0.87,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
# check if any value is null in any year
print('year', 'null_count')
for year in YEARS:
    print(year, data_trmm[year].isnull().sum().sum())
    
# data_trmm[1998].isnull()

year null_count
1998 0
1999 0
2000 0
2001 0
2002 0
2003 0
2004 0
2005 0
2006 0
2007 0
2008 0
2009 0
2010 0
2011 0
2012 0
2013 0
2014 0
2015 0
2016 0


In [7]:
## how many epochs for each tuning
EPOCHS = 100

## should early stopping be enabled?
## training will stop if val_loss increased for more than PATIENCE times in a row
## enabled if >0
PATIENCE = 30

TUNINGS = [{
    'epochs': EPOCHS,
    'patience': PATIENCE
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'batch_size': 4
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'dropout': 0.7,
    'dropout_conv': 0.5
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'dropout': 0.4,
    'dropout_conv': 0.3
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'optimizer': 'rmsprop',
    'learning_rate': 0.01
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'optimizer': 'rmsprop',
    'learning_rate': 0.003
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'optimizer': 'adam'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'optimizer': 'sgd'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'numerical_loss': 'mean_absolute_error'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'batch_norm': False
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'num_filters': (64, 32, 16)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'num_filters': (16, 16, 16),
    'kernel_dims': (6, 5, 4)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'num_filters': (8, 16, 32),
    'kernel_dims': (3, 5, 7)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'num_filters': (16, 16, 16),
    'kernel_dims': (4, 5, 6)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (0, 2, 0, 2)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'recurrent_activation': 'tanh'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (0, 0, 0, 3),
    'recurrent_activation': 'tanh'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (0, 0, 0, 3),
    'recurrent_activation': 'tanh'
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (0, 0, 0, 3),
    'recurrent_activation': None
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh',
    'recurrent_activation': None,
    'num_filters': (16, 8, 4),
    'kernel_dims': (6, 4, 2)
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh',
    'dropout_recurrent': 0.3
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh',
    'dense_nodes': 512
}]

In [8]:
X_train, y_train, X_test, y_test, unstacked = ModelTRMMv3.train_test_split(data_trmm, prediction_ts, onset_ts, years_train=YEARS_TRAIN, years_test=YEARS_TEST, numerical=True)

In [11]:
models_num = ModelHelpers.run_configs(ModelTRMMv3, TUNINGS, X_train, y_train, numerical=True, version='T3-num-1.0deg')

## TODO: evaluation of the models above
## can get any model from the models_num array
## models_num[0]['model'] to get the trained model class from the first tuning
## models_num[0]['model'].model to get the keras model instance
## models_num[0]['model'].history to get the training history
## models_num[0]['config'] to get the parameters the model was trained with
## then e.g. predict the test set and compare with y_test
print(models_num[0]['model'].model.predict(X_test))

> Training config {'epochs': 100, 'patience': 30}
ep-100_bat-1_opt-rmsprop_lr-None_dpt-0.6_dptR-None_dptC-0.4_ls-mean_squared_error_flt-(32, 16, 8)_krn-(7, 5, 3)_pool-(0, 0, 0, 4)_pad-same_pat-30_norm-True_dnsNod-1024_dnsAct-relu_recAct-relu_splt-0.1
> Training config {'epochs': 100, 'patience': 30, 'batch_size': 4}
ep-100_bat-4_opt-rmsprop_lr-None_dpt-0.6_dptR-None_dptC-0.4_ls-mean_squared_error_flt-(32, 16, 8)_krn-(7, 5, 3)_pool-(0, 0, 0, 4)_pad-same_pat-30_norm-True_dnsNod-1024_dnsAct-relu_recAct-relu_splt-0.1
> Training config {'epochs': 100, 'patience': 30, 'dropout': 0.7, 'dropout_conv': 0.5}
ep-100_bat-1_opt-rmsprop_lr-None_dpt-0.7_dptR-None_dptC-0.5_ls-mean_squared_error_flt-(32, 16, 8)_krn-(7, 5, 3)_pool-(0, 0, 0, 4)_pad-same_pat-30_norm-True_dnsNod-1024_dnsAct-relu_recAct-relu_splt-0.1
> Training config {'epochs': 100, 'patience': 30, 'dropout': 0.4, 'dropout_conv': 0.3}
ep-100_bat-1_opt-rmsprop_lr-None_dpt-0.4_dptR-None_dptC-0.3_ls-mean_squared_error_flt-(32, 16, 8)_krn-(7, 

KeyboardInterrupt: 

In [None]:
losses = list(map(lambda m: m['model'].history['loss'], models_num))
val_losses = list(map(lambda m: m['model'].history['val_loss'], models_num))

In [None]:
layout = go.Layout(
    yaxis=dict(
        type='log',
        autorange=True
    )
)

fig = go.Figure(data=[go.Scatter(
        y = losses[i],
        mode = 'lines',
        name = 'Loss #{}'.format(i)
    ) for i in range(len(losses))] + [go.Scatter(
        y = val_losses[i],
        mode = 'lines',
        name = 'Val. #{}'.format(i)
    ) for i in range(len(losses))], layout=layout)

plotly.offline.iplot(fig)