# Load necessary packages

In [1]:
import pandas as pd
import numpy as np

hex_salmon = '#F68F83'
hex_gold = '#BC9661'
hex_indigo = '#2D2E5F'
hex_maroon = '#8C4750'
hex_white = '#FAFAFA'
hex_blue = '#7EB5D2'

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.dates import DateFormatter
import matplotlib.dates as dates

import matplotlib.font_manager as font_manager
mpl.font_manager._rebuild()

mpl.rcParams['font.family'] = 'SF Mono'
mpl.rcParams['font.weight'] = 'medium'
mpl.rcParams['axes.titleweight'] = 'semibold'
mpl.rcParams['axes.labelweight'] = 'medium'
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=[hex_indigo, hex_salmon, hex_maroon])
mpl.rcParams["figure.titlesize"] = 'large'
mpl.rcParams["figure.titleweight"] = 'semibold'

from termcolor import colored

from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso, LogisticRegression, Ridge, ElasticNet, LassoCV, RidgeCV, ElasticNetCV
from sklearn.feature_selection import SelectFromModel
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import roc_auc_score, accuracy_score

import tensorflow as tf

# Organise data

## Import features

In [2]:
! pip install 'git+git://github.com/HR/github-clone#egg=ghclone' &> /dev/null

! ghclone https://github.com/timovijn/ElectricityPriceForecasting/tree/master/LSTM

zsh:1: command not found: ghclone


In [3]:
features = pd.read_pickle(f"./features.pkl")

display(features)

Unnamed: 0,ID3,VOL,MCP,LOAD,LOAD_F,LOAD_FE,ID3 (-4),ID3 (-5),ID3 (-6),ID3 (-7),...,HOD 14,HOD 15,HOD 16,HOD 17,HOD 18,HOD 19,HOD 20,HOD 21,HOD 22,HOD 23
2015-01-08 01:00:00+00:00,22.953776,439.5,32.32,9008.00,8505.25,502.75,29.934792,61.666667,61.118812,61.370370,...,0,0,0,0,0,0,0,0,0,0
2015-01-08 02:00:00+00:00,23.168355,261.5,31.10,8889.25,8222.25,667.00,29.853669,29.934792,61.666667,61.118812,...,0,0,0,0,0,0,0,0,0,0
2015-01-08 03:00:00+00:00,21.000000,420.5,30.17,8929.25,8122.25,807.00,24.012378,29.853669,29.934792,61.666667,...,0,0,0,0,0,0,0,0,0,0
2015-01-08 04:00:00+00:00,30.000000,460.6,24.54,9423.75,8323.50,1100.25,23.269810,24.012378,29.853669,29.934792,...,0,0,0,0,0,0,0,0,0,0
2015-01-08 05:00:00+00:00,30.000000,250.0,32.00,10884.50,9015.00,1869.50,22.953776,23.269810,24.012378,29.853669,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-12-30 14:00:00+00:00,53.790740,446.6,46.19,13842.50,15329.25,1486.75,76.370821,87.755884,78.709213,52.958116,...,1,0,0,0,0,0,0,0,0,0
2018-12-30 15:00:00+00:00,59.477646,131.6,47.64,14319.25,15644.50,1325.25,63.690401,76.370821,87.755884,78.709213,...,0,1,0,0,0,0,0,0,0,0
2018-12-30 16:00:00+00:00,59.883829,310.1,55.94,15120.75,16285.75,1165.00,56.170316,63.690401,76.370821,87.755884,...,0,0,1,0,0,0,0,0,0,0
2018-12-30 17:00:00+00:00,59.471501,220.9,58.40,14728.75,15555.75,827.00,51.675229,56.170316,63.690401,76.370821,...,0,0,0,1,0,0,0,0,0,0


## Select features

In [4]:
X = features[['ID3', 'LOAD']]
y = features[['ID3']]

lag_X = range(-72, -3, 1)
lag_y = range(0, 3, 1)

X2 = pd.DataFrame(index = X.index, columns = pd.MultiIndex.from_product([['X'], lag_X, X.columns], names = ['Feature', 'Type', 'Lag']))
X2 = X2.rename_axis('Timestamp')

y2 = pd.DataFrame(index = y.index, columns = pd.MultiIndex.from_product([['y'], lag_y, y.columns], names = ['Feature', 'Type', 'Lag']))
y2 = y2.rename_axis('Timestamp')

