### Imports and helper functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import matplotlib.pyplot as plt
import numpy as np

DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
# create example data
def create_example_data(batch_size, seq_length,
                        input_size, device, horizon, trend=False):
    # Generate random frequencies and phases (0, 10) (0, 2*pi)
    frequencies = torch.rand(batch_size) * 10  
    phases = torch.rand(batch_size) * 2 * math.pi  

    # Create a time vector (+ target)
    time = torch.linspace(0, 2 * math.pi, seq_length + horizon)

    # Generate sine waves (+ target)
    sine_waves = torch.zeros(batch_size, seq_length + horizon)
    for i in range(batch_size):
        sine_waves[i] = (torch.sin(frequencies[i] * time + phases[i]))
    
    if trend:
        sine_waves += torch.linspace(0, 2 * math.pi, seq_length + horizon)
                
    
    # add additive amplitude noise to features
    sine_waves[:, :-horizon] +=  torch.rand(batch_size, seq_length) / 2
    
    # Split into inputs and targets
    sine_waves = sine_waves[:, :, None]
    inputs = sine_waves[:, :-horizon].to(device)
    if horizon > 1:
        targets = sine_waves[:, -horizon:].to(device)
    else:
        targets = sine_waves[:, -1].unsqueeze(-1).to(device)
  
    return sine_waves, inputs, targets, time

# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    batch_size=32,
    seq_length=20,
    input_size=1,
    device=DEVICE,
    horizon=3,
    trend=True)

print(example_input.shape, example_target.shape, time.shape)

# plot example data
plt.figure()

for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-3:],
        sine_waves[i][-3:].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()

# Informer (transformer) model 
for long sequences

includes ProbSparse-Attention, PosEnc, enc-dec model

### Positional encoding

In [None]:
# positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model
        ).to(device)
        position = torch.arange(0, max_len, dtype=torch.float
        ).unsqueeze(1).to(device)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        ).to(device)
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0).transpose(0, 1)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

### ProbSparse Attention

In [None]:
# ProbSparse Attention
class ProbSparseAttention(nn.Module):
    def __init__(self, d_model, num_heads, sparse_factor=0.1, dropout=0.1):
        super(ProbSparseAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = math.sqrt(self.d_k)
        self.sparse_factor = sparse_factor
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, Q, K, V):
        
        # calculate dimension
        batch_size = Q.size(0)
        q_len, k_len, v_len = Q.size(1), K.size(1), V.size(1)
        
        Q = self.query(Q) # (batch_size, seq_len, d_model)
        K = self.key(K) # (batch_size, seq_len, d_model)
        V = self.value(V) # (batch_size, seq_len, d_model)
        
        Q = Q.view(batch_size, q_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, k_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, v_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (batch_size, num_heads, seq_len, seq_len)
        
        # number of top-k
        top_k = int(self.sparse_factor * q_len)
        if top_k < 1:
            top_k = 1
        
        # mask to scores to attent to top-k
        top_k_scores, _ = scores.topk(k=top_k, dim=-1)
        threshold = top_k_scores[:, :, :, -1].unsqueeze(-1)
        
        sparse_mask = scores >= threshold
        sparse_scores = scores.masked_fill(sparse_mask, float('-inf'))
        
        # calculate attention
        attention_weights = self.dropout(self.softmax(sparse_scores))
        output = torch.matmul(attention_weights, V)
        
        # recover input shape
        output = output.transpose(1, 2).contiguous().view(batch_size, q_len, -1)
        
        return output

### Encoder block

In [None]:
# Encoder Block
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(Encoder, self).__init__()
        # ProbSparse Attention
        self.attention = ProbSparseAttention(d_model, num_heads)
        
        # Feed-forward residual block
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # ProbSparse Attention
        attn_output = self.attention(x, x, x)
        
        # normalization
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward residual block
        ffn_output = self.ffn(x)
        
        # normalization
        x = self.norm2(x + self.dropout(ffn_output))
        return x

### Decoder block

In [None]:
# Decoder
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(Decoder, self).__init__()
        # ProbSparse Attention
        self.attention = ProbSparseAttention(d_model, num_heads)
        
        # Feed-forward residual block
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output):
        
        # ProbSparse Attention
        attn_output = self.attention(x, encoder_output, encoder_output)  
             
        # normalization
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward residual block
        ffn_output = self.ffn(x)
        
        # normalization
        x = self.norm2(x + self.dropout(ffn_output))
        return x

