In [None]:
import time

import numpy as np

In [None]:
# Use cumulative sum to understand FFT and IFFT.


def fft_cumsum(x):
    h = np.ones_like(x)
    x_padded = np.concatenate((x, np.zeros_like(x)))
    h_padded = np.concatenate((h, np.zeros_like(h)))

    x_fft = np.fft.fft(x_padded)
    h_fft = np.fft.fft(h_padded)

    result_fft = x_fft * h_fft
    conv_result = np.fft.ifft(result_fft)

    return np.real(conv_result[: len(x)])


def test_fft_cumsum():
    # Generate a random sequence of length 10
    x = np.random.rand(10)

    expected = np.cumsum(x)
    result = fft_cumsum(x)

    print("Expected", expected)
    print("Result", result)
    assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

    print("Test passed!")


test_fft_cumsum()

In [None]:
import numpy as np


def fft_convolve(inputs, rho, dt):
    seq_length, input_dim = inputs.shape[1], inputs.shape[2]
    # Create the rho values for the entire sequence length
    rho_vals = np.array([rho(t * dt) for t in range(seq_length)])

    # Pad the sequences to avoid circular convolution
    padded_rho = np.concatenate((rho_vals, np.zeros_like(rho_vals)))
    padded_inputs = np.concatenate(
        (inputs, np.zeros((inputs.shape[0], seq_length, input_dim))), axis=1
    )

    # FFT
    rho_fft = np.fft.fft(padded_rho)
    inputs_fft = np.fft.fft(padded_inputs, axis=1)

    # Element-wise multiplication in frequency domain
    result_fft = inputs_fft * rho_fft[np.newaxis, :, np.newaxis]

    # Inverse FFT
    conv_result = np.fft.ifft(result_fft, axis=1)

    # Return the result up to seq_length
    return np.real(conv_result[:, :seq_length, :])

In [None]:
def test_convolution(dt, size, seq_length, input_dim, rho, Gaussian_input=False):
    inputs = dt * np.random.normal(size=(size, seq_length, input_dim))

    if Gaussian_input:
        inputs = np.cumsum(inputs, axis=1)

    # Time and execute the direct convolution
    start_time_direct = time.time()
    outputs = []
    for t in range(seq_length):
        output = 0
        for s in range(t + 1):
            output += inputs[:, t - s, :] * (rho(s * dt))
        outputs.append(output)
    direct_outputs = np.array(outputs).transpose(1, 0, 2)
    end_time_direct = time.time()

    # Time and execute the FFT-based convolution
    start_time_fft = time.time()
    fft_outputs = fft_convolve(inputs, rho, dt)
    end_time_fft = time.time()

    print(f"Direct Convolution Time: {end_time_direct - start_time_direct:.5f} seconds")
    print(f"FFT-Based Convolution Time: {end_time_fft - start_time_fft:.5f} seconds")

    assert np.allclose(
        direct_outputs, fft_outputs
    ), f"Direct method and FFT-based method outputs do not match!"
    print("Test passed!")


# Test
def rho(t):
    return np.exp(-t)

dt = 1.0
size = 100
seq_length = 1000
input_dim = 1
Gaussian_input = False
print("seq_length", seq_length)
test_convolution(dt, size, seq_length, input_dim, rho, Gaussian_input)

dt = 1.0
size = 100
seq_length = 2000
input_dim = 1
Gaussian_input = False
print("seq_length", seq_length)
test_convolution(dt, size, seq_length, input_dim, rho, Gaussian_input)

dt = 1.0
size = 100
seq_length = 4000
input_dim = 1
Gaussian_input = False
print("seq_length", seq_length)
test_convolution(dt, size, seq_length, input_dim, rho, Gaussian_input)