In [4]:
%%capture
!pip install plotnine

# https://github.com/vcerqueira/blog/tree/main/posts/dl_for_forecasting
Pada peramalan lstm disini menghasilkan beberapa variabel output dan multistep 
denga multivariate input multivariate output multistep

import pandas as pd
from plotnine import *

from src.tde import (time_delay_embedding,
                     from_3d_to_matrix,
                     from_matrix_to_3d)

# https://github.com/vcerqueira/blog/tree/main/data
data = pd.read_csv('https://raw.githubusercontent.com/vcerqueira/blog/refs/heads/main/data/wine_sales.csv', parse_dates=['date'])
data.set_index('date', inplace=True)

N_FEATURES = data.shape[1]
N_LAGS = 3
HORIZON = 2

plot_df = data.reset_index().melt('date')
plot_df['Type'] = 'Sales of different types of wine'
plot = \
    ggplot(plot_df) + \
    aes(x='date',
        y='np.log(value)',
        group='variable',
        color='variable') + \
    theme_538(base_family='Palatino', base_size=12) + \
    theme(plot_margin=0.2,
          axis_text=element_text(size=10),
          axis_text_x=element_text(angle=0, size=8),
          legend_title=element_blank(),
          legend_position='right')

plot += geom_line()
plot += facet_wrap('~ Type')
plot = \
    plot + \
    xlab('') + \
    ylab('Wine Sales (Log)') + \
    ggtitle('')

# print(plot)

# plot.save('mv_line_plot.pdf', height=5, width=8)


# transforming each variable into a matrix format
mat_by_variable = []
for col in data:
    col_df = time_delay_embedding(data[col], n_lags=N_LAGS, horizon=HORIZON)
    mat_by_variable.append(col_df)

# concatenating all variables
mat_df = pd.concat(mat_by_variable, axis=1).dropna()

# target_var = 'Sparkling'
# defining target (Y) and explanatory variables (X)
predictor_variables = mat_df.columns.str.contains('\(t\-|\(t\)')
# target_variables = mat_df.columns.str.contains(f'{target_var}\(t\+')
target_variables = mat_df.columns.str.contains('\(t\+')
X = mat_df.iloc[:, predictor_variables]
Y = mat_df.iloc[:, target_variables]

X_3d = from_matrix_to_3d(X)
Y_3d = from_matrix_to_3d(Y)

# Defining the LSTM ##################################################

from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import (Dense,
                          LSTM,
                          TimeDistributed,
                          RepeatVector,
                          Dropout)

model = Sequential()
model.add(LSTM(8, activation='relu', input_shape=(N_LAGS, N_FEATURES)))
model.add(Dropout(.2))
model.add(RepeatVector(HORIZON))
model.add(LSTM(4, activation='relu', return_sequences=True))
model.add(Dropout(.2))
model.add(TimeDistributed(Dense(N_FEATURES)))
model.compile(optimizer='adam', loss='mse')
model.summary()

######################################################################


X_train, X_valid, Y_train, Y_valid = train_test_split(X_3d, Y_3d, test_size=.2, shuffle=False)

model.fit(X_train, Y_train, epochs=500, validation_data=(X_valid, Y_valid))

preds = model.predict_on_batch(X_valid)

preds_df = from_3d_to_matrix(preds, Y.columns)
print(preds_df)
# preds_df = pd.DataFrame(preds, columns=Y.columns)



In [17]:
preds_df

Unnamed: 0,Fortified(t+1),Fortified(t+2),Drywhite(t+1),Drywhite(t+2),Sweetwhite(t+1),Sweetwhite(t+2),Red(t+1),Red(t+2),Rose(t+1),Rose(t+2),Sparkling(t+1),Sparkling(t+2)
0,2401.414795,2401.414795,2328.205078,2328.205078,1136.50415,1136.50415,75.574501,75.574501,1894.0271,1894.0271,195.698273,195.698273
1,2181.877197,2253.247803,2080.595947,2145.443115,1037.787842,1072.176758,66.954521,69.002708,1719.902466,1776.054932,180.017685,186.101227
2,1941.666504,1958.115601,1857.908447,1872.854126,922.696045,930.621826,59.847382,60.319435,1530.782227,1543.723999,159.821686,161.223785
3,1913.172363,1913.172363,1883.008789,1883.008789,901.482849,901.482849,61.486816,61.486816,1509.844971,1509.844971,154.180664,154.180664
4,1147.769531,1264.058716,1084.009766,1189.670044,547.881836,603.914307,34.531132,37.868393,904.637695,996.131042,95.462318,105.374664
5,2276.905762,2276.905762,2223.77832,2223.77832,1075.233887,1075.233887,72.420685,72.420685,1796.319458,1796.319458,184.537094,184.537094
6,2279.586182,2279.586182,2221.862061,2221.862061,1077.164185,1077.164185,72.287903,72.287903,1798.302002,1798.302002,185.039734,185.039734
7,2455.333252,2455.333252,2384.833496,2384.833496,1161.363037,1161.363037,77.489914,77.489914,1936.670654,1936.670654,199.813049,199.813049
8,2512.288818,2512.288818,2440.405762,2440.405762,1188.245239,1188.245239,79.30864,79.30864,1981.592773,1981.592773,204.427032,204.427032
9,2706.15918,2706.15918,2643.043701,2643.043701,1277.771729,1277.771729,86.14872,86.14872,2134.894287,2134.894287,219.283707,219.283707
