In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [10]:
def convolution(image,kernel):
    num_rows,num_cols,num_channels = image.shape
    rows = np.array([j+(i*num_cols)+ (num_cols*num_rows*k) for k in range(num_channels) for i in range(kernel.shape[0]) for j in range(kernel.shape[1])])
    cols = np.array([j+(i*num_cols) for i in range(num_rows-kernel.shape[0]+1) for j in range(num_cols-kernel.shape[1]+1)])
    grid = rows[None,:]+cols[:,None]
    image_tr = image.transpose(2,0,1)
    kernel_tr = kernel.transpose(2,0,1)
    conv = np.dot(image_tr.take(grid),kernel_tr.flatten())
    out_h,out_w = (image.shape[0]-kernel.shape[0]+1),(image.shape[1]-kernel.shape[1]+1)
    conv = conv.reshape(out_h,out_w)
    return conv

In [11]:
image = np.random.randint(0,255,size=(5,5,3))
kernel = np.random.uniform(size=(3,3,3))
conv = convolution(image,kernel)
print(conv)

[[1971.01010935 1766.37017284 1930.51001495]
 [1643.29022906 1599.53701142 1759.19852932]
 [1958.55360629 2098.80408001 2017.63643403]]


### Include batch

In [12]:
image = np.random.randint(0,255,size=(2,5,5,3))

In [13]:
def convolution2(image,kernel):
    mb,num_rows,num_cols,num_channels = image.shape
    rows = np.array([j+(i*num_cols)+ (num_cols*num_rows*k) for k in range(num_channels) for i in range(kernel.shape[0]) for j in range(kernel.shape[1])])
    cols = np.array([j+(i*num_cols) for i in range(num_rows-kernel.shape[0]+1) for j in range(num_cols-kernel.shape[1]+1)])
    grid = rows[None,:]+cols[:,None]
    batch = np.array(range(0,mb)) * num_channels * num_rows*num_cols
    grid = batch[:,None,None] + grid[None,:,:]
    image_tr = image.transpose(0,3,1,2)
    kernel_tr = kernel.transpose(2,0,1)
    conv = np.dot(image_tr.take(grid),kernel_tr.flatten())
    out_h,out_w = (image.shape[1]-kernel.shape[0]+1),(image.shape[2]-kernel.shape[1]+1)
    conv = conv.reshape(mb,out_h,out_w)
    return conv

In [14]:
convolution2(image,kernel)

array([[[1481.93797227, 1635.12212008, 2249.62739054],
        [1508.46592949, 2086.59998072, 2086.55118486],
        [1503.12318692, 1539.60798838, 2057.51121662]],

       [[1843.67827386, 1720.54255614, 2153.13613707],
        [1777.41309311, 1542.20001529, 1997.57462824],
        [1655.13133393, 1630.6294293 , 1794.80815171]]])

#### Idea from -> https://sgugger.github.io/convolution-in-depth.html#convolution-in-depth