frame = pd.merge(y2, X2, left_index = True, right_index = True)

X3 = pd.DataFrame(index = X.index)
y3 = pd.DataFrame(index = y.index)

for c in X.columns:
    for l in lag_X:
        X3[f'{c} ({l})'] = X[f'{c}'].shift(-l)

frame['X'] = X3.values

for c in y.columns:
    for l in lag_y:
        y3[f'{c} ({l})'] = y.shift(-l)

frame['y'] = y3.values

frame = frame.dropna()

display(frame)

Feature,y,y,y,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X
Type,0,1,2,-72,-72,-71,-71,-70,-70,-69,...,-8,-8,-7,-7,-6,-6,-5,-5,-4,-4
Lag,ID3,ID3,ID3,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,...,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD
Timestamp,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
2015-01-11 03:00:00+00:00,24.922597,26.238903,27.002718,22.953776,23.168355,21.000000,30.000000,30.000000,43.153846,43.588694,...,12552.00,12653.50,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00
2015-01-11 04:00:00+00:00,26.238903,27.002718,30.000000,23.168355,21.000000,30.000000,30.000000,43.153846,43.588694,43.537764,...,12653.50,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75
2015-01-11 05:00:00+00:00,27.002718,30.000000,35.000000,21.000000,30.000000,30.000000,43.153846,43.588694,43.537764,48.252186,...,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50
2015-01-11 06:00:00+00:00,30.000000,35.000000,35.000000,30.000000,30.000000,43.153846,43.588694,43.537764,48.252186,48.683607,...,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50,8019.75
2015-01-11 08:00:00+00:00,35.000000,35.000000,35.000000,30.000000,43.153846,43.588694,43.537764,48.252186,48.683607,46.580903,...,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50,8019.75,7929.75
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-12-30 12:00:00+00:00,56.170316,51.675229,53.790740,67.046715,68.057669,69.568080,73.739284,76.254406,67.158844,61.551377,...,10629.00,10633.25,10287.00,10155.75,10121.25,10276.00,10588.50,11064.00,11786.75,12507.50
2018-12-30 13:00:00+00:00,51.675229,53.790740,59.477646,68.057669,69.568080,73.739284,76.254406,67.158844,61.551377,57.997148,...,10633.25,10287.00,10155.75,10121.25,10276.00,10588.50,11064.00,11786.75,12507.50,13035.50
2018-12-30 14:00:00+00:00,53.790740,59.477646,59.883829,69.568080,73.739284,76.254406,67.158844,61.551377,57.997148,62.701536,...,10287.00,10155.75,10121.25,10276.00,10588.50,11064.00,11786.75,12507.50,13035.50,13448.25
2018-12-30 15:00:00+00:00,59.477646,59.883829,59.471501,73.739284,76.254406,67.158844,61.551377,57.997148,62.701536,60.599674,...,10155.75,10121.25,10276.00,10588.50,11064.00,11786.75,12507.50,13035.50,13448.25,13715.50


## Split train and test

In [5]:
X_train, X_test, y_train, y_test = train_test_split(
    frame['X'],
    frame['y'],
    test_size = 0.3,
    random_state = 0,
    shuffle = False)

X_train.columns = pd.MultiIndex.from_product([['X'], lag_X, X.columns], names = ['Lag', 'Feature', 'Lag'])
y_train.columns = pd.MultiIndex.from_product([['y'], lag_y, y.columns], names = ['Lag', 'Feature', 'Lag'])

frame_train = pd.merge(y_train, X_train, left_index = True, right_index = True)

X_test.columns = pd.MultiIndex.from_product([['X'], lag_X, X.columns], names = ['Lag', 'Feature', 'Lag'])
y_test.columns = pd.MultiIndex.from_product([['y'], lag_y, y.columns], names = ['Lag', 'Feature', 'Lag'])

frame_test = pd.merge(y_test, X_test, left_index = True, right_index = True)

print()
print(f'Train input', frame_train['X'].shape, 'output', frame_train['y'].shape)
print()
print(f'Test input', frame_test['X'].shape, 'output', frame_test['y'].shape)
print()

display(frame_train)


Train input (23639, 138) output (23639, 3)

Test input (10132, 138) output (10132, 3)



