In [None]:
# ===============================
# 0. Librairies
# ===============================
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

In [None]:
# ===============================
# 1. Generate complex synthetic time series
# ===============================
def create_complex_time_series(n_samples=2000, seq_len=50, n_features=3, seed=42):
    np.random.seed(seed)
    t = np.arange(n_samples + seq_len)
    
    series = []
    for f in range(n_features):
        # Different frequency sine/cosine
        freq = np.random.uniform(0.01, 0.1)
        amp = np.random.uniform(0.5, 2.0)
        wave = amp * np.sin(2 * np.pi * freq * t) + amp/2 * np.cos(2 * np.pi * freq*0.5 * t)
        
        # Add trend
        trend = t * np.random.uniform(0.0005, 0.002)
        
        # Add seasonal effect
        season = 0.5 * np.sin(2 * np.pi * t / np.random.randint(30, 100))
        
        # Add noise
        noise = np.random.normal(0, 0.2, len(t))
        
        # Add occasional spikes
        spikes = np.zeros(len(t))
        spike_idx = np.random.choice(len(t), size=int(0.01*len(t)), replace=False)
        spikes[spike_idx] = np.random.uniform(1, 3, len(spike_idx))
        
        series.append(wave + trend + season + noise + spikes)
    
    series = np.stack(series, axis=-1)
    
    # Build sequences
    X, y = [], []
    for i in range(n_samples):
        X.append(series[i:i+seq_len])
        y.append(series[i+seq_len, 0])  # predict first feature as example
    
    X = np.array(X)  # (n_samples, seq_len, n_features)
    y = np.array(y)  # (n_samples,)
    return X, y, series

SEQ_LEN = 50
N_FEATURES = 3
X, y, true_series = create_complex_time_series(n_samples=5000, seq_len=SEQ_LEN, n_features=N_FEATURES)
train_size = int(0.8 * len(X))
X_train, X_val = X[:train_size], X[train_size:]
y_train, y_val = y[:train_size], y[train_size:]

In [None]:
# ===============================
# 2. Plotting the time series
# ===============================
plt.figure(figsize=(14, 5))
plt.plot(true_series[:, 0], label="True underlying target series")
plt.title("True Continuous Time Series")
plt.legend()
plt.show()

In [None]:
# ===============================
# 3. GRU 
# ===============================
input_layer = layers.Input(shape=(SEQ_LEN, N_FEATURES))
x = layers.GRU(64, return_sequences=True)(input_layer)
x = layers.MultiHeadAttention(num_heads=2, key_dim=32)(x, x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.3)(x)
output_layer = layers.Dense(1)(x)

# Model assembly
gru_model = models.Model(inputs=input_layer, outputs=output_layer)
gru_model.compile(optimizer='adam', loss='mse', metrics=['mae'])
gru_model.summary()

# Training
gru_history = gru_model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=30, batch_size=32, verbose=1)

In [None]:
# ===============================
# 4. Plot Predictions
# ===============================
def plot_predictions(model, X, y_true, n=200):
    y_pred = model.predict(X[:n])
    plt.figure(figsize=(12, 4))
    plt.plot(y_true[:n], label='True')
    plt.plot(y_pred, label='Pred')
    plt.legend()
    plt.show()

plot_predictions(gru_model, X_val, y_val)

In [None]:
# ===============================
# 5. Recursive Multi-step Forecasting
# ===============================
def recursive_forecast(model, X_init, n_future=100):
    """
    Predict n_future steps recursively from initial sequence X_init
    X_init: (seq_len, n_features)
    """
    seq = X_init.copy()
    predictions = []

    for _ in range(n_future):
        pred = model.predict(seq[np.newaxis, :, :])[0, 0]  # shape (1,1) -> scalar
        predictions.append(pred)
        
        # Append prediction to the sequence
        # For simplicity, keep other features unchanged or as 0
        new_step = np.zeros(seq.shape[1])
        new_step[0] = pred  # first feature predicted
        seq = np.vstack([seq[1:], new_step])

    return np.array(predictions)

# Choose the last validation sequence as starting point
X_start = X_val[-1]
n_future = 20
future_gru = recursive_forecast(gru_model, X_start, n_future)

In [None]:
# ===============================
# 6. Plot Future Forecasts
# ===============================
plt.figure(figsize=(12, 5))
plt.plot(np.arange(len(y_val[-200:])), y_val[-200:], label='Recent True Values', color='blue')
plt.plot(np.arange(len(y_val[-200:]), len(y_val[-200:]) + n_future), future_gru, label='GRU Forecast', color='orange')
plt.xlabel('Time step')
plt.ylabel('Value')
plt.title('Multi-step Future Forecasting')
plt.legend()
plt.show()

In [None]:
# ===============================
# 7. Plot Predictions + Future Forecasts
# ===============================
def plot_predictions_and_forecast(model, X_val, y_val, n_past=200, n_future=100, label="Model"):
    """
    Plot past predictions on validation set and forecast n_future steps ahead.
    """
    # Past predictions
    y_pred_past = model.predict(X_val[-n_past:])
    
    # Future forecast starting from last sequence
    X_start = X_val[-1]
    future_pred = recursive_forecast(model, X_start, n_future)
    
    # Plot everything
    plt.figure(figsize=(14, 5))
    
    # Plot past true values
    plt.plot(np.arange(len(y_val[-n_past:])), y_val[-n_past:], label='True (Past)', color='blue')
    # Plot past predictions
    plt.plot(np.arange(len(y_val[-n_past:])), y_pred_past, label=f'Predicted (Past) - {label}', color='orange')
    # Plot future forecast
    plt.plot(np.arange(len(y_val[-n_past:]), len(y_val[-n_past:]) + n_future), future_pred, label=f'Forecast (Future) - {label}', color='green')
    
    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.title(f'{label} - Past Predictions & Future Forecast')
    plt.legend()
    plt.show()

# Plot for GRU
plot_predictions_and_forecast(gru_model, X_val, y_val, n_past=200, n_future=100, label="GRU")

In [None]:
# ===============================
# 8. Plot Predictions + Long Future Forecasts
# ===============================
plot_predictions_and_forecast(gru_model, X_val, y_val, n_past=400, n_future=400, label="GRU")