In [1]:
import numpy as np

def _convolve2d(image, kernel):
  shape = (image.shape[0] - kernel.shape[0] + 1, image.shape[1] - kernel.shape[1] + 1) + kernel.shape
  strides = image.strides * 2
  strided_image = np.lib.stride_tricks.as_strided(image, shape, strides)
  return np.einsum('kl,ijkl->ij', kernel, strided_image)

def _convolve2d_multichannel(image, kernel):
  convolved_image = np.empty((image.shape[0] - kernel.shape[0] + 1, image.shape[1] - kernel.shape[1] + 1, image.shape[2]))
  for i in range(image.shape[2]):
    convolved_image[:,:,i] = _convolve2d(image[:,:,i], kernel)
  return convolved_image

def _pad_singlechannel_image(image, kernel_shape, boundary):
  return np.pad(image, ((int(kernel_shape[0] / 2),), (int(kernel_shape[1] / 2),)), boundary)

def _pad_multichannel_image(image, kernel_shape, boundary):
  return  np.pad(image, ((int(kernel_shape[0] / 2),), (int(kernel_shape[1] / 2),), (0,)), boundary)

def convolve2d(image, kernel, boundary='edge'):
  if image.ndim == 2:
    pad_image = _pad_singlechannel_image(image, kernel.shape, boundary) if boundary is not None else image
    return _convolve2d(pad_image, kernel)
  elif image.ndim == 3:
    pad_image = _pad_multichannel_image(image, kernel.shape, boundary) if boundary is not None else image
    return _convolve2d_multichannel(pad_image, kernel)
