In [1]:
from fft import *
from scipy import signal

### convolution implemented using direct matrix multiplication

In [2]:
def convolve_direct(img,kernel,mode='full'):
    assert mode in {'full','same','valid'}, NotImplemented
    kernel=np.flipud(np.fliplr(kernel))
    return cross_correlation(img,kernel,mode)

def cross_correlation(img, kernel,mode):
    '''
    img: C*H*W, where C is channel (3 for rgb image, 1 for grayscale);
                  H and W are image height and width respectively.
    '''
    if img.ndim==2:
        img=img[None,...]

    kernel_size = kernel.shape
    h_k, w_k = kernel.shape  # kernel height and width
    if mode == 'full':
        padding = (h_k - 1, w_k - 1)
        pad_width = [[0, 0] for _ in range(img.ndim)]
        pad_width[-2] = [padding[0], padding[0]]
        pad_width[-1] = [padding[1], padding[1]]
        padded = np.pad(img, pad_width)
        return cross_correlation(padded, kernel, mode='valid')
    elif mode == 'same':
        padding = (h_k // 2, w_k // 2)
        pad_width = [[0, 0] for _ in range(img.ndim)]
        pad_width[-2] = [padding[0], padding[0]]
        pad_width[-1] = [padding[1], padding[1]]
        padded = np.pad(img, pad_width)
    elif mode == 'valid':
        padding = (0,0)
        padded=img

    C, H_in, W_in = img.shape  # batch size, rgb channel, Height, Width
    H_out = np.floor(H_in + 2 * padding[0] - kernel_size[0] + 1).astype(int)
    W_out = np.floor(W_in + 2 * padding[1] - kernel_size[1] + 1).astype(int)

    expanded = np.lib.stride_tricks.as_strided(
        padded,
        shape=(
            H_out,  # out channel height
            W_out,  # out channel width
            padded.shape[-3],  # input channel
            kernel.shape[-2],  # kernel height
            kernel.shape[-1],  # kernel width
        ),
        strides=(
            padded.strides[-2],  # H dimension
            padded.strides[-1],  # W dimension
            padded.strides[-3],  # input chennel
            padded.strides[-2],  # kernel height
            padded.strides[-1],  # kernel width
        ),
        writeable=False,
    )
    feature_map = np.ascontiguousarray(np.moveaxis(np.einsum('...ij,...ij->...', expanded, kernel), -1, -3))
    return feature_map

### convolution implemented using fft

In [3]:
def convolve_fft(image, kernel,mode='full'):
    assert mode in {'full','same','valid'}, NotImplemented
    h,w=image.shape
    hk,wk=kernel.shape
    # the shape of the output should be (h+hk-1, w+wk-1) for full convolution,
    # need to extend to next power of 2
    hf,wf=1<<(h+hk-2).bit_length(),1<<(w+wk-2).bit_length()
    img_padded=np.pad(image,((0,hf-h),(0,wf-w)))
    kernel_padded=np.pad(kernel,((0,hf-hk),(0,wf-wk)))
    image_hat=fft2(img_padded)
    kernel_hat=fft2(kernel_padded)
    output_hat=image_hat*kernel_hat
    result_fft=np.real(ifft2(output_hat))[:(h+hk-1), :(w+wk-1)]
    if mode=='same':
        return result_fft[(hk//2):(h+hk//2),(wk//2):(w+wk//2)]
    elif mode=='valid':
        return result_fft[(hk-1):(hk-1+h-2*(hk//2)),(wk-1):(wk-1+w-2*(wk//2))]
    elif mode=='full':
        return result_fft

### small N
direct matrix multiplication is faster.

In [4]:
image=np.random.randn(823,809)
kernel=np.random.randn(21,21)
mode='full'

In [5]:
%%time
result_fft=convolve_fft(image, kernel,mode=mode)

CPU times: user 2.02 s, sys: 112 ms, total: 2.13 s
Wall time: 1.73 s


In [6]:
%%time
result_direct=convolve_direct(image, kernel,mode=mode)[0,...]

CPU times: user 428 ms, sys: 0 ns, total: 428 ms
Wall time: 427 ms


In [7]:
%%time
result_true=signal.convolve(image,kernel,mode=mode)
print(f'scipy uses the "{signal.choose_conv_method(image,kernel)}" method to compute this convolution.\n')

scipy uses the "fft" method to compute this convolution.

CPU times: user 43.4 ms, sys: 0 ns, total: 43.4 ms
Wall time: 43.1 ms


In [8]:
print(np.allclose(result_true,result_fft))
print(np.allclose(result_true,result_direct))

True
True


### large N
fft method is faster, as direct matrix multiplication method grows in O(n^2). Note, dimension of full size convolution is (image size + kernel size - 1). fft will pad zeros to the next nearest power of 2. For example, if the size is 1025, fft will pad it to 2048, thus making it inefficient, a better approach is to divide image to chunks such that (chunk size + kernel size - 1) equals to a power of 2 (i.e. 256), then overlap the results and add them together.

In [9]:
image=np.random.randn(823,809)
kernel=np.random.randn(123,109)
mode='full'

In [10]:
%%time
result_fft=convolve_fft(image, kernel,mode=mode)

CPU times: user 2 s, sys: 50 ms, total: 2.05 s
Wall time: 1.66 s


In [11]:
%%time
result_direct=convolve_direct(image, kernel,mode=mode)[0,...]

CPU times: user 15.9 s, sys: 1.11 ms, total: 15.9 s
Wall time: 15.9 s


In [12]:
%%time
result_true=signal.convolve(image,kernel,mode=mode)
print(f'scipy uses the "{signal.choose_conv_method(image,kernel)}" method to compute this convolution.\n')

scipy uses the "fft" method to compute this convolution.

CPU times: user 46.5 ms, sys: 29 µs, total: 46.5 ms
Wall time: 45.1 ms


In [13]:
print(np.allclose(result_true,result_fft))
print(np.allclose(result_true,result_direct))

True
True
