### The TFT architecture includes:
- Input embeddings: Embedding layers for categorical and continuous features.
- Variable selection networks: Dynamically select relevant variables.
- Gating mechanisms: Control flow between layers.
- Temporal attention mechanism: Focus on important time steps.
- Fully connected output network: Predict future values.

# Imports and helper functions

In [None]:
import math
import time as timeit
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

print(torch.__version__)

DEVICE = 'mps' if torch.backends.mps.is_available() else 'cpu'
# DEVICE = 'cpu'

In [None]:
# create example data
def create_example_data(batch_size, seq_length, input_size, device):
    # Generate random frequencies and phases (0, 10) (0, 2*pi)
    frequencies = torch.rand(batch_size) * 10  
    phases = torch.rand(batch_size) * 2 * math.pi  

    # Create a time vector (+ target)
    time = torch.linspace(0, 2 * math.pi, seq_length + 1)

    # Generate sine waves (+ target)
    sine_waves = torch.zeros(batch_size, seq_length + 1)
    for i in range(batch_size):
        sine_waves[i] = (torch.sin(frequencies[i] * time + phases[i]))
    
    # add additive amplitude noise to features
    sine_waves[:, :-1] +=  torch.rand(batch_size, seq_length) / 4
    
    # Split into inputs and targets
    sine_waves = sine_waves[:, :, None]
    inputs = sine_waves[:, :-1].to(device)
    targets = sine_waves[:, -1].to(device)
    return sine_waves, inputs, targets, time

# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    32,
    20,
    1,
    DEVICE)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        sine_waves[i][-1].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

## Basic Temporal Fusion Transformer (TFT) (many to one)

### Model definition

In [None]:
# Basic Temporal Fusion Transformer TFT
class TemporalFusionTransformer(nn.Module):
    def __init__(
        self, 
        input_size: int,
        hidden_size: int,
        num_heads: int,
        seq_length: int,
        output_size: int,
        device: torch.device,
        dropout: float=0.1
    ):

        super(TemporalFusionTransformer, self).__init__()
        
        self.device = device
        
        # Embeddings
        self.input_embedding = nn.Linear(
            input_size,
            hidden_size
        )

        # Temporal attention
        self.attention_layer = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout
        )

        # gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(
                hidden_size,
                hidden_size
            ),
            nn.Sigmoid(),
        )

        # FCL for output
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

        # Positional encoding
        self.positional_encoding = self._get_positional_encoding(
            seq_length,
            hidden_size
        )
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:

        # apply input embedding
        x = self.input_embedding(x) + self.positional_encoding
        
        # apply temporal attention
        attn_output, _ = self.attention_layer(
            x,
            x,
            x,
        )
        
        # apply gating mechanism
        gated_output = (
            self.gate(attn_output) * F.sigmoid(attn_output)
        ).sum(dim=1)
                 
        # apply output layer
        output = self.fc(gated_output)
        
        return output
    
    def _get_positional_encoding(
        self,
        seq_length: int,
        hidden_size: int
    ) -> torch.Tensor:

        position = torch.arange(
            0, seq_length
        ).unsqueeze(1).to(self.device)
        
        div_term = torch.exp(
            torch.arange(0, hidden_size, 2) * (-math.log(10000.0) / hidden_size)
        ).to(self.device)
        
        pe = torch.zeros(
            1, seq_length, hidden_size
        ).to(self.device)
        
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        return pe

### Create TFT model

In [None]:
seq_length = 30
input_size = 1
hidden_size = 64
num_heads = 4
output_size = 1

# model = torch.compile(TemporalFusionTransformer(
#     input_size,
#     hidden_size,
#     num_heads,
#     seq_length,
#     output_size,
#     DEVICE
# )).to(DEVICE)

model = TemporalFusionTransformer(
    input_size,
    hidden_size,
    num_heads,
    seq_length,
    output_size,
    DEVICE
).to(DEVICE)

print(model)

### Testing the model

In [None]:
# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    32,
    seq_length,
    input_size,
    DEVICE)