Lag,y,y,y,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X
Feature,0,1,2,-72,-72,-71,-71,-70,-70,-69,...,-8,-8,-7,-7,-6,-6,-5,-5,-4,-4
Lag,ID3,ID3,ID3,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,...,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD
Timestamp,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
2015-01-11 03:00:00+00:00,24.922597,26.238903,27.002718,22.953776,23.168355,21.000000,30.000000,30.000000,43.153846,43.588694,...,12552.00,12653.50,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00
2015-01-11 04:00:00+00:00,26.238903,27.002718,30.000000,23.168355,21.000000,30.000000,30.000000,43.153846,43.588694,43.537764,...,12653.50,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75
2015-01-11 05:00:00+00:00,27.002718,30.000000,35.000000,21.000000,30.000000,30.000000,43.153846,43.588694,43.537764,48.252186,...,13797.25,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50
2015-01-11 06:00:00+00:00,30.000000,35.000000,35.000000,30.000000,30.000000,43.153846,43.588694,43.537764,48.252186,48.683607,...,13601.25,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50,8019.75
2015-01-11 08:00:00+00:00,35.000000,35.000000,35.000000,30.000000,43.153846,43.588694,43.537764,48.252186,48.683607,46.580903,...,13058.25,12296.75,11478.00,10813.25,10076.75,9351.00,8681.75,8249.50,8019.75,7929.75
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-10-29 07:00:00+00:00,15.897972,28.726407,42.253968,47.569562,50.342117,53.282098,51.046393,47.998597,45.133267,40.789998,...,13412.75,12756.50,12091.75,11434.00,10751.00,10082.50,9754.25,9540.25,9401.50,9437.75
2017-10-29 08:00:00+00:00,28.726407,42.253968,39.981959,50.342117,53.282098,51.046393,47.998597,45.133267,40.789998,38.877500,...,12756.50,12091.75,11434.00,10751.00,10082.50,9754.25,9540.25,9401.50,9437.75,9595.25
2017-10-29 09:00:00+00:00,42.253968,39.981959,37.668182,53.282098,51.046393,47.998597,45.133267,40.789998,38.877500,40.095831,...,12091.75,11434.00,10751.00,10082.50,9754.25,9540.25,9401.50,9437.75,9595.25,9990.00
2017-10-29 10:00:00+00:00,39.981959,37.668182,39.994760,51.046393,47.998597,45.133267,40.789998,38.877500,40.095831,58.901518,...,11434.00,10751.00,10082.50,9754.25,9540.25,9401.50,9437.75,9595.25,9990.00,10511.75


## Scaling

In [6]:
frame_train_unscaled = frame_train
frame_test_unscaled = frame_test

y_scaler = StandardScaler()
y_scaler.fit(frame_train['y'])

frame_train['y'] = y_scaler.transform(frame_train['y'])
frame_test['y'] = y_scaler.transform(frame_test['y'])

X_scaler = StandardScaler()
X_scaler.fit(frame_train['X'])

frame_train['X'] = X_scaler.transform(frame_train['X'])
frame_test['X'] = X_scaler.transform(frame_test['X'])

display(frame_train)

display(frame_test)

