In [None]:
# ===============================
# 0. Librairies
# ===============================
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import MeanAbsoluteError
import keras
import matplotlib.pyplot as plt

In [None]:
# ===============================
# 1. Generate complex synthetic time series
# ===============================
def create_complex_time_series(n_samples=2000, seq_len=50, n_features=3, seed=42):
    np.random.seed(seed)
    t = np.arange(n_samples + seq_len)
    
    series = []
    for f in range(n_features):
        # Different frequency sine/cosine
        freq = np.random.uniform(0.01, 0.1)
        amp = np.random.uniform(0.5, 2.0)
        wave = amp * np.sin(2 * np.pi * freq * t) + amp/2 * np.cos(2 * np.pi * freq*0.5 * t)
        
        # Add trend
        trend = t * np.random.uniform(0.0005, 0.002)
        
        # Add seasonal effect
        season = 0.5 * np.sin(2 * np.pi * t / np.random.randint(30, 100))
        
        # Add noise
        noise = np.random.normal(0, 0.2, len(t))
        
        # Add occasional spikes
        spikes = np.zeros(len(t))
        spike_idx = np.random.choice(len(t), size=int(0.01*len(t)), replace=False)
        spikes[spike_idx] = np.random.uniform(1, 3, len(spike_idx))
        
        series.append(wave + trend + season + noise + spikes)
    
    series = np.stack(series, axis=-1)
    
    # Build sequences
    X, y = [], []
    for i in range(n_samples):
        X.append(series[i:i+seq_len])
        y.append(series[i+seq_len, 0])  # predict first feature as example
    
    X = np.array(X)  # (n_samples, seq_len, n_features)
    y = np.array(y)  # (n_samples,)
    return X, y, series


SEQ_LEN = 50
N_FEATURES = 3
X, y, true_series = create_complex_time_series(n_samples=5000, seq_len=SEQ_LEN, n_features=N_FEATURES)

In [None]:
# ===============================
# 2. Plotting the time series
# ===============================
plt.figure(figsize=(14, 5))
plt.plot(true_series[:, 0], label="True underlying target series")
plt.title("True Continuous Time Series")
plt.legend()
plt.show()

In [None]:
# =============================== 
# 3. Prepare Encoder–Decoder Data 
# # =============================== 
def build_encoder_decoder_dataset(X, y, dec_len=None): 
    """ Builds a dataset for seq2seq time series: 
    - X_enc: encoder input (past window) 
    - X_dec: decoder input (teacher-forcing, shifted window) 
    - Y: target values (next-step prediction) 
    """ 
    N, seq_len, n_features = X.shape 
    if dec_len is None: dec_len = seq_len 
    # Encoder input = original window 
    X_enc = X.copy() # (N, seq_len, n_features) 
    # Decoder input = 1-step shifted window 
    X_dec = np.zeros_like(X) # initialize 
    # Shift encoder window right by 1 timestep 
    X_dec[:, 1:, :] = X[:, :-1, :] # teacher forcing 
    X_dec[:, 0, :] = 0 # start token 
    # Target values (next value of feature 0) 
    Y = y.reshape(-1, 1) # (N, 1) 
    return X_enc, X_dec, Y 

# Build encoder–decoder dataset PRED_HORIZON = 20 
X_enc, X_dec, Y = build_encoder_decoder_dataset(X, y, dec_len=SEQ_LEN) 
# Train/validation split 
train_size = int(0.8 * len(X)) 
X_enc_train, X_enc_val = X_enc[:train_size], X_enc[train_size:] 
X_dec_train, X_dec_val = X_dec[:train_size], X_dec[train_size:] 
Y_train, Y_val = Y[:train_size], Y[train_size:]

