In [2]:
import torch
import torch.nn as nn

## Sample code to export ONNX Model

In [None]:
from src.TemporalFusionTransformer import TemporalFusionTransformer
from src.VQVAE import VQVAE
from typing import Optional

feature_dim = 32
future_steps = 40
lookback = 30

vqvae = VQVAE(input_dim=feature_dim, hidden_dim=512, num_embeddings=128, embedding_dim=128, commitment_cost=0.25)

class EnhancedTFT(nn.Module):
    def __init__(self, num_features, num_hidden, num_outputs, num_steps, vqvae:VQVAE):
        super(EnhancedTFT, self).__init__()
        self.vqvae = vqvae
        self.tft = TemporalFusionTransformer(num_features + vqvae.encoder.fc2.out_features, num_hidden, num_outputs, num_steps, num_attention_heads=8)
        self.num_outputs = num_outputs
        self.num_steps = num_steps
        
    def forward(self, x, mask: Optional[torch.Tensor]=None):
        x_recon, vq_loss, perplexity, embedding = self.vqvae(x)
        x_enhanced = torch.cat((x, embedding), dim=-1)
        return self.tft(x_enhanced, mask), vq_loss, perplexity

model = EnhancedTFT(num_features=feature_dim, num_hidden=128, num_outputs=2, num_steps=future_steps, vqvae=vqvae)

model.to('cpu')

dummy_input = torch.randn(1, lookback, feature_dim)

# Export the wrapped model to ONNX format
torch.onnx.export(
    model,                   # Wrapped model to export
    dummy_input,                     # Model input
    "model.onnx",              # Output file name
    export_params=True,              # Store the trained parameter weights inside the model file
    opset_version=13,                # Set the ONNX opset version (adjust as needed)
    do_constant_folding=True,        # Whether to execute constant folding for optimization
    input_names=['input'],           # The model's input names
    output_names=['output'],         # The model's output names
    dynamic_axes={
        'input': {0: 'batch_size'},  # Enable dynamic axes for input
        'output': {0: 'batch_size'}  # Enable dynamic axes for output
    }
)


In [11]:
model(dummy_input).shape

torch.Size([1, 40, 2])