In [2]:
import torch
import torch.nn as nn

class SimpleLSTM(nn.Module):
    def __init__(self, input_channels, num_layers, output_timesteps, output_channels):
        super(SimpleLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_channels, 
                            hidden_size=output_channels, 
                            num_layers=num_layers, 
                            batch_first=True)
        self.fc = nn.Linear(output_channels, output_channels)
        self.output_timesteps = output_timesteps

    def forward(self, x):
        # x shape: [Batch, Channels, Timesteps]
        x = x.permute(0, 2, 1)  # Change to [Batch, Timesteps, Channels]
        lstm_out, _ = self.lstm(x)
        output = self.fc(lstm_out)  # Apply fully connected layer to the entire output
        output = output.permute(0, 2, 1)  # Change to [Batch, Out_Channels, Timesteps]
        return output


# Example usage
batch_size = 32
input_channels = 5
timesteps = 100
num_layers = 2
output_timesteps = 3
output_channels = 2

model = SimpleLSTM(input_channels, num_layers, output_timesteps, output_channels)
input_tensor = torch.randn(batch_size, input_channels, timesteps)
output = model(input_tensor)
print(output.shape)  # Should be [Batch, output_timesteps, output_channels]


torch.Size([32, 2, 100])


In [7]:
import torch
import torch.nn as nn

class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_num_channels, 
                 input_num_timesteps, 
                 num_transformer_blocks_stacked, 
                 output_num_channels, 
                 output_num_timesteps, 
                 hidden_dim, 
                 nhead=8):
        
        super(TimeSeriesTransformer, self).__init__()
        
        assert hidden_dim % nhead == 0, "hidden_dim must be divisible by nhead"
        
        self.input_num_channels = input_num_channels
        self.input_num_timesteps = input_num_timesteps
        self.num_transformer_blocks_stacked = num_transformer_blocks_stacked
        self.output_num_channels = output_num_channels
        self.output_num_timesteps = output_num_timesteps
        self.hidden_dim = hidden_dim
        
        self.input_projection = nn.Linear(input_num_channels, hidden_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, input_num_timesteps, hidden_dim))
        
        transformer_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_transformer_blocks_stacked)
        
        self.conv1d = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1, stride=input_num_timesteps // output_num_timesteps)
        self.output_projection = nn.Linear(hidden_dim, output_num_channels)
    
    def forward(self, x):
        # x shape: [Batch, input_num_channels, input_num_timesteps]
        x = x.permute(0, 2, 1)  # Change to [Batch, input_num_timesteps, input_num_channels]
        x = self.input_projection(x)  # Project input to hidden_dim
        x += self.positional_encoding  # Add positional encoding
        
        x = self.transformer_encoder(x)  # Apply transformer encoder
        x = x.permute(0, 2, 1)  # Change to [Batch, hidden_dim, input_num_timesteps]
        x = self.conv1d(x)  # Apply convolution to reduce timesteps
        x = x.permute(0, 2, 1)  # Change to [Batch, output_num_timesteps, hidden_dim]
        x = self.output_projection(x)  # Project to output_num_channels
        x = x.permute(0, 2, 1)  # Change to [Batch, output_num_channels, output_num_timesteps]
        
        return x

# Example usage
batch_size = 32
input_num_channels = 5
input_num_timesteps = 100
num_transformer_blocks_stacked = 4
output_num_channels = 2  # Final output channels
output_num_timesteps = 5
hidden_dim = 64  # Hidden dimension for transformer
nhead = 8

model = TimeSeriesTransformer(input_num_channels, input_num_timesteps, num_transformer_blocks_stacked, output_num_channels, output_num_timesteps, hidden_dim, nhead)
input_data = torch.randn(batch_size, input_num_channels, input_num_timesteps)
output_data = model(input_data)
print(output_data.shape)  # Should be [Batch, output_num_channels, output_num_timesteps]


torch.Size([32, 2, 5])
