In [4]:
import torch
import torch.nn as nn

from typing import Any
import numpy as np

In [51]:
class FFTLayer(nn.Module):
    """
    Fast Fourier Transform (FFT) layer. Transforms the input into the frequency domain. The FFT of a real signal is Hermitian-symmetric, X[i] = conj(X[-i]) so the output contains only the positive frequencies below the Nyquist frequency.
    """
    def __init__(self) -> None:
        super(FFTLayer, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        complex_f = torch.fft.rfft(input=x, dim=-1, norm='forward')

        # Concatenate the real and imaginary parts
        return torch.cat((complex_f.real, complex_f.imag), dim=-1)


class FFTEnrichLayer(nn.Module):
    """
    Concatenate the input with the FFT of the input.
    """
    def __init__(self) -> None:
        super(FFTEnrichLayer, self).__init__()
        self.fft = FFTLayer()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cat((x, self.fft(x)), dim=-1)
    

class LinearFFTEnriched(nn.Module):
    """
    Linear layer that enriches the input with the FFT of the input.
    """
    def __init__(self, input_size: int, output_size: int) -> None:
        super(LinearFFTEnriched, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.fft_enrich = FFTEnrichLayer()

        self.linear = nn.Linear(input_size + 2 * (input_size // 2 + 1), output_size)

    def to(self, device: torch.device) -> Any:
        self.linear.to(device)
        return self

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.fft_enrich(x))

In [54]:
LinearFFTEnriched(input_size=175, output_size=175)(torch.randn(2, 175))

tensor([[ 0.2342,  0.0018, -0.5272,  0.0961, -0.9642,  0.4963, -0.6998, -0.0798,
          0.1426,  0.6610, -0.4360,  0.5705, -0.5102, -0.5341,  0.1643, -0.2726,
          0.3506,  0.0122,  0.6797,  0.5360,  0.0830, -0.3847,  0.1212,  0.2409,
          0.5764, -0.3908,  0.3484, -0.0890,  0.3017, -0.0988,  0.3992, -0.7167,
         -0.0200,  0.3042, -0.2587, -0.1973, -0.5728, -0.4654,  0.0693, -0.4362,
          0.4107, -0.4018, -0.2631,  0.5925,  0.5228,  0.2471,  0.1502, -0.1625,
         -0.4511,  0.2259, -0.0824,  0.8160,  0.4593, -0.8728,  0.3651, -0.5238,
         -0.1732, -0.2172, -0.0479, -0.3563, -0.0092,  0.4076,  0.1475,  0.7072,
          0.1646, -0.4238,  0.3235, -0.4178, -0.2224, -0.0540,  0.6400, -0.3012,
          0.2719, -0.3525,  0.9355,  0.2347,  0.8612, -0.3089,  0.2750,  0.6096,
          0.4522,  0.3103,  0.0908,  0.0273,  0.3478,  0.2021, -0.8486,  1.1881,
         -0.9351,  0.2825, -0.1407, -0.0525,  0.0139,  0.0466,  0.1369,  0.0904,
         -0.1624,  0.3797,  