In [1]:
import torch
import torch.nn as nn

In [2]:
class VerboseLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.0, bidirectional: bool = False) -> None:
        super(VerboseLSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional

        self.lstm = nn.ModuleList()

        for i in range(num_layers - 1):
            print(f"Adding LSTM layer {i + 1} with input size {input_size} and hidden size {hidden_size} and dropout {dropout}")
            self.lstm.append(nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, bidirectional=bidirectional, batch_first=True))
            input_size = hidden_size * (2 if bidirectional else 1)

            if dropout > 0:
                self.lstm.append(nn.Dropout(p=dropout))

        self.lstm.append(nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, bidirectional=bidirectional, batch_first=True))


    def to(self, *args, **kwargs) -> 'VerboseLSTM':
        self.lstm = self.lstm.to(*args, **kwargs)
        self.linear = self.linear.to(*args, **kwargs)
        return super().to(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = torch.empty(self.num_layers, x.size(0), x.size(1), self.hidden_size * (2 if self.bidirectional else 1), device=x.device)

        lstm_index = 0
        for layer in self.lstm:
            if isinstance(layer, nn.LSTM):
                print(x.shape)
                x, _= layer.forward(x)
                h[lstm_index] = x
                lstm_index += 1
            else:
                x = layer.forward(x)
        
        return x,  h.permute(1, 0, 2, 3)

In [3]:
from bcnf.models import FullyConnectedFeatureNetwork
from typing import Any, Type

In [4]:
class DualDomainLSTM(nn.Module):
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            fc_sizes: list[int],
            fc_dropout: float = 0.0,
            num_layers: int = 1,
            dropout: float = 0.0,
            bidirectional: bool = False,
            pooling: str = 'mean') -> None:
        super(DualDomainLSTM, self).__init__()

        self.input_size = input_size
        self.output_size = fc_sizes[-1]
        self.pooling = pooling

        self.lstm = VerboseLSTM(input_size, hidden_size, num_layers, dropout, bidirectional)
        self.frequency_lstm = VerboseLSTM(input_size, hidden_size, num_layers, dropout, bidirectional)

        fc_sizes = [hidden_size * (2 if bidirectional else 1) * 2] + fc_sizes

        self.fc = FullyConnectedFeatureNetwork(sizes=fc_sizes, dropout=fc_dropout)

    def to(self, *args, **kwargs) -> 'VerboseLSTM':
        self.lstm = self.lstm.to(*args, **kwargs)
        self.linear = self.linear.to(*args, **kwargs)
        return super().to(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_lstm, _ = self.lstm.forward(x)
        
        x_frequencies = torch.fft.rfft(x, dim=1)

        x_frequencies_lstm, _ = self.frequency_lstm.forward(torch.cat([x_frequencies.real, x_frequencies.imag], dim=2))

        if self.pooling == 'mean':
            x_lstm_pooled = x_lstm.mean(dim=1)
            x_frequencies_lstm_pooled = x_frequencies_lstm.mean(dim=1)
        elif self.pooling == 'max':
            x_lstm_pooled = x_lstm.max(dim=1).values
            x_frequencies_lstm_pooled = x_frequencies_lstm.max(dim=1).values
        else:
            raise ValueError(f"Invalid pooling method: {self.pooling}")
        
        x_cat = torch.cat([x_lstm_pooled, x_frequencies_lstm_pooled], dim=1)

        return self.fc.forward(x_cat)

In [5]:
lstm = DualDomainLSTM(input_size=3, hidden_size=13, fc_sizes=[128], fc_dropout=0.5, num_layers=5, dropout=0.1, bidirectional=True)

Adding LSTM layer 1 with input size 3 and hidden size 13 and dropout 0.1
Adding LSTM layer 2 with input size 26 and hidden size 13 and dropout 0.1
Adding LSTM layer 3 with input size 26 and hidden size 13 and dropout 0.1
Adding LSTM layer 4 with input size 26 and hidden size 13 and dropout 0.1
Adding LSTM layer 1 with input size 3 and hidden size 13 and dropout 0.1
Adding LSTM layer 2 with input size 26 and hidden size 13 and dropout 0.1
Adding LSTM layer 3 with input size 26 and hidden size 13 and dropout 0.1
Adding LSTM layer 4 with input size 26 and hidden size 13 and dropout 0.1