In [None]:
# ===============================
# 4. Attention Variants
# ===============================
class MultiQueryAttention(layers.Layer):
    """
    Multi-Query Attention: shared key/value projections across heads
    Supports a query, key, value interface like standard MultiHeadAttention 
    (compatible drop-in replacement for MultiHeadAttention)
    """
    def __init__(self, num_heads, key_dim):
        super().__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.supports_masking = True

        # Query gets separate projection for each head
        self.q_proj = layers.Dense(num_heads * key_dim)
        # Keys and Values shared across all heads
        self.k_proj = layers.Dense(key_dim)
        self.v_proj = layers.Dense(key_dim)
        # Final linear projection
        self.out = layers.Dense(num_heads * key_dim)

    def call(self, query, value=None, key=None, mask=None):
        """
        query: (B, Tq, d_model)
        key  : (B, Tk, d_model) or None
        value: (B, Tk, d_model) or None
        mask : (B, 1, Tq, Tk) optional
        """
        if key is None:
            key = query
        if value is None:
            value = query

        B = tf.shape(query)[0]
        Tq = tf.shape(query)[1]
        Tk = tf.shape(key)[1]
        
        # Linear projection (Project Q, K, V)
        q = self.q_proj(query)                  # (B, Tq, H*D)
        k = self.k_proj(key)                    # (B, Tk, D)
        v = self.v_proj(value)                  # (B, Tk, D)

        # Reshape Q into heads
        q = tf.reshape(q, (B, Tq, self.num_heads, self.key_dim))
        q = tf.transpose(q, [0, 2, 1, 3])        # (B, H, Tq, D)

        # Expand shared Keys and Values across heads
        k = k[:, None, :, :]                    # (B, 1, Tk, D)
        v = v[:, None, :, :]                    # (B, 1, Tk, D)

        # Scaled dot-product attention
        scores = tf.matmul(q, k, transpose_b=True)  # (B, H, Tq, Tk)
        scores /= tf.sqrt(tf.cast(self.key_dim, tf.float32))
        # Apply attention mask
        if mask is not None:
            scores += mask  # broadcasting works if mask is (B, 1, Tq, Tk)

        # Softmax over key dimension
        attn = tf.nn.softmax(scores, axis=-1)
        # Weighted sum
        out = tf.matmul(attn, v)                # (B, H, Tq, D)

        # Restore shape
        out = tf.transpose(out, [0, 2, 1, 3])    # (B, Tq, H, D)
        out = tf.reshape(out, (B, Tq, self.num_heads * self.key_dim))

        # Final linear projection
        return self.out(out)

In [None]:
# ===============================
# 5.Transformer Encoder Block
# ===============================
class EncoderBlock(layers.Layer):
    """
    One Transformer encoder block 

    Consists of:
    1. Multi-Head (or Multi-Query) Self-Attention
    2. Add & LayerNorm
    3. Feed-Forward Network (FFN)
    4. Add & LayerNorm
    """
    def __init__(self, d_model, num_heads, ff_layers=None, use_mqa=False):
        super().__init__()

        # Standard MHA or Multi-Query Attention
        if use_mqa:
            self.attn = MultiQueryAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        else:
            self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)

        # Feed-forward network
        if ff_layers is None:
            ff_layers = [128]
        ffn_list = []
        for ff_layer in ff_layers:
            ffn_list.append(layers.Dense(ff_layer, activation='relu'))
            ffn_list.append(layers.Dropout(0.1))
        ffn_list.append(layers.Dense(d_model))
        self.ffn = tf.keras.Sequential(ffn_list)

        # Two LayerNorms for stabilizing training
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, x, mask=None, training=False):
        """
        x: (batch, seq_len, d_model)
        mask: (batch, 1, seq_len, seq_len) or None
        """
        # 1. SELF-ATTENTION
        if isinstance(self.attn, MultiQueryAttention):
            attn_out = self.attn(query=x, value=x, key=x, mask=mask)
        else:
            attn_out = self.attn(query=x, value=x, key=x, attention_mask=mask)
        # Residual connection + normalization
        x = self.norm1(x + attn_out)

        # 2. FEED FORWARD
        ffn_output = self.ffn(x)
        # Residual + normalization
        return self.norm2(x + ffn_output)

