In [1]:
import torch
import math

FFT

In [47]:
t = torch.rand(8,requires_grad=True)
k = torch.rand(8,requires_grad=True)

In [3]:
compare = lambda a,b: torch.max(torch.abs(a - b))
def FFT_matrix(n):
    w = torch.exp(torch.as_tensor(-1j*2*math.pi/n))
    a = torch.arange(n)[:,None]
    b = torch.arange(n)
    return torch.pow(w,a*b)

In [4]:
FFT8 = FFT_matrix(8)
FFT16 = FFT_matrix(16)

In [5]:
# +0j to convert to complex
o1 = FFT8@(t+0j)

In [6]:
# freq of 0/8,1/8,...7/8
# freq of 1/8 and 7/8 are conjugate of each other, and signal is real,
# hence the freq component of o1[i] = conj(o1[-i]), for i = 1,2,...
# note that 0/8 and 4/8 freq component for real signal are real
o1

tensor([ 3.8516+0.0000e+00j,  0.1502+5.0126e-01j,  0.9479-7.4898e-02j,
         0.8196+5.0151e-01j, -0.5733+1.2291e-06j,  0.8196-5.0151e-01j,
         0.9479+7.4898e-02j,  0.1502-5.0126e-01j], grad_fn=<MvBackward0>)

In [7]:
compare(torch.fft.fft(t) , o1)

tensor(2.1499e-06, grad_fn=<MaxBackward1>)

In [8]:
# torch.fft.rfft gives only the positive frequencies
# i.e., 0/8, 1/8, 2/8, 3/8, 4/8
compare(torch.fft.fft(t)[:5] , torch.fft.rfft(t))

tensor(0., grad_fn=<MaxBackward1>)

In [9]:
# fft.irfft expects only positive freq as given by rfft
compare(t,torch.fft.irfft(torch.fft.rfft(t)))

tensor(5.9605e-08, grad_fn=<MaxBackward1>)

In [10]:
# pad zero to the right
t_pad = torch.concat([t,torch.zeros(8)])

In [11]:
o2 = FFT16@(t_pad+0j)

In [12]:
# pad to the right does not change the original frequency at 0/16, 2/16, ..., 14/16
compare(o1,o2[::2])

tensor(2.0399e-06, grad_fn=<MaxBackward1>)

In [13]:
# FFT pad zero to the right
compare(o2,torch.fft.fft(t,n=16))

tensor(7.3314e-06, grad_fn=<MaxBackward1>)

In [14]:
# rFFT pad to the right as well
compare(o2[:9],torch.fft.rfft(t,n=16))

tensor(2.0691e-06, grad_fn=<MaxBackward1>)

In [28]:
t,torch.flip(t,(0,))

(tensor([0.8892, 0.2762, 0.1729, 0.4755, 0.4043, 0.8674, 0.1727, 0.5933],
        requires_grad=True),
 tensor([0.5933, 0.1727, 0.8674, 0.4043, 0.4755, 0.1729, 0.2762, 0.8892],
        grad_fn=<FlipBackward0>))

Conv

In [48]:
from torchaudio.functional import convolve

In [53]:
# Causal Conv
# conv1d expects shape (batch,channel,seq_len)
# padding mode (str, optional)
# ”full”: one overlap, with shape (…, N + M - 1). Same as in polynomial multiply
# ”valid”: two inputs overlap completely, with shape (…, max(N, M) - min(N, M) + 1).
# ”same”: Returns the center segment of the full convolution result, with shape (…, N).

# note conv1d is actually cross-correlation, to get conv, we need to flip k
compare(torch.nn.functional.conv1d(t[None,None], torch.flip(k,(0,))[None,None],padding=7)[...,:8][0,0],\
        convolve(t, k, 'full')[:8])

tensor(0., grad_fn=<MaxBackward1>)

In [57]:
o3 = convolve(t, k, 'full')[:8]

In [58]:
o3[2].backward()

In [59]:
# as we backprop on the 3rd element of the output, and the conv is causal,
# only the first 3 input and 3 kernel have non-zero grad
list(zip(t.grad,k.grad))

[(tensor(0.0685), tensor(0.5079)),
 (tensor(0.5150), tensor(0.4454)),
 (tensor(0.0618), tensor(0.6872)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.))]

In [64]:
def conv_fft(s,k):
    return torch.fft.irfft(torch.fft.rfft(s) * torch.fft.rfft(k))

In [60]:
def conv_fft2(s,k):
    L = s.shape[-1]
    return torch.fft.irfft(torch.fft.rfft(s,n=2*L-1) * torch.fft.rfft(k,n=2*L-1),n=2*L-1)[...,:L]

In [61]:
compare(conv_fft2(t,k), o3)

tensor(1.4901e-07, grad_fn=<MaxBackward1>)

In [65]:
compare(conv_fft(t,k), o3)

tensor(1.4767, grad_fn=<MaxBackward1>)

In [None]:
# why does padding make it Causal Convolution?
# 1. Polynomial view, (a0,a1,a2) * (b0,b1,b2) will result in 4th order polynomial
# Hence need to represent polynomial with (1+4) samples, i.e. F5 or pad 2 zero at the end
# 
# 2. Circulant matrix. C(a) * x represent circulant conv. If we pad n-1 zeros (assuming a,x are of len n),
# only the zero wraps around, resulting in causal conv