Lag,y,y,y,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X
Feature,0,1,2,-72,-72,-71,-71,-70,-70,-69,...,-8,-8,-7,-7,-6,-6,-5,-5,-4,-4
Lag,ID3,ID3,ID3,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,...,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD
Timestamp,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
2015-01-11 03:00:00+00:00,-0.883980,-0.798055,-0.748213,-1.015689,-1.001737,-1.143602,-0.555128,-0.555159,0.304953,0.333353,...,0.097109,0.138955,0.610075,0.529413,0.305839,-0.007720,-0.344854,-0.618579,-0.921875,-1.220787
2015-01-11 04:00:00+00:00,-0.798000,-0.748162,-0.552428,-1.001658,-1.143530,-0.555063,-0.555128,0.305043,0.333390,0.330022,...,0.138915,0.610038,0.529349,0.305774,-0.007783,-0.344911,-0.618616,-0.921884,-1.220751,-1.496400
2015-01-11 05:00:00+00:00,-0.748109,-0.552380,-0.225824,-1.143448,-0.555003,-0.555063,0.305069,0.333480,0.330059,0.638308,...,0.610006,0.529311,0.305706,-0.007856,-0.344983,-0.618679,-0.921925,-1.220761,-1.496360,-1.674410
2015-01-11 06:00:00+00:00,-0.552329,-0.225780,-0.225824,-0.554931,-0.555003,0.305110,0.333506,0.330149,0.638351,0.666520,...,0.529277,0.305662,-0.007931,-0.345065,-0.618759,-0.921997,-1.220808,-1.496371,-1.674368,-1.769026
2015-01-11 08:00:00+00:00,-0.225734,-0.225780,-0.225824,-0.554931,0.305152,0.333546,0.330175,0.638451,0.666563,0.529019,...,0.305625,-0.007981,-0.345146,-0.618848,-0.922085,-1.220887,-1.496422,-1.674379,-1.768983,-1.806090
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-10-29 07:00:00+00:00,-1.473459,-0.635571,0.248012,0.593955,0.775207,0.967429,0.821202,0.621867,0.434394,0.150340,...,0.451637,0.181379,-0.092363,-0.363187,-0.644397,-0.919629,-1.054739,-1.142824,-1.199955,-1.185062
2017-10-29 08:00:00+00:00,-0.635519,0.248050,0.099602,0.775254,0.967458,0.821229,0.621892,0.434488,0.150373,0.025278,...,0.181339,-0.092416,-0.363268,-0.644486,-0.919717,-1.054814,-1.142870,-1.199964,-1.185026,-1.120199
2017-10-29 09:00:00+00:00,0.248088,0.099642,-0.051536,0.967502,0.821261,0.621924,0.434513,0.150458,0.025308,0.104947,...,-0.092460,-0.363327,-0.644573,-0.919813,-1.054906,-1.142947,-1.200010,-1.185036,-1.120165,-0.957632
2017-10-29 10:00:00+00:00,0.099682,-0.051494,0.100439,0.821307,0.621960,0.434551,0.150485,0.025389,0.104979,1.334690,...,-0.363376,-0.644638,-0.919905,-1.055005,-1.143042,-1.200089,-1.185082,-1.120174,-0.957600,-0.742764


Lag,y,y,y,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X
Feature,0,1,2,-72,-72,-71,-71,-70,-70,-69,...,-8,-8,-7,-7,-6,-6,-5,-5,-4,-4
Lag,ID3,ID3,ID3,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,...,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD
Timestamp,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
2017-10-29 12:00:00+00:00,0.100518,0.172282,0.432890,0.434644,0.150576,0.025466,0.105090,1.334869,1.021697,0.842035,...,-0.920036,-1.055175,-1.143240,-1.200288,-1.185256,-1.120296,-0.957651,-0.742743,-0.463831,-0.247444
2017-10-29 13:00:00+00:00,0.172321,0.432925,0.781620,0.150634,0.025514,0.105136,1.334889,1.021809,0.842082,0.255591,...,-1.055237,-1.143317,-1.200386,-1.185358,-1.120390,-0.957723,-0.742781,-0.463838,-0.247421,-0.171257
2017-10-29 14:00:00+00:00,0.432960,0.781651,1.650408,0.025575,0.105183,1.334901,1.021831,0.842188,0.255626,-0.341605,...,-1.143380,-1.200464,-1.185456,-1.120491,-0.957813,-0.742848,-0.463872,-0.247428,-0.171235,-0.144077
2017-10-29 15:00:00+00:00,0.781680,1.650428,1.734947,0.105242,1.334922,1.021852,0.842211,0.255714,-0.341581,-0.631051,...,-1.200529,-1.185534,-1.120587,-0.957910,-0.742931,-0.463931,-0.247458,-0.171241,-0.144055,-0.123383
2017-10-29 16:00:00+00:00,1.650443,1.734966,1.153214,1.334960,1.021879,0.842237,0.255740,-0.341512,-0.631034,-0.633566,...,-1.185598,-1.120664,-0.958003,-0.743023,-0.464007,-0.247511,-0.171270,-0.144061,-0.123362,-0.172287
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-12-30 12:00:00+00:00,1.157089,0.863447,1.001604,1.867580,1.933660,2.032422,2.305205,2.469669,1.874722,1.507970,...,-0.694942,-0.693137,-0.835679,-0.889644,-0.903758,-0.839938,-0.711173,-0.515316,-0.217668,0.079131
2018-12-30 13:00:00+00:00,0.863475,1.001632,1.373078,1.933687,2.032429,2.305191,2.469682,1.874861,1.508031,1.275552,...,-0.693192,-0.835748,-0.889736,-0.903853,-0.840025,-0.711239,-0.515350,-0.217674,0.079150,0.296574
2018-12-30 14:00:00+00:00,1.001658,1.373101,1.399610,2.032454,2.305192,2.469663,1.874877,1.508158,1.275608,1.583182,...,-0.835806,-0.889807,-0.903945,-0.840118,-0.711322,-0.515411,-0.217703,0.079145,0.296589,0.466553
2018-12-30 15:00:00+00:00,1.373120,1.399633,1.372677,2.305212,2.469660,1.874875,1.508176,1.275727,1.583244,1.445736,...,-0.889866,-0.904017,-0.840209,-0.711413,-0.515488,-0.217756,0.079120,0.296585,0.466567,0.576613