### Full informer model

In [None]:
# Informer Model
class Informer(nn.Module):
    def __init__(self,
                 input_dim,
                 d_model,
                 num_heads,
                 d_ff,
                 num_enc_layers,
                 num_dec_layers,
                 output_dim,
                 seq_len,
                 dropout=0.1,
                 device='cpu'):
        super(Informer, self).__init__()
        self.input_embedding = nn.Linear(input_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model, seq_len, device)
        
        # x number of encoder and decoder layers
        self.encoder = nn.ModuleList([Encoder(d_model, num_heads, d_ff, dropout) for _ in range(num_enc_layers)])
        self.decoder = nn.ModuleList([Decoder(d_model, num_heads, d_ff, dropout) for _ in range(num_dec_layers)])
        
        # output layer
        self.output_layer = nn.Linear(d_model, output_dim)
    
    def forward(self, x_enc, x_dec):
        
        # get last time step as input for decoder
        x_dec = x_enc[:, -1:, :]
        
        
        # Encoder input processing
        enc_out = self.input_embedding(x_enc)
        enc_out = self.positional_encoding(enc_out)
        # encoder layers
        for layer in self.encoder:
            enc_out = layer(enc_out)
            
        # Decoder input processing
        dec_out = self.input_embedding(x_dec)
        # dec_out = self.positional_encoding(dec_out)
        
        # Decoder layers (x_enc, x_dec)
        for layer in self.decoder:
            dec_out = layer(dec_out, enc_out)
            
        # Output layer
        output = self.output_layer(dec_out)
        return output

## Testing the model

In [None]:
# Instantiate and test the model
model = Informer(input_dim=1,
                 d_model=64,
                 num_heads=4,
                 d_ff=256,
                 num_enc_layers=2,
                 num_dec_layers=2,
                 output_dim=1,
                 seq_len=100)

print(model)

x_enc = torch.randn(32, 100, 1)  # (batch_szie, sequence_len, input dim)
x_dec = torch.randn(32, 10, 1)  # Decoder input
output = model(x_enc, x_dec)

print(output.shape)  # Should output (sequence length, batch size, output dim)

## Training the model

In [None]:
input_size = 1
batch_size = 32
seq_length = 50
horizon= 1
# Instantiate the model
model = Informer(input_dim=input_size,
                 d_model=64,
                 num_heads=4,
                 d_ff=256,
                 num_enc_layers=2,
                 num_dec_layers=2,
                 output_dim=1,
                 seq_len=seq_length,
                 device=DEVICE
).to(DEVICE)

# Generate example data
sine_waves, example_input, example_target, time = create_example_data(
    batch_size,
    seq_length,
    input_size,
    DEVICE,
    horizon=horizon,
    trend=True
)

print(example_input.shape, example_target.shape, time.shape)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)
criterion = nn.MSELoss()

forecast_input = torch.zeros_like(example_target, device=DEVICE
)

num_epochs = 4000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(example_input,
                   forecast_input)
    
    loss = criterion(output, example_target)
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 200 == 0:
        print(f"Epoch: {epoch+1:4}, Loss: {loss.item():.4f}")

In [None]:
# plot example data
plt.figure()
for i in np.random.randint(0, len(example_input), 5):
    plt.plot(
        time,
        sine_waves[i].squeeze().cpu().detach().numpy(),
        'b-')
    plt.plot(
        time[-horizon:],
        output[i].squeeze().cpu().detach().numpy(),
        'r.')
plt.show()