print(example_input.shape, example_target.shape, time.shape)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        sine_waves[i][-1].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

In [None]:
# test the model
example_output = model(example_input)

print(example_input.shape, example_output.shape)

### Example Training loop

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)

criterion = nn.MSELoss()

num_epochs = 4000

start_time = timeit.time()

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(example_input)
    
    loss = criterion(output, example_target)
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 400 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

print(f"Training time: {timeit.time() - start_time:.2f} seconds")

In [None]:
for i in range(5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        output[i].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

# Extended Temporal Fusion Transformer (TFT) (many to many)
includes:
- Variable Selection Network: Dynamically selects important input variables for each time step. Includes gating to enhance interpretability.
- Static Embedding: Handles features that do not vary across time (e.g., location or demographic data). These are integrated into the temporal context.
- Quantile Forecasting: Provides probabilistic outputs for multiple quantiles (e.g., 10%, 50%, 90%).
- Attention with Gating: Combines attention mechanisms with gating to dynamically focus on relevant features and time steps.

In [None]:
class VariableSelectionNetwork(nn.Module):
    """
    Dynamically selects important variables using gating
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
    ):
         
        super(VariableSelectionNetwork, self).__init__()
        
        self.input_layer = nn.Linear(input_size,
                                     hidden_size)
        
        self.gate_layer = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
        
        self.output_layer = nn.Linear(hidden_size,
                                      output_size)
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        gated_input = self.gate_layer(x) * self.input_layer(x)
        
        return self.output_layer(gated_input)
    

class TemporalFusionTransformer(nn.Module):
    def __init__(
        self, 
        input_size: int,
        static_input_size: int,
        hidden_size: int,
        num_heads: int,
        seq_length: int,
        output_size: int,
        device: torch.device,
        dropout: float=0.1,
        quantiles: tuple=(0.1, 0.5, 0.9)
    ):

        super(TemporalFusionTransformer, self).__init__()
        
        self.device = device
        
        # quantile probabilistic forecasting
        self.quantiles = quantiles
        
        # Embeddings
        self.input_embedding = nn.Linear(
            input_size,
            hidden_size
        )
        self.static_embedding = nn.Linear(
            static_input_size,
            hidden_size
        )
        
        # variable selection
        self.variable_selection = VariableSelectionNetwork(
            input_size=hidden_size,
            hidden_size=hidden_size,
            output_size=hidden_size
        )
        
        # temporal attention
        self.attention_layer = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(
                hidden_size,
                hidden_size
            ),
            nn.Sigmoid(),
        )
        
        # output layer
        self.fc = nn.Linear(
            hidden_size,
            output_size * len(self.quantiles)
        )
        
        # Positional encoding
        self.positional_encoding = self._get_positional_encoding(
            seq_length,
            hidden_size
        )
        
    def forward(
        self,
        x: torch.Tensor,
        static_inputs: torch.Tensor
    ) -> torch.Tensor:

        # static embeddings
        static_context = self.static_embedding(static_inputs)
        
        # input embeddings
        x = self.input_embedding(x) + self.positional_encoding
        
        # variable selection
        x = self.variable_selection(x)
        
        # temporal attention
        attn_output, _ = self.attention_layer(
            x,
            x,
            x,
        )
        
        # gating mechanism
        gated_output = self.gate(attn_output) * attn_output
   
        # gated with static context
        gated_with_context = gated_output + static_context.unsqueeze(1)

        
        # output layer
        output = self.fc(gated_with_context)
        
        # reshape for quantiles
        batch_size, seq_length, _ = output.shape
        
        # [B, T, Q, O]
        output = output.view(
            batch_size,
            seq_length,
            len(self.quantiles),
            -1
        )
        
        return output
        
    def _get_positional_encoding(
        self,
        seq_length: int,
        hidden_size: int
    ) -> torch.Tensor:

        position = torch.arange(
            0, seq_length
        ).unsqueeze(1).to(self.device)
        
        div_term = torch.exp(
            torch.arange(0, hidden_size, 2) * (-math.log(10000.0) / hidden_size)
        ).to(self.device)
        
        pe = torch.zeros(
            1, seq_length, hidden_size
        ).to(self.device)
        
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        return pe    

### Testing the model

In [None]:
seq_length = 30
input_size = 1
static_input_size = 2
hidden_size = 64
num_heads = 4
output_size = 1
quantiles = (0.1, 0.5, 0.9)

model = TemporalFusionTransformer(
    input_size,
    static_input_size,
    hidden_size,
    num_heads,
    seq_length,
    output_size,
    DEVICE,
    quantiles=quantiles
).to(DEVICE)

print(model)

In [None]:
# Example inputs
dynamic_inputs = torch.randn(32, seq_length, input_size).to(DEVICE)
static_inputs = torch.zeros(32, static_input_size).to(DEVICE)  # Static inputs for each sample
targets = torch.randn(32, seq_length, output_size).to(DEVICE)

# Forward pass
output = model(dynamic_inputs,
               static_inputs)
print(output.shape)  # Should output: [32, seq_length, len(quantiles), output_size]

### Training with Quantile loss

In [None]:
class QuantileLoss(nn.Module):
    def __init__(self,
                 quantiles: tuple
    ):
        super(QuantileLoss, self).__init__()
        self.quantiles = quantiles
    
    def forward(self, preds, target):
        loss = 0
        
        for i, q in enumerate(self.quantiles):
            
            errors = target - preds[:, :, i]
            loss += torch.mean(
                torch.max((q - 1) * errors, q * errors)
            )
        
        return loss
    
criterion = QuantileLoss(quantiles)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

num_epochs = 1000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(dynamic_inputs,
                   static_inputs)
    
    loss = criterion(output, targets)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 100 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

# Temporal Fusion Transformer for forecasting (TFT) (many to one)

In [None]:
class TemporalFusionTransformerForecast(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_heads: int,
        num_layers: int,
        agg_method: str=None,
        dropout: float=0.1
    ):
        
        super(TemporalFusionTransformerForecast, self).__init__()
        
        # temporal embeddings
        self.input_projection = nn.Linear(
            input_size, hidden_size
        )
        
        # temporal processing layers
        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dropout=dropout,
                batch_first=True,
            ),
            num_layers=num_layers
        )
        
        # fully connected layes for forecasting
        self.gated_residual_network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
        )
        
        # output layer
        self.output_layer = nn.Linear(
            hidden_size, output_size
        )
        
        # aggregation method
        match agg_method:
            case 'last':
                agg_func = lambda x: x[:, -1]
            case 'mean':
                agg_func = lambda x: torch.mean(x, dim=1)
            case 'sum':
                agg_func = lambda x: torch.sum(x, dim=1)
            case _:
                raise ValueError(f"Invalid aggregation method: {agg_method}")
        self.agg_func = agg_func
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        # project temporal features
        x_proj = self.input_projection(x)
        
        # process temporal sequences
        temporal_encoded = self.temporal_encoder(x_proj)
        
        # aggregate information from temporal features
        temporal_context = self.agg_func(temporal_encoded)

        # pass through gated residual network
        gated_output = self.gated_residual_network(temporal_context)
        
        # generate forecast
        forecast = self.output_layer(gated_output)
        
        return forecast

### Test the network

In [None]:
# Hyperparameters
input_size = 1
hidden_size = 64 
output_size = 1    
num_heads = 4      
num_layers = 3
dropout = 0.1

# Initialize the model
model = TemporalFusionTransformerForecast(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_heads=num_heads,
    num_layers=num_layers,
    agg_method='last',
    dropout=dropout
).to(DEVICE)

print(model)

### Testing the model

In [None]:
# Sample data
batch_size = 32
seq_length = 12

x_temporal = torch.rand(
    batch_size, seq_length, input_size
).to(DEVICE)  # Temporal features

# Forward pass
forecast = model(x_temporal)
print(f"Forecast shape: {forecast.shape}")  # Expected: [batch_size, output_size]

### Train the model

In [None]:
# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    batch_size,
    seq_length,
    input_size,
    DEVICE)

print(example_input.shape, example_target.shape, time.shape)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        sine_waves[i][-1].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)

criterion = nn.MSELoss()

num_epochs = 4000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(example_input)
    
    loss = criterion(output, example_target)
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 400 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

In [None]:
for i in range(5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        output[i].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

# Extended Temporal Fusion Transformer for forecasting (TFT) (many to one)
includes Variable Input Selection Network

In [None]:
class VariableSelectionNetwork(nn.Module):
    """
    Dynamically selects important variables using gating
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
    ):
         
        super(VariableSelectionNetwork, self).__init__()
        
        self.input_layer = nn.Linear(input_size,
                                     hidden_size)
        
        self.gate_layer = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
        
        self.output_layer = nn.Linear(hidden_size,
                                      output_size)
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        gated_input = self.gate_layer(x) * self.input_layer(x)
        
        return self.output_layer(gated_input)