In [7]:
X_train = list()

for index, row in frame_train['X'].iterrows():
    X_train.extend(row.tolist())

X_train = np.array(X_train)

X_train = X_train.reshape((len(frame_train), len(lag_X), len(X.columns)))

X_test = list()

for index, row in frame_test['X'].iterrows():
    X_test.extend(row.tolist())

X_test = np.array(X_test)

X_test = X_test.reshape((len(frame_test), len(lag_X), len(X.columns)))

In [8]:
# step1 = []
# step2 = []
# step3 = []

# for index, row in frame_train.iterrows():
#     step2 = []
#     for l in lag_X:
#         step1 = []
#         for c in X.columns:
#             step1.append(row['X'][f'{c}'][l])
#         step2.append(step1)
#     step3.append(step2)

# X_train = step3

# X_train = np.array(X_train)

In [9]:
# step1 = []
# step2 = []
# step3 = []

# for index, row in frame_test.iterrows():
#     step2 = []
#     for l in lag_X:
#         step1 = []
#         for c in X.columns:
#             step1.append(row['X'][f'{c}'][l])
#         step2.append(step1)
#     step3.append(step2)

# X_test = step3

# X_test = np.array(X_test)

# Learning

## Create model

In [10]:
LATENT_DIM = 5
BATCH_SIZE = 32
EPOCHS = 10

In [11]:
from keras.models import Sequential
from keras.layers import Dense, GRU
from keras.optimizers import SGD, Adam
from keras.utils.vis_utils import plot_model
from keras.layers import GRU, Dense, RepeatVector, TimeDistributed, Flatten

In [12]:
T = len(lag_X)
HORIZON = len(lag_y)

model = Sequential()

model.add(GRU(LATENT_DIM, input_shape = (T, len(X.columns))))

model.add(RepeatVector(HORIZON))

model.add(GRU(LATENT_DIM, return_sequences = True))

model.add(TimeDistributed(Dense(1)))

model.add(Flatten())