In [None]:
# ===============================
# 6. Decoder Block (masked + cross-attention)
# ===============================
class DecoderBlock(layers.Layer):
    """
    One Transformer decoder block

    Consists of:
    1. Causal masked self-attention
    2. Encoder-decoder cross-attention
    3. Feed-forward network
    """
    def __init__(self, d_model, num_heads, ff_layers=None, use_mqa=False):
        super().__init__()
        
        # Self-attention (causal mask applied) and Cross-attention (decoder queries, encoder keys)
        if use_mqa:
            self.self_attn = MultiQueryAttention(num_heads=num_heads, key_dim=d_model // num_heads)
            self.cross_attn = MultiQueryAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        else:
            self.self_attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
            self.cross_attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)

        # Feed-forward network
        if ff_layers is None:
            ff_layers = [128]
        ffn_list = []
        for ff_layer in ff_layers:
            ffn_list.append(layers.Dense(ff_layer, activation='relu'))
            ffn_list.append(layers.Dropout(0.1))
        ffn_list.append(layers.Dense(d_model))
        self.ffn = tf.keras.Sequential(ffn_list)

        # LayerNorms
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.norm3 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, x, enc_out, causal_mask=None, training=False):
        """
        x       : decoder input (batch, dec_len, d_model)
        enc_out : encoder output  (batch, enc_len, d_model)
        causal_mask: prevents decoder from looking into the future
        """
        # Masked self-attention and Cross-attention
        if isinstance(self.self_attn, MultiQueryAttention):
            attn1 = self.self_attn(query=x, value=x, key=x, mask=causal_mask)
            attn2 = self.cross_attn(x, enc_out)
        else:
            attn1 = self.self_attn(query=x, value=x, key=x, attention_mask=causal_mask)
            attn2 = self.cross_attn(query=x, value=enc_out, key=enc_out)
        x = self.norm1(x + attn1)
        x = self.norm2(x + attn2)

        # Feed-forward
        ffn_out = self.ffn(x)

        # Final normalization
        x = self.norm3(x + ffn_out)
        return x