class TemporalFusionTransformerForecast(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_heads: int,
        num_layers: int,
        agg_method: str,
        dropout: float=0.1
    ):
        
        super(TemporalFusionTransformerForecast, self).__init__()
        
        # temporal embeddings
        self.input_projection = nn.Linear(
            input_size, hidden_size
        )
        
        # variable selection
        self.variable_selection = VariableSelectionNetwork(
            input_size=hidden_size,
            hidden_size=hidden_size,
            output_size=hidden_size
        )
        
        # temporal processing layers
        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dropout=dropout,
                batch_first=True,
            ),
            num_layers=num_layers
        )
        
        # fully connected layes for forecasting
        self.gated_residual_network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
        )
        
        # output layer
        self.output_layer = nn.Linear(
            hidden_size, output_size
        )
        
        # aggregation method
        match agg_method:
            case 'last':
                agg_func = lambda x: x[:, -1]
            case 'mean':
                agg_func = lambda x: torch.mean(x, dim=1)
            case 'sum':
                agg_func = lambda x: torch.sum(x, dim=1)
            case _:
                raise ValueError(f"Invalid aggregation method: {agg_method}")
        self.agg_func = agg_func
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        # project temporal features
        x_proj = self.input_projection(x)
        
        # variable selection`
        x_vs = self.variable_selection(x_proj)
        
        # process temporal sequences
        temporal_encoded = self.temporal_encoder(x_vs)
        
        # aggregate information from temporal features
        temporal_context = self.agg_func(temporal_encoded)

        
        # pass through gated residual network
        gated_output = self.gated_residual_network(temporal_context)
        
        # generate forecast
        forecast = self.output_layer(gated_output)
        
        return forecast

### Test the model

In [None]:
# Hyperparameters
input_size = 1
hidden_size = 64 
output_size = 1    
num_heads = 4      
num_layers = 3
dropout = 0.1

# Initialize the model
model = TemporalFusionTransformerForecast(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_heads=num_heads,
    num_layers=num_layers,
    agg_method='last',
    dropout=dropout
).to(DEVICE)

print(model)

In [None]:
# Sample data
batch_size = 32
seq_length = 12

x_temporal = torch.rand(
    batch_size, seq_length, input_size
).to(DEVICE)  # Temporal features

# Forward pass
forecast = model(x_temporal)
print(f"Forecast shape: {forecast.shape}")  # Expected: [batch_size, output_size]

### Train the model

In [None]:
# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    batch_size,
    seq_length,
    input_size,
    DEVICE)

print(example_input.shape, example_target.shape, time.shape)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        sine_waves[i][-1].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)

criterion = nn.MSELoss()

num_epochs = 4000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(example_input)
    
    loss = criterion(output, example_target)
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 400 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

In [None]:
for i in range(5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-'
    )
    plt.plot(
        time[-1],
        output[i].squeeze().cpu().detach().numpy(),
        'r.'
    )
plt.show()

# Extended CI Temporal Fusion Transformer for forecasting (TFT) (many to one)
includes Variable Input Selection Network, Gaussian Distribution prediction