In [13]:
model.compile(optimizer = 'Adam', loss = 'mse')

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
gru (GRU)                    (None, 5)                 135       
_________________________________________________________________
repeat_vector (RepeatVector) (None, 3, 5)              0         
_________________________________________________________________
gru_1 (GRU)                  (None, 3, 5)              180       
_________________________________________________________________
time_distributed (TimeDistri (None, 3, 1)              6         
_________________________________________________________________
flatten (Flatten)            (None, 3)                 0         
Total params: 321
Trainable params: 321
Non-trainable params: 0
_________________________________________________________________


In [14]:
model.fit(X_train,
          np.array(frame_train['y']),
          batch_size=BATCH_SIZE,
          epochs=EPOCHS,
        #   validation_data=(valid_inputs['X'], valid_inputs['target']),
        #   callbacks=[earlystop],
          verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7feffe1643d0>

In [15]:
predictions = pd.DataFrame(index = frame_test.index, columns = pd.MultiIndex.from_product([['Prediction'], y.columns, lag_y], names = ['Type', 'Feature', 'Lag']))

frame_test = pd.merge(frame_test, predictions, left_index = True, right_index = True)

frame_test['Prediction'] = model.predict(X_test)

frame_test['Prediction'] = y_scaler.inverse_transform(frame_test['Prediction'])

frame_test['y'] = y_scaler.inverse_transform(frame_test['y'])

display(frame_test)

Lag,y,y,y,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,Prediction,Prediction,Prediction
Feature,0,1,2,-72,-72,-71,-71,-70,-70,-69,...,-7,-6,-6,-5,-5,-4,-4,ID3,ID3,ID3
Lag,ID3,ID3,ID3,ID3,LOAD,ID3,LOAD,ID3,LOAD,ID3,...,LOAD,ID3,LOAD,ID3,LOAD,ID3,LOAD,0,1,2
Timestamp,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
2017-10-29 12:00:00+00:00,39.994760,41.094024,45.084266,0.434644,0.150576,0.025466,0.105090,1.334869,1.021697,0.842035,...,-1.200288,-1.185256,-1.120296,-0.957651,-0.742743,-0.463831,-0.247444,32.864662,32.401371,32.235504
2017-10-29 13:00:00+00:00,41.094024,45.084266,50.422995,0.150634,0.025514,0.105136,1.334889,1.021809,0.842082,0.255591,...,-1.185358,-1.120390,-0.957723,-0.742781,-0.463838,-0.247421,-0.171257,30.855253,30.592043,30.599752
2017-10-29 14:00:00+00:00,45.084266,50.422995,63.723309,0.025575,0.105183,1.334901,1.021831,0.842188,0.255626,-0.341605,...,-1.120491,-0.957813,-0.742848,-0.463872,-0.247428,-0.171235,-0.144077,30.943216,30.578112,30.544212
2017-10-29 15:00:00+00:00,50.422995,63.723309,65.017521,0.105242,1.334922,1.021852,0.842211,0.255714,-0.341581,-0.631051,...,-0.957910,-0.742931,-0.463931,-0.247458,-0.171241,-0.144055,-0.123383,31.268072,30.687077,30.543612
2017-10-29 16:00:00+00:00,63.723309,65.017521,56.111738,1.334960,1.021879,0.842237,0.255740,-0.341512,-0.631034,-0.633566,...,-0.743023,-0.464007,-0.247511,-0.171270,-0.144061,-0.123362,-0.172287,30.656311,30.076653,29.983606
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-12-30 12:00:00+00:00,56.170316,51.675229,53.790740,1.867580,1.933660,2.032422,2.305205,2.469669,1.874722,1.507970,...,-0.889644,-0.903758,-0.839938,-0.711173,-0.515316,-0.217668,0.079131,38.546581,37.636208,37.153088
2018-12-30 13:00:00+00:00,51.675229,53.790740,59.477646,1.933687,2.032429,2.305191,2.469682,1.874861,1.508031,1.275552,...,-0.903853,-0.840025,-0.711239,-0.515350,-0.217674,0.079150,0.296574,38.680786,37.802254,37.293896
2018-12-30 14:00:00+00:00,53.790740,59.477646,59.883829,2.032454,2.305192,2.469663,1.874877,1.508158,1.275608,1.583182,...,-0.840118,-0.711322,-0.515411,-0.217703,0.079145,0.296589,0.466553,37.534847,36.551907,36.015961
2018-12-30 15:00:00+00:00,59.477646,59.883829,59.471501,2.305212,2.469660,1.874875,1.508176,1.275727,1.583244,1.445736,...,-0.711413,-0.515488,-0.217756,0.079120,0.296585,0.466567,0.576613,38.583588,37.485489,36.834534


# Results

## Plot prediction

In [16]:
import plotly.express as px
import plotly.graph_objects as go

fig1 = go.Scatter(      x = frame_test.index,
                        y = frame_test['y'][0]['ID3'],
                        name = 'Actual',
                        # color = hex_maroon
                        # title = "Log of Appliance Energy Consumption in Wh vs Time"
                    )

fig2 = go.Scatter(      x = frame_test.index,
                        y = frame_test['Prediction']['ID3'][0],
                        name = 'Predicted',
                        # color = hex_gold
                        # title = "Log of Appliance Energy Consumption in Wh vs Time"
                    )

data = [fig1, fig2]

fig = go.Figure(data = data)

fig.update_layout(      title = 'Forecast of test set',
                        xaxis_title = 'Timestamp',
                        yaxis_title = 'ID3 (€)')

fig.show()

## Metrics

In [17]:
def smape(A, F):
    return 100/len(A) * np.sum(2 * np.abs(F - A) / (np.abs(A) + np.abs(F)))

print(smape(frame_test['y'][0]['ID3'], frame_test['Prediction']['ID3'][0]))
print(smape(frame_test['y'][1]['ID3'], frame_test['Prediction']['ID3'][1]))
print(smape(frame_test['y'][2]['ID3'], frame_test['Prediction']['ID3'][2]))

39.97994890163044
41.39233838824775
41.41276997047666


In [18]:
from sklearn.metrics import mean_absolute_error

display(mean_absolute_error(frame_test['y'][0]['ID3'], frame_test['Prediction']['ID3'][0]))

18.7331313607741