In [None]:
# ===============================
# 0. Librairies
# ===============================
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

In [None]:
# ===============================
# 1. Generate complex synthetic time series
# ===============================
def create_complex_time_series(n_samples=2000, seq_len=50, n_features=3, seed=42):
    np.random.seed(seed)
    t = np.arange(n_samples + seq_len)
    
    series = []
    for f in range(n_features):
        # Different frequency sine/cosine
        freq = np.random.uniform(0.01, 0.1)
        amp = np.random.uniform(0.5, 2.0)
        wave = amp * np.sin(2 * np.pi * freq * t) + amp/2 * np.cos(2 * np.pi * freq*0.5 * t)
        
        # Add trend
        trend = t * np.random.uniform(0.0005, 0.002)
        
        # Add seasonal effect
        season = 0.5 * np.sin(2 * np.pi * t / np.random.randint(30, 100))
        
        # Add noise
        noise = np.random.normal(0, 0.2, len(t))
        
        # Add occasional spikes
        spikes = np.zeros(len(t))
        spike_idx = np.random.choice(len(t), size=int(0.01*len(t)), replace=False)
        spikes[spike_idx] = np.random.uniform(1, 3, len(spike_idx))
        
        series.append(wave + trend + season + noise + spikes)
    
    series = np.stack(series, axis=-1)
    
    # Build sequences
    X, y = [], []
    for i in range(n_samples):
        X.append(series[i:i+seq_len])
        y.append(series[i+seq_len, 0])  # predict first feature as example
    
    X = np.array(X)  # (n_samples, seq_len, n_features)
    y = np.array(y)  # (n_samples,)
    return X, y, series

SEQ_LEN = 50
N_FEATURES = 3
X, y, true_series = create_complex_time_series(n_samples=5000, seq_len=SEQ_LEN, n_features=N_FEATURES)
train_size = int(0.8 * len(X))
X_train, X_val = X[:train_size], X[train_size:]
y_train, y_val = y[:train_size], y[train_size:]

In [None]:
# ===============================
# 2. Plotting the time series
# ===============================
plt.figure(figsize=(14, 5))
plt.plot(true_series[:, 0], label="True underlying target series")
plt.title("True Continuous Time Series")
plt.legend()
plt.show()