In [None]:
class VariableSelectionNetwork(nn.Module):
    """
    Dynamically selects important variables using gating
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
    ):
         
        super(VariableSelectionNetwork, self).__init__()
        
        self.input_layer = nn.Linear(input_size,
                                     hidden_size)
        
        self.gate_layer = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
        
        self.output_layer = nn.Linear(hidden_size,
                                      output_size)
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        gated_input = self.gate_layer(x) * self.input_layer(x)
        
        return self.output_layer(gated_input)

class TemporalFusionTransformerProbabilityForecast(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_heads: int,
        num_layers: int,
        agg_method: str,
        dropout: float=0.1
    ):
        
        super(TemporalFusionTransformerProbabilityForecast, self).__init__()
        
        # temporal embeddings
        self.input_projection = nn.Linear(
            input_size, hidden_size
        )
        
        # variable selection
        self.variable_selection = VariableSelectionNetwork(
            input_size=hidden_size,
            hidden_size=hidden_size,
            output_size=hidden_size
        )
        
        # temporal processing layers
        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dropout=dropout,
                batch_first=True,
            ),
            num_layers=num_layers
        )
        
        # fully connected layes for forecasting
        self.gated_residual_network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
        )
        
        # mean and std layer
        self.mean_layer = nn.Linear(
            hidden_size, output_size
        )
        self.std_layer = nn.Linear(
            hidden_size, output_size
        )
        
        # aggregation method
        match agg_method:
            case 'last':
                agg_func = lambda x: x[:, -1]
            case 'mean':
                agg_func = lambda x: torch.mean(x, dim=1)
            case 'sum':
                agg_func = lambda x: torch.sum(x, dim=1)
            case _:
                raise ValueError(f"Invalid aggregation method: {agg_method}")
        self.agg_func = agg_func
        
    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        
        # project temporal features
        x_proj = self.input_projection(x)
        
        # variable selection`
        x_vs = self.variable_selection(x_proj)
        
        # process temporal sequences
        temporal_encoded = self.temporal_encoder(x_vs)
        
        # aggregate information from temporal features
        temporal_context = self.agg_func(temporal_encoded)

        
        # pass through gated residual network
        gated_output = self.gated_residual_network(temporal_context)
        
        # generate forecast
        mean = self.mean_layer(gated_output)
        std = self.std_layer(gated_output)
        
        return mean, std
    

# Loss Function: Negative Log-Likelihood
def gaussian_nll_loss(y_true, mean, std):
    """
    Gaussian negative log-likelihood loss.
    y_true: [batch_size, 1] (ground truth)
    mean: [batch_size, 1] (predicted mean)
    std: [batch_size, 1] (predicted standard deviation)
    """
    variance = std ** 2
    output = torch.mean(
        0.5 * torch.log(variance) + (y_true - mean) ** 2 / (2 * variance)
    )
    return output

### Test the model

In [None]:
# Hyperparameters
input_size = 1
hidden_size = 64 
output_size = 1    
num_heads = 4      
num_layers = 3
dropout = 0.1

# Initialize the model
model = TemporalFusionTransformerProbabilityForecast(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_heads=num_heads,
    num_layers=num_layers,
    agg_method='last',
    dropout=dropout
).to(DEVICE)

print(model)

In [None]:
# Sample data
batch_size = 32
sequence_length = 12

x_temporal = torch.rand(
    batch_size, seq_length, input_size
).to(DEVICE)

# Forward pass
mean, std = model(x_temporal)
print(f"Mean shape: {mean.shape}, Std shape: {std.shape}")

# Example usage of loss
y_true = torch.rand(
    batch_size, 1
).to(DEVICE) 
loss = gaussian_nll_loss(y_true, mean, std)
print(f"Loss: {loss.item()}")

### Train the model

In [None]:
# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    batch_size,
    seq_length,
    input_size,
    DEVICE)

print(example_input.shape, example_target.shape, time.shape)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-1],
        sine_waves[i][-1].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)

num_epochs = 6000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    mean, std = model(example_input)
    
    loss = gaussian_nll_loss(example_target, mean, std)
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

In [None]:
for i in range(2):
    plt.figure()
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-'
    )
    plt.plot(
        time[-1],
        mean[i].squeeze().cpu().detach().numpy(),
        'ro'
    )
    upper = (mean[i].squeeze().cpu().detach().numpy()
             + std[i].squeeze().cpu().detach().numpy()
    )
    lower = (mean[i].squeeze().cpu().detach().numpy()
             - std[i].squeeze().cpu().detach().numpy()
    )
    plt.plot(time[-1], upper, 'r.')
    plt.plot(time[-1], lower, 'r.')
    plt.show()

# LSTM Temporal Fusion Network
includes variable selection network, lstm with attention, static context

In [None]:
class VariableSelectionNetwork(nn.Module):
    """Dynamically selects relevant variables at each time step."""
    def __init__(self,
                 input_size: int,
                 hidden_size: int
    ) -> torch.Tensor:
        super().__init__()
        
        # gate network
        self.gate = nn.Linear(input_size, hidden_size)
        
        # transformation network
        self.transformation = nn.Linear(input_size, hidden_size)
        
        # softmax
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, inputs):
         # Variable selection weights
        gate_weights = self.softmax(self.gate(inputs))
        
        # Variable selection
        transformed_inputs = self.transformation(inputs)
        
        # Apply variable selection
        return gate_weights * transformed_inputs


class TemporalFusionTransformerLSTM(nn.Module):
    """Simplified Temporal Fusion Transformer with LSTM and variable selection."""
    def __init__(self,
                 input_size,
                 static_size,
                 lstm_hidden_size,
                 attention_heads,
                 output_size,
                 seq_length
    ):
        super(TemporalFusionTransformerLSTM, self).__init__()
        
        # Static Variable Selection
        self.static_selection = VariableSelectionNetwork(static_size,
                                                         lstm_hidden_size)
        
        # Temporal Variable Selection
        self.temporal_selection = VariableSelectionNetwork(input_size,
                                                           lstm_hidden_size)
        
        # LSTM Encoder
        self.lstm = nn.LSTM(lstm_hidden_size,
                            lstm_hidden_size,
                            batch_first=True)
        
        # Multi-head Attention
        self.attention = nn.MultiheadAttention(lstm_hidden_size,
                                               num_heads=attention_heads,
                                               batch_first=True)
        
        # Static Enrichment
        self.static_enrichment = nn.Linear(lstm_hidden_size,
                                           lstm_hidden_size)
        
        # Feedforward Network
        self.feedforward = nn.Sequential(
            nn.Linear(lstm_hidden_size,
                      lstm_hidden_size),
            nn.ReLU(),
            nn.Linear(lstm_hidden_size,
                      lstm_hidden_size)
        )
        
        # Output layer
        self.output_layer = nn.Linear(lstm_hidden_size,
                                      output_size)
        
        # Sequence length
        self.seq_length = seq_length
    
    def forward(self, static_inputs, temporal_inputs):
        
        # Static variable selection
        static_context = self.static_selection(static_inputs)
        static_context = static_context.mean(dim=1)  # Aggregate across static features
        
        # Temporal variable selection
        temporal_features = self.temporal_selection(temporal_inputs)
        
        # LSTM encoder
        lstm_output, _ = self.lstm(temporal_features)
        
        # Static enrichment
        enriched_static = self.static_enrichment(static_context).unsqueeze(1).expand(-1, self.seq_length, -1)
        enriched_lstm_output = lstm_output + enriched_static
        
        # Temporal attention
        attention_output, _ = self.attention(enriched_lstm_output, enriched_lstm_output, enriched_lstm_output)
        
        # Feedforward processing
        processed_output = self.feedforward(attention_output)
        
        # Output layer
        predictions = self.output_layer(processed_output)
        
        return predictions
