<a href="https://colab.research.google.com/github/pattichis/MLTransforms/blob/main/FFTStart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Learnable Transforms

In [2]:
import torch
import torch.nn as nn
import torch.fft

class SpectralFilter(nn.Module):
    def __init__(self, size):
        super().__init__()
        # Learnable weights in the frequency domain
        self.filter_weights = nn.Parameter(torch.randn(size // 2 + 1, dtype=torch.complex64))

    def forward(self, x):
        # 1. Transform to frequency domain
        # x shape: [batch, signal_length]
        freq_domain = torch.fft.rfft(x)

        # 2. Apply learnable filter (element-wise multiplication)
        filtered = freq_domain * self.filter_weights

        # 3. Transform back to spatial/time domain
        # n=x.size(-1) ensures the output length matches the input
        output = torch.fft.irfft(filtered, n=x.size(-1))

        return output

# Usage
model = SpectralFilter(size=128)
input_data = torch.randn(16, 128, requires_grad=True)
output = model(input_data)
loss = output.sum()
loss.backward()

print(f"Gradient exists: {input_data.grad is not None}")
print(loss)

Gradient exists: True
tensor(-5.2260, grad_fn=<SumBackward0>)