In [None]:
# ===============================
# 3. Custom Attention Layers
# ===============================
class MultiQueryAttention(layers.Layer):
    """Multi-Query Attention: shared key/value projections across heads"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.q_dense = [layers.Dense(d_model) for _ in range(num_heads)]
        self.k_dense = layers.Dense(d_model)  # shared
        self.v_dense = layers.Dense(d_model)  # shared
        self.out = layers.Dense(d_model)
    
    def call(self, x):
        k = self.k_dense(x)
        v = self.v_dense(x)
        head_outputs = []
        for q_layer in self.q_dense:
            q = q_layer(x)
            attn = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.d_model, tf.float32))
            attn = tf.nn.softmax(attn, axis=-1)
            head_outputs.append(tf.matmul(attn, v))
        return self.out(tf.concat(head_outputs, axis=-1))

class LatentAttention(layers.Layer):
    """Low-rank Latent Attention approximation"""
    def __init__(self, d_model, latent_dim):
        super().__init__()
        self.q_dense = layers.Dense(latent_dim)
        self.k_dense = layers.Dense(latent_dim)
        self.v_dense = layers.Dense(latent_dim)
        self.out = layers.Dense(d_model)
    
    def call(self, x):
        q = self.q_dense(x)
        k = self.k_dense(x)
        v = self.v_dense(x)
        attn = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(q.shape[-1], tf.float32))
        attn = tf.nn.softmax(attn, axis=-1)
        return self.out(tf.matmul(attn, v))

In [None]:
# ===============================
# 4. Transformer Block
# ===============================
class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, ff_dim, use_mqa=False, use_latent=False, latent_dim=16):
        super().__init__()
        if use_latent:
            self.attn = LatentAttention(d_model, latent_dim)
        elif use_mqa:
            self.attn = MultiQueryAttention(d_model, num_heads)
        else:
            self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = models.Sequential([
            layers.Dense(ff_dim, activation='relu'),
            layers.Dense(d_model)
        ])
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.use_builtin = not isinstance(self.attn, (MultiQueryAttention, LatentAttention))
    
    def call(self, x):
        if self.use_builtin:
            attn_out = self.attn(x, x)   # standard MHA
        else:
            attn_out = self.attn(x)      # custom MQA/MLA
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ffn(x))
        return x

In [None]:
# ===============================
# 5. Simple Transformer Model for Time Series
# ===============================
class TimeSeriesTransformer(tf.keras.Model):
    def __init__(self, seq_len, n_features, d_model=32, num_heads=2, ff_dim=64, use_mqa=False, use_latent=False, latent_dim=16):
        super().__init__()
        self.embedding = layers.Dense(d_model)
        self.pos_encoding = self._positional_encoding(seq_len, d_model)
        self.block = TransformerBlock(d_model, num_heads, ff_dim, use_mqa=use_mqa, use_latent=use_latent, latent_dim=latent_dim)
        self.out = layers.Dense(1)
    
    def call(self, x):
        x = self.embedding(x) + self.pos_encoding
        x = self.block(x)
        x = tf.reduce_mean(x, axis=1)
        return self.out(x)
    
    def _positional_encoding(self, seq_len, d_model):
        pos = np.arange(seq_len)[:, None]
        i = np.arange(d_model)[None, :]
        rates = 1 / np.power(10000, (2*(i//2)) / d_model)
        angles = pos * rates
        pe = np.zeros_like(angles)
        pe[:, 0::2] = np.sin(angles[:, 0::2])
        pe[:, 1::2] = np.cos(angles[:, 1::2])
        return tf.cast(pe[None, ...], tf.float32)

In [None]:
# ===============================
# 6. Usage Example
# ===============================
transformer_model = TimeSeriesTransformer(
    SEQ_LEN, N_FEATURES,
    d_model=32, num_heads=2, ff_dim=64,
    use_mqa=False,       
    use_latent=False    
)

transformer_model.compile(optimizer='adam', loss='mse')
transformer_history = transformer_model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=50, batch_size=32)

In [None]:
# ===============================
# 7. Plot Predictions
# ===============================
def plot_predictions(model, X, y_true, n=200):
    y_pred = model.predict(X[:n])
    plt.figure(figsize=(12, 4))
    plt.plot(y_true[:n], label='True')
    plt.plot(y_pred, label='Pred')
    plt.legend()
    plt.show()

plot_predictions(transformer_model, X_val, y_val)

In [None]:
# ===============================
# 8. Recursive Multi-step Forecasting
# ===============================
def recursive_forecast(model, X_init, n_future=100):
    """
    Predict n_future steps recursively from initial sequence X_init
    X_init: (seq_len, n_features)
    """
    seq = X_init.copy()
    predictions = []

    for _ in range(n_future):
        pred = model.predict(seq[np.newaxis, :, :])[0, 0]  # shape (1,1) -> scalar
        predictions.append(pred)
        
        # Append prediction to the sequence
        # For simplicity, keep other features unchanged or as 0
        new_step = np.zeros(seq.shape[1])
        new_step[0] = pred  # first feature predicted
        seq = np.vstack([seq[1:], new_step])

    return np.array(predictions)

# Choose the last validation sequence as starting point
X_start = X_val[-1]
n_future = 20
future_transformer = recursive_forecast(transformer_model, X_start, n_future)

In [None]:
# ===============================
# 9. Plot Future Forecasts
# ===============================
plt.figure(figsize=(12, 5))
plt.plot(np.arange(len(y_val[-200:])), y_val[-200:], label='Recent True Values', color='blue')
plt.plot(np.arange(len(y_val[-200:]), len(y_val[-200:]) + n_future), future_transformer, label='Transformer Forecast', color='orange')
plt.xlabel('Time step')
plt.ylabel('Value')
plt.title('Multi-step Future Forecasting')
plt.legend()
plt.show()

In [None]:
# ===============================
# 10. Plot Predictions + Future Forecasts
# ===============================
def plot_predictions_and_forecast(model, X_val, y_val, n_past=200, n_future=100, label="Model"):
    """
    Plot past predictions on validation set and forecast n_future steps ahead.
    """
    # Past predictions
    y_pred_past = model.predict(X_val[-n_past:])
    
    # Future forecast starting from last sequence
    X_start = X_val[-1]
    future_pred = recursive_forecast(model, X_start, n_future)
    
    # Plot everything
    plt.figure(figsize=(14, 5))
    
    # Plot past true values
    plt.plot(np.arange(len(y_val[-n_past:])), y_val[-n_past:], label='True (Past)', color='blue')
    # Plot past predictions
    plt.plot(np.arange(len(y_val[-n_past:])), y_pred_past, label=f'Predicted (Past) - {label}', color='orange')
    # Plot future forecast
    plt.plot(np.arange(len(y_val[-n_past:]), len(y_val[-n_past:]) + n_future), future_pred, label=f'Forecast (Future) - {label}', color='green')
    
    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.title(f'{label} - Past Predictions & Future Forecast')
    plt.legend()
    plt.show()

# Plot for transformers
plot_predictions_and_forecast(transformer_model, X_val, y_val, n_past=200, n_future=100, label="Transformer")

In [None]:
# ===============================
# 11. Plot Predictions + Long Future Forecasts
# ===============================
plot_predictions_and_forecast(transformer_model, X_val, y_val, n_past=1000, n_future=400, label="Transformer")