In [None]:
from fastcore.test import *
from manim import *
import numpy as np
from scipy.signal import convolve2d
from sympy import (
    Add,
    Derivative,
    Expr,
    Function,
    Indexed,
    IndexedBase,
    latex,
    Mul,
    Symbol,
    symbols,
)
from typing import Tuple


In [None]:
config.media_embed = True

Below is the convolution function I implemented for my CNN course. It's probably not the most efficient implementation, but writing it was a good learning exercise. 

In [None]:
#| export
def convolve(X: np.ndarray, filter: np.ndarray, zero_pad_width: int, stride: int):
    X_pad = np.pad(
        X,
        ((zero_pad_width, zero_pad_width), (zero_pad_width, zero_pad_width)),
        mode='constant',
        constant_values=(0.0, 0.0),
    )

    in_H, in_W = X.shape
    f, f = filter.shape

    out_H = int((in_H + (2 * zero_pad_width) - f) / stride) + 1
    out_W = int((in_W + (2 * zero_pad_width) - f) / stride) + 1

    output = np.zeros((out_H, out_W), dtype=X.dtype)

    for out_row in range(out_H):
        for out_col in range(out_W):
            in_start_row = out_row * stride
            in_start_col = out_col * stride
            output[out_row][out_col] = sum(
                (
                    X_pad[in_row][in_col] * filter[i][j]
                    for i, in_row in enumerate(range(in_start_row, in_start_row + f))
                    for j, in_col in enumerate(range(in_start_col, in_start_col + f))
                )
            )

    return output

We can test it with some sample data.

In [None]:
X_test = np.array(
    [
        [1.0, 2.0, 3.0, 4.0,],
        [1.0, 2.0, 3.0, 4.0,],
        [1.0, 2.0, 3.0, 4.0,],
        [1.0, 2.0, 3.0, 4.0,],
    ]
)


f_test = np.array(
    [
        [0.0, 1.0,],
        [1.0, 0.0,],
    ]
)

expected = np.array(
    [
        [3.0, 5.0, 7.0], 
        [3.0, 5.0, 7.0], 
        [3.0, 5.0, 7.0]
    ]
)

test_eq(convolve(X_test, f_test, 0, 1), expected)

In [None]:
# Test output of our convolve function is consistent with scipy
test_eq(
    convolve(X_test, f_test, 0, 1), # ours
    convolve2d(X_test, f_test, mode='valid') # scipy
)

It turns out, the same convolution code works just as well when the contents of the input arrays are `Symbol`s instead of numbers!

In [None]:
X_test = np.array([
    [Symbol(r'x_{11}'), Symbol(r'x_{12}'), Symbol(r'x_{13}'), Symbol(r'x_{14}'),],
    [Symbol(r'x_{21}'), Symbol(r'x_{22}'), Symbol(r'x_{23}'), Symbol(r'x_{24}'),],
    [Symbol(r'x_{31}'), Symbol(r'x_{32}'), Symbol(r'x_{33}'), Symbol(r'x_{34}'),],
    [Symbol(r'x_{41}'), Symbol(r'x_{42}'), Symbol(r'x_{43}'), Symbol(r'x_{44}'),],
])


f_test = np.array([
    [Symbol(r'w_{11}'), Symbol(r'w_{12}'),],
    [Symbol(r'w_{21}'), Symbol(r'w_{22}')],
])
result = convolve(X_test, f_test, 0, 1)
result

array([[w_{11}*x_{11} + w_{12}*x_{12} + w_{21}*x_{21} + w_{22}*x_{22},
        w_{11}*x_{12} + w_{12}*x_{13} + w_{21}*x_{22} + w_{22}*x_{23},
        w_{11}*x_{13} + w_{12}*x_{14} + w_{21}*x_{23} + w_{22}*x_{24}],
       [w_{11}*x_{21} + w_{12}*x_{22} + w_{21}*x_{31} + w_{22}*x_{32},
        w_{11}*x_{22} + w_{12}*x_{23} + w_{21}*x_{32} + w_{22}*x_{33},
        w_{11}*x_{23} + w_{12}*x_{24} + w_{21}*x_{33} + w_{22}*x_{34}],
       [w_{11}*x_{31} + w_{12}*x_{32} + w_{21}*x_{41} + w_{22}*x_{42},
        w_{11}*x_{32} + w_{12}*x_{33} + w_{21}*x_{42} + w_{22}*x_{43},
        w_{11}*x_{33} + w_{12}*x_{34} + w_{21}*x_{43} + w_{22}*x_{44}]],
      dtype=object)

Each element of the output array is an _expression_ in terms of the input symbols that defines how that element is calculated. We can pretty-print the first-one to see it better:

In [None]:
result[0][0]

w_{11}*x_{11} + w_{12}*x_{12} + w_{21}*x_{21} + w_{22}*x_{22}

This is exactly the expression for the first element of the convolution output. If you overlaid the filter on the top-left corner of the input matrix and then multipled elements and summed the products, this is the expression you'd get. By running the convolution function against symbols, we've been able to get the expressions that represent the convolution output. 