In [1]:
from convolutions import *
from scipy import signal
import torch

## small N
direct matrix multiplication is faster, and the scipy `signal.convolve` uses the direct matrix multiplication to compute convolution.

In [2]:
image=np.random.randn(25,25)
kernel=np.random.randn(3,3)
mode='full'

In [3]:
# direct mat. mul. method
result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]
%timeit result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]

176 µs ± 442 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [4]:
# fft method
result_fft=convolve_fft(image, kernel,mode=mode)
%timeit result_fft=convolve_fft(image, kernel,mode=mode)

594 µs ± 7.98 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [5]:
# overlap add fft method
result_oa=convolve_oa(image, kernel,mode=mode)[0]
%timeit result_oa=convolve_oa(image, kernel,mode=mode)[0]

699 µs ± 7.42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
# scipy signal convolve
result_true=signal.convolve(image,kernel,mode=mode)
%timeit result_true=signal.convolve(image,kernel,mode=mode)
print(f'scipy uses the "{signal.choose_conv_method(image,kernel)}" method to compute this convolution.\n')

177 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
scipy uses the "direct" method to compute this convolution.



In [7]:
print(np.allclose(result_true,result_direct))
print(np.allclose(result_true,result_fft))
print(np.allclose(result_true,result_oa))

True
True
True


## large N (output size a bit smaller than a power of 2)
fft method is faster, as direct matrix multiplication method grows in O(n^2). Note, dimension of full size convolution is (image size + kernel size - 1). fft will pad zeros to the next nearest power of 2. For example, if the size is 1000, fft will pad it to 1024.

In [8]:
# output size will be (823,809) + (123,109) - (1,1) = (945,917), very close to its next power of 2, i.e. 1024
image=np.random.randn(823,809)
kernel=np.random.randn(123,109)
mode='full'

In [9]:
result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]
%timeit -r1 -n1 result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]

15.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [10]:
result_fft=convolve_fft(image, kernel,mode=mode)
%timeit result_fft=convolve_fft(image, kernel,mode=mode)

1.72 s ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
result_oa=convolve_oa(image, kernel,mode=mode)[0]
%timeit result_oa=convolve_oa(image, kernel,mode=mode)[0]

1.72 s ± 13.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
result_true=signal.convolve(image,kernel,mode=mode)
%timeit result_true=signal.convolve(image,kernel,mode=mode)
print(f'scipy uses the "{signal.choose_conv_method(image,kernel)}" method to compute this convolution.\n')

33.9 ms ± 482 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
scipy uses the "fft" method to compute this convolution.



In [13]:
print(np.allclose(result_true,result_direct))
print(np.allclose(result_true,result_fft))
print(np.allclose(result_true,result_oa))

True
True
True


## larger N (output size a bit larger than a power of 2)
If the output size is 1025, fft will pad it to 2048, thus making it 4x inefficient in 2D, a better approach is to divide image to chunks such that (chunk size + kernel size - 1) equals to a power of 2 (i.e. 1024), then overlap the results and add them together.

Chunk size should be set to the largest power of 2 that is smaller than output shape. For example, if the output shape is (1045,1027), chunk size should be (1024,1024). chunk size can also be set lower to save memory.

In [14]:
# output size will be (823,809) + (223,219) - (1,1) = (1045,1027), fft will pad to (2048,2048), thus 4x slower
image=np.random.randn(823,809)
kernel=np.random.randn(223,219)
mode='full'

In [15]:
result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]
%timeit -r1 -n1 result_direct=convolve_direct(image[None,None,...], kernel[None,None,...],mode=mode)[0,0]

1min 15s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [16]:
result_fft=convolve_fft(image, kernel,mode=mode)
%timeit -r1 -n1 result_fft=convolve_fft(image, kernel,mode=mode)

11.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [17]:
result_oa=convolve_oa(image, kernel,mode=mode,chunk_thres=(1024,1024))[0]
%timeit result_oa=convolve_oa(image, kernel,mode=mode,chunk_thres=(1024,1024))[0]

2.5 s ± 39.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
result_true=signal.convolve(image,kernel,mode=mode)
%timeit result_true=signal.convolve(image,kernel,mode=mode)
print(f'scipy uses the "{signal.choose_conv_method(image,kernel)}" method to compute this convolution.\n')

55.9 ms ± 579 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
scipy uses the "fft" method to compute this convolution.



In [19]:
print(np.allclose(result_true,result_direct))
print(np.allclose(result_true,result_fft))
print(np.allclose(result_true,result_oa))

True
True
True


## compare with torch.nn.Conv2d
convolution in torch is actually cross correlation in math, we need to flip the kernel in its height and width axis before using fft method.  

Also, since the inputs are real, we can use `numpy.fft.rfft2` and `numpy.fft.irfft2` for efficiency.  

fft method is both faster and more accurate than mat. mul. method.

In [20]:
inpt=np.random.random((32,128,8,8))
weight=np.random.random((128,128,3,3))

In [21]:
# conv. as matrix multiplication
%time y_mat=cross_correlation(inpt,weight,mode='valid')

CPU times: user 2.63 s, sys: 31 µs, total: 2.63 s
Wall time: 2.63 s


In [22]:
%%time
# FFT method

hi,wi=inpt.shape[-2:]
hk,wk=weight.shape[-2:]
ho,wo=hi+hk-1,wi+wk-1
inpt_padded=np.pad(inpt,((0,0),(0,0),(0,ho-hi),(0,wo-wi)))
weight_padded=np.pad(np.flip(weight,axis=(-1,-2)),((0,0),(0,0),(0,ho-hk),(0,wo-wk)))

inpt_hat=np.fft.rfft2(inpt_padded)
inpt_hat=np.lib.stride_tricks.as_strided(inpt_hat,
                   shape=(inpt_hat.shape[0],weight.shape[0],*inpt_hat.shape[1:]),
                   strides=(inpt_hat.strides[0],0,*inpt_hat.strides[1:]),
                   writeable=False)
weight_hat=np.fft.rfft2(weight_padded)
result_hat=np.einsum('...ijk,...ijk->...jk',inpt_hat,weight_hat)
result=np.fft.irfft2(result_hat,s=(ho,wo))
y_fft=result[...,(hk-1):(hi-(hk&1==0)),(wk-1):(wi-(wk&1==0))] # mode valid

CPU times: user 137 ms, sys: 0 ns, total: 137 ms
Wall time: 136 ms


In [23]:
%%time
# torch

m=torch.nn.Conv2d(2,1,1,bias=False)
m.weight.data=torch.tensor(weight)
y_torch=m(torch.tensor(inpt))

CPU times: user 28.9 ms, sys: 7.89 ms, total: 36.8 ms
Wall time: 21.7 ms


In [24]:
print(np.allclose(y_torch.detach().numpy(),y_mat))
print(np.allclose(y_torch.detach().numpy(),y_fft))

True
True


In [25]:
print(f'error using mat. mul. {((y_torch.detach().numpy()-y_mat)**2).sum()}')
print(f'error using fft       {((y_torch.detach().numpy()-y_fft)**2).sum()}')

error using mat. mul. 4.062655573288714e-21
error using fft       1.3044387843955493e-21