In [None]:
# ===============================
# 7. Positional Encoding
# ===============================
def positional_encoding(seq_len, d_model):
    pos = np.arange(seq_len)[:, None]
    i = np.arange(d_model)[None, :]
    angle_rates = 1 / np.power(10000, (2*(i//2)) / np.float32(d_model))
    angle_rads = pos * angle_rates
    pe = np.zeros_like(angle_rads)
    pe[:, 0::2] = np.sin(angle_rads[:, 0::2])
    pe[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return tf.cast(pe[None, ...], tf.float32)   # shape (1, seq_len, d_model)

In [None]:
# ===============================
# 8. Full Encoder–Decoder Transformer
# ===============================
class TransformerTimeSeries(tf.keras.Model):
    def __init__(self, enc_len, dec_len, d_model=64, num_heads=4, 
                 enc_ff_layers=None, num_enc_layers=2, 
                 dec_ff_layers=None, num_dec_layers=2,
                 use_mqa=False):
        super().__init__()

        #  Embeddings
        self.enc_embedding = layers.Dense(d_model)
        self.dec_embedding = layers.Dense(d_model)

        # Positional encoding
        self.enc_pos = positional_encoding(enc_len, d_model)
        self.dec_pos = positional_encoding(dec_len, d_model)

        # Encoder and decoder stacks
        self.encoder = [EncoderBlock(d_model, num_heads, enc_ff_layers, use_mqa) for _ in range(num_enc_layers)]
        self.decoder = [DecoderBlock(d_model, num_heads, dec_ff_layers, use_mqa) for _ in range(num_dec_layers)]
        
        self.out = layers.Dense(1)  # Final projection on the last step

    def causal_mask(self, seq_len):
        """
        Causal mask for decoder

        Creates a triangular causal mask so that
        position i can only attend to <= i
        """
        mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)   # Shape (1, 1, seq_len, seq_len)
        mask = mask * -1e9
        return mask[None, None, :, :]

    def call(self, inputs, training=False):
        encoder_input, decoder_input = inputs

        # ENCODER
        x = self.enc_embedding(encoder_input) + self.enc_pos
        for layer in self.encoder:
            x = layer(x, training=training)
        enc_out = x   # (B, enc_len, d_model)

        # DECODER
        y = self.dec_embedding(decoder_input) + self.dec_pos
        # Create causal mask
        mask = self.causal_mask(tf.shape(decoder_input)[1])
        for layer in self.decoder:
            y = layer(y, enc_out, causal_mask=mask, training=training)

        # Output prediction
        return self.out(y)[:, -1:, :]

In [None]:
# ===============================
# 9. Create the model
# ===============================
model = TransformerTimeSeries(
    enc_len=SEQ_LEN, dec_len=SEQ_LEN,
    d_model=128, num_heads=8,

    # FFN depths for encoder and decoder
    enc_ff_layers=[512], num_enc_layers=3,
    dec_ff_layers=[512], num_dec_layers=3,

    use_mqa=True
)

In [None]:
# ===============================
# 10. Extra Metrics (give more meaningful evaluation)
# ===============================
def rmse(y_true, y_pred):
    return tf.sqrt(tf.reduce_mean((y_true - y_pred) ** 2))

def mape(y_true, y_pred):
    return tf.reduce_mean(tf.abs((y_true - y_pred) / (y_true + 1e-6))) * 100

def r2_score(y_true, y_pred):
    ss_res = tf.reduce_sum(tf.square(y_true - y_pred))
    ss_tot = tf.reduce_sum(tf.square(y_true - tf.reduce_mean(y_true)))
    return 1 - ss_res / (ss_tot + 1e-8)

In [None]:
# ===============================
# 11. Learning Rate Scheduler: Cosine Decay + Warmup
# ===============================
steps_per_epoch = len(X_enc_train) // 32
total_steps = steps_per_epoch * 20
lr_schedule = keras.optimizers.schedules.CosineDecayRestarts(initial_learning_rate=1e-3, first_decay_steps=steps_per_epoch * 5, t_mul=1.5, m_mul=0.8, alpha=0.05)

In [None]:
# ===============================
# 12. Optimizer: AdamW + Learning Schedule + Weight Decay
# ===============================
optimizer = keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4)

In [None]:
# ===============================
# 13. Compile model
# ===============================
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.Huber(delta=1.0),
    metrics=[rmse, MeanAbsoluteError(), mape, r2_score]
)

In [None]:
# ===============================
# 14. Callbacks
# ===============================
callbacks = [EarlyStopping(monitor="val_loss", patience=6, restore_best_weights=True)]  # stop when validation stops improving

In [None]:
# ===============================
# 15. Training
# ===============================
history = model.fit(
    (X_enc_train, X_dec_train), Y_train,
    validation_data=((X_enc_val, X_dec_val), Y_val),
    epochs=10,                # let early stopping stop earlier
    batch_size=32,
    callbacks=callbacks
)

In [None]:
# ===============================
# 16. Plot Predictions
# ===============================
X_train = (X_enc_train, X_dec_train)
X_val = (X_enc_val, X_dec_val)

def plot_predictions(model, X, y_true, n=500):
    y_pred = model.predict(X)
    y_pred = y_pred[:n, 0, 0]  # reshape (B,1,1) → (B,)
    
    plt.figure(figsize=(12, 4))
    plt.plot(y_true[:n], label='True')
    plt.plot(y_pred, label='Pred')
    plt.legend()
    plt.show()

plot_predictions(model, X_val, Y_val)

In [None]:
# ===============================
# 17. Recursive Multi-step Forecasting
# ===============================
def recursive_forecast(model, X_enc_init, X_dec_init=None, n_future=100):
    """
    Predict n_future steps recursively from initial sequence
    
    X_enc_init: (seq_len, n_features)  -> encoder input
    X_dec_init: (seq_len, n_features)  -> decoder input, optional
    """
    seq_enc = X_enc_init.copy()
    if X_dec_init is None:
        # Start decoder input with zeros + start token
        seq_dec = np.zeros_like(seq_enc)
        seq_dec[0, :] = 0
        seq_dec[1:, :] = seq_enc[:-1, :]
    else:
        seq_dec = X_dec_init.copy()

    predictions = []
    for _ in range(n_future):
        # Model expects a tuple: (encoder_input, decoder_input)
        y_pred = model.predict((seq_enc[np.newaxis, :, :], seq_dec[np.newaxis, :, :]))  # shape (1,1,1)
        pred = y_pred[0, 0, 0]
        predictions.append(pred)
        # Prepare next decoder input
        new_step = np.zeros(seq_dec.shape[1])
        new_step[0] = pred  # first feature predicted
        seq_dec = np.vstack([seq_dec[1:], new_step])
        # Optionally append prediction to encoder input for rolling predictions
        seq_enc = np.vstack([seq_enc[1:], new_step])

    return np.array(predictions)

# Choose the last validation sequence as starting point
X_start_enc = X_enc_val[-1]
X_start_dec = X_dec_val[-1]
# Predict 20 future steps
n_future = 20
future_transformer = recursive_forecast(model, X_start_enc, X_start_dec, n_future)

In [None]:
# ===============================
# 18. Plot Future Forecasts
# ===============================
plt.figure(figsize=(12, 5))
plt.plot(np.arange(len(Y_val[-200:])), Y_val[-200:], label='Recent True Values', color='blue')
plt.plot(np.arange(len(Y_val[-200:]), len(Y_val[-200:]) + n_future), future_transformer, label='Transformer Forecast', color='green')
plt.xlabel('Time step')
plt.ylabel('Value')
plt.title('Multi-step Future Forecasting')
plt.legend()
plt.show()

In [None]:
# ===============================
# 19. Plot Predictions + Future Forecasts
# ===============================
def plot_predictions_and_forecast(model, X_val, y_val, n_past=200, n_future=100, label="Model"):
    """
    Plot past predictions on validation set and forecast n_future steps ahead
    
    X_val: tuple (X_enc_val, X_dec_val)
    y_val: true values
    """
    X_enc_val, X_dec_val = X_val
    # Past predictions
    y_pred_past = model.predict((X_enc_val[-n_past:], X_dec_val[-n_past:]))[:, 0, 0]
    # Future forecast starting from last validation sequence
    X_start_enc = X_enc_val[-1]
    X_start_dec = X_dec_val[-1]
    future_pred = recursive_forecast(model, X_start_enc, X_start_dec, n_future)

    # Plot
    plt.figure(figsize=(14, 5))
    plt.plot(np.arange(len(y_val[-n_past:])), y_val[-n_past:], label='True (Past)', color='blue')
    plt.plot(np.arange(len(y_val[-n_past:])), y_pred_past, label=f'Predicted (Past) - {label}', color='orange')
    plt.plot(np.arange(len(y_val[-n_past:]), len(y_val[-n_past:]) + n_future), future_pred, label=f'Forecast (Future) - {label}', color='green')
    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.title(f'{label} - Past Predictions & Future Forecast')
    plt.legend()
    plt.show()

# Plot for Transformer
plot_predictions_and_forecast(model, X_val, Y_val, n_past=200, n_future=100, label="Transformer")

In [None]:
# ===============================
# 20. Plot Predictions + Long Future Forecasts
# ===============================
plot_predictions_and_forecast(model, X_val, Y_val, n_past=1000, n_future=400, label="Transformer")