<a href="https://colab.research.google.com/github/parag2489/Algorithms/blob/master/conv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from enum import Enum
from typing import Tuple

import numpy as np
from torch import nn

In [59]:
class Padding(Enum):
  VALID = "valid"
  SAME = "same"


class Conv2D:
  """Convolution class."""

  def __init__(self, kernel_size: Tuple[int, int], stride: Tuple[int, int], padding: Padding, out_channels: int) -> None:
    """Initialization."""
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.out_channels = out_channels

  def _prepare_conv_weight_kernel(self, in_channels: int):
    weights = np.ones((self.out_channels, in_channels, *self.kernel_size), dtype=np.float32)
    return weights

  def _compute_output_dim(self, dim: int, kernel_dim: int, stride_dim: int, padding_dim: int) -> int:
    """Compute output height or width."""
    out_dim = (dim - kernel_dim + 2 * padding_dim) // stride_dim + 1
    return out_dim

  def __call__(self, x: np.ndarray) -> np.ndarray:
    """Run convolution."""
    padding_amount = (0, 0)
    if self.padding is Padding.SAME:
      padding_amount = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)

    self.weights = self._prepare_conv_weight_kernel(in_channels=x.shape[1])
    out = np.zeros((x.shape[0], self.out_channels, self._compute_output_dim(x.shape[2], self.kernel_size[0], self.stride[0], padding_amount[0]), self._compute_output_dim(x.shape[3], self.kernel_size[1], self.stride[1], padding_amount[1])))

    for i in range(out.shape[2]):
      for j in range(out.shape[3]):
        out[:, :, i, j] = np.sum(x[:, None, :, i * self.stride[0] : i * self.stride[0] + self.kernel_size[0], j * self.stride[1] : j * self.stride[1] + self.kernel_size[1]] * self.weights[None, ...], axis=(2, 3, 4))

    return out



In [86]:
conv_input = np.random.randint(0, 10, (2, 3, 10, 10))
conv = Conv2D(kernel_size=(3, 3), stride=(2, 2), padding=Padding.VALID, out_channels=5)
out = conv(conv_input)

print(conv_input)
print(out[0, 0, 1, 1])
np.sum(conv_input[0, :, 2:5, 2:5] * np.ones((1, 3, 3, 3)))

[[[[2 5 6 0 7 7 8 5 0 5]
   [6 9 6 8 5 6 4 8 7 6]
   [4 0 5 5 2 3 5 6 7 0]
   [8 6 3 4 2 9 8 0 3 2]
   [7 4 5 3 4 4 7 0 7 6]
   [9 3 1 8 9 1 6 8 1 2]
   [9 8 4 9 4 8 0 2 0 8]
   [5 6 4 0 2 1 0 5 0 1]
   [7 6 4 6 7 9 5 6 8 7]
   [3 7 9 9 9 5 6 8 2 7]]

  [[3 1 9 3 7 0 5 1 0 8]
   [5 0 6 8 6 8 9 4 6 8]
   [3 6 5 2 9 8 2 7 0 9]
   [4 0 3 1 3 3 5 7 3 8]
   [7 7 0 5 5 4 4 6 1 1]
   [8 1 6 5 6 4 6 8 3 0]
   [3 7 8 1 1 1 8 4 9 4]
   [8 6 6 8 3 9 4 1 4 5]
   [8 4 9 9 5 5 6 4 7 4]
   [1 6 7 4 7 9 4 9 6 1]]

  [[9 1 3 4 2 1 8 6 8 9]
   [0 0 2 7 5 8 4 7 4 5]
   [1 5 2 4 8 1 1 2 6 3]
   [3 2 9 1 7 0 1 2 8 9]
   [3 1 6 4 8 4 2 9 9 9]
   [9 3 9 8 9 5 9 1 2 7]
   [6 0 0 8 8 9 2 6 9 4]
   [0 8 2 3 8 1 0 7 4 1]
   [8 0 9 9 5 8 0 8 0 2]
   [3 5 8 0 1 4 8 6 2 2]]]


 [[[2 6 4 1 4 3 4 6 8 3]
   [1 1 1 1 6 3 0 5 2 3]
   [2 8 7 0 6 6 9 6 9 5]
   [7 2 8 1 9 7 5 2 9 9]
   [6 1 6 0 0 0 4 6 3 9]
   [8 7 1 4 8 6 9 7 5 1]
   [3 3 5 9 0 4 7 3 7 8]
   [6 9 6 2 3 7 9 5 2 6]
   [7 6 1 0 2 6 7 5 2 3]
   [2 2 8 4 1 6 7

115.0

In [195]:
def convolve(image, kernel, stride):
  # We start off by defining some constants, which are required for this code
  kernelH, kernelW = kernel.shape
  imageH, imageW = image.shape
  stride_h, stride_w = stride
  h, w = (imageH - kernelH) // stride_h + 1 , (imageW - kernelW) // stride_w + 1
  print(h, w)

  # filter1 creates an index system that calculates the sum of the x and y indices at each point
  # Shape of filter1 is h x kernelH
  filter1 = np.arange(imageH - kernelH + 1, step=stride_h)[:, None] + np.arange(kernelH)
  print(filter1)

  # intermediate is the stepped data, which has the shape h x kernelH x imageW
  intermediate = image[filter1]
  print(intermediate.shape)

  # transpose the inner dimensions of the intermediate so as to enact another filter
  # shape is now h x imageW x kernelH
  intermediate = np.transpose(intermediate, (0, 2, 1))  # (8, 16, 3)

  # filter2 similarly creates an index system
  # Shape of filter2 is w * kernelW
  filter2 = np.arange(imageW - kernelW + 1, step=stride_w)[:, None] + np.arange(kernelW)
  print(filter2)

  # Apply filter2 on the inner data piecewise, resultant shape is h x w x kernelW x kernelH
  intermediate = intermediate[:, filter2]
  print(intermediate.shape)

  # transpose inwards again to get a resultant shape of h x w x kernelH x kernelW
  intermediate = np.transpose(intermediate, (0, 1, 3, 2))

  # piecewise multiplication with kernel
  print(intermediate.shape)
  print(kernel.shape)
  product = intermediate * kernel[None, None, ...]
  print(product.shape)

  # find the sum of each piecewise product, shape is now h x w
  convolved = product.sum(axis = (2,3))

  return convolved

In [196]:
conv_input = np.random.randint(0, 10, (1, 1, 10, 16))
conv = Conv2D(kernel_size=(3, 5), stride=(1, 2), padding=Padding.VALID, out_channels=5)
out = conv(conv_input)

print(conv_input)
print(out[0, 0])
print(out.shape)
np.sum(conv_input[0, :, 1:4, 1:6] * np.ones((1, 1, 3, 5)))

[[[[6 1 3 0 7 1 3 5 8 5 0 5 9 7 0 4]
   [3 6 4 3 7 4 4 9 6 8 4 0 5 2 9 1]
   [8 2 4 4 9 0 4 5 5 7 6 4 5 9 4 8]
   [3 6 6 9 0 7 2 8 8 1 5 7 4 8 4 6]
   [1 2 6 0 8 5 4 2 0 0 6 7 4 6 2 7]
   [8 3 8 4 9 8 7 8 2 2 9 0 6 0 4 0]
   [2 5 1 9 5 8 6 2 7 7 7 0 4 0 6 5]
   [4 7 5 6 0 6 2 7 0 5 7 2 5 3 3 8]
   [3 0 4 7 2 0 0 2 9 1 4 2 3 4 5 3]
   [6 4 6 3 2 2 3 9 3 3 0 7 5 1 2 7]]]]
[[67. 57. 77. 79. 77. 69.]
 [74. 67. 78. 82. 75. 76.]
 [68. 68. 67. 63. 69. 81.]
 [73. 83. 78. 64. 61. 72.]
 [71. 88. 81. 69. 61. 61.]
 [76. 84. 77. 78. 63. 56.]
 [60. 61. 56. 66. 63. 55.]
 [59. 48. 47. 55. 56. 53.]]
(1, 5, 8, 6)


71.0

In [198]:
convolve(conv_input.reshape((10, 16)), np.ones((3, 5)), (1, 2)) == out[0][0]

8 6
[[0 1 2]
 [1 2 3]
 [2 3 4]
 [3 4 5]
 [4 5 6]
 [5 6 7]
 [6 7 8]
 [7 8 9]]
(8, 3, 16)
[[ 0  1  2  3  4]
 [ 2  3  4  5  6]
 [ 4  5  6  7  8]
 [ 6  7  8  9 10]
 [ 8  9 10 11 12]
 [10 11 12 13 14]]
(8, 6, 5, 3)
(8, 6, 3, 5)
(3, 5)
(8, 6, 3, 5)


array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True]])