In [91]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

## Convolutions

Here we will build up some of the basic approaches for convolution, from a simple all-for-loop algorithm to an algorithm that uses a single matrix multiplication plus resize operations.

### Storage Order Images

In the simple fully-connected networks, hidden units are typically simply represented as vectors, i.e., a quantity $z \in \mathbb{R}^n$, or when representing an entire minibatch, a matrix $Z \in \mathbb{R}^{B \times n}$.  But when we move to convolutional networks, we need to include additional structure in the hidden unit.  This is typically done by representing each hidden vector as a 3D array, with dimensions `height x width x channels`, or in the minibatch case, with an additional batch dimension.  That is, we could represent a hidden unit as an array:

```c++
float Z[BATCHES][HEIGHT][WIDTH][CHANNELS];
```

The format above is referred to as **NHWC format** (number(batch)-height-width-channel).  However, there are other ways we can represent the hidden unit as well.  For example, PyTorch defaults to the **NCHW format** (indexing over channels in the second dimension, then height and width), though it can also support **NHWC** in later versions.  There are subtle but substantial differences in the performance for each different setting: convolutions are typically faster in NHWC format, owing to their ability to better exploit tensor cores; but NCHW format is typically faster for BatchNorm operation (because batch norm for convolutional networks operates over all pixels in an individual channel).

### Storage Order Kernels

Although less commonly discussed, there is a similar trade-off to be had when it comes to storing the convolutional weights (filter) as well.  Convolutional filters are specified by their kernel size (which can technically be different over different height and width dimensions, but this is quite uncommon), their input channels, and their output channels.  We'll store these weights in the form:

```c++
float weights[KERNEL_SIZE][KERNEL_SIZE][IN_CHANNELS][OUT_CHANNELS];
```

Again, PyTorch does things a bit differently here (for no good reason, it was just done that way historically), storing weight in the order `OUT_CHANNELS x IN_CHANNELS x KERNELS_SIZE x KERNEL_SIZE`.

---
---

## Convolutions with Simple Loops

Let's begin by implementing a simple convolutional operator.  We're going to implement a simple version, which allows for different kernel sizes but which *doesn't* have any built-in padding: to implement padding, you'd just explicitly form a new `ndarray` with the padding built in. This means that if we have an $H \times W$ input image and convolution with kernel size $K$, we'll end up with a $(H - K + 1) \times (W - K + 1)$ image.

We use PyTorch as a reference implementation of convolution that we will check against. However, since PyTorch, as mentioned above, uses the **NCHW format** (and stores the convolutional weights in a different ordering as well), and we'll use the **NHWC format** and the weights ordering stated above, we will need to swap things around for our reference implementation.

In [4]:
import numpy as np
import torch
import torch.nn as nn


def conv_reference(Z: np.ndarray, weight: np.ndarray) -> np.ndarray:
    """
    Reference implementation of convolution operation using PyTorch.

    Parameters
    ----------
    Z : np.ndarray
        The input to the convolutional layer.
    weight : np.ndarray
        The weights of the convolutional layer.

    Returns
    -------
    np.ndarray
        The output of the convolutional layer.
    """
    # Convert NHWC to NCHW for PyTorch, returning a view of the original tensor input with its dimensions permuted
    z_torch = torch.tensor(Z).permute(0, 3, 1, 2)

    # Convert KKIO to OIKK where K is the kernel size, O is the number of output channels, and I is the number of input channels
    weight_torch = torch.tensor(weight).permute(3, 2, 0, 1)

    out = nn.functional.conv2d(z_torch, weight_torch, stride=1, padding=0)

    # Convert back to NHWC, returning a contiguous in memory tensor containing the same data as self tensor
    return out.permute(0, 2, 3, 1).contiguous().numpy()

In [5]:
# Batch of 10 images of size 32 x 32 with 8 channels
Z = np.random.randn(10, 32, 32, 8)
# Kernel size of 3 x 3 with 8 input channels and 16 output channels
W = np.random.randn(3, 3, 8, 16)

out = conv_reference(Z, W)

out.shape

(10, 30, 30, 16)

- The input is a batch of 10 images, each of which is $32 \times 32$ pixels with 8 channels.
  
- The kernel is a $3 \times 3$ matrix with 8 input channels and 16 output channels.
  
- The stride is assumed to be 1 (which is a common default), meaning the kernel moves 1 pixel at a time.
  
- The padding is assumed to be 0 , meaning no additional pixels are added to the border of the image.
  
Given these parameters, the dimensions of the output can be calculated as follows:

1. Height and Width: Since the kernel is $3 \times 3$ and the stride is 1 , the kernel can slide $((I - F) / 1) + 1=(32-3+1)$ $=30$ positions along both the width and the height, resulting in an output size of $30 \times 30$. This reduction in size is because the kernel cannot slide over the full image without going out of the image boundaries (unless padding is used).
   
2. Depth: The depth of the output is determined by the number of output channels in the kernel, which is 16 in this case.
   
3. Batch Size: The batch size remains the same, so the output includes results for all 10 input images.

Simplest possible implementation of a convolution using for loops:

In [9]:
def conv_naive(Z: np.ndarray, weight: np.ndarray) -> np.ndarray:
    """
    Naive implementation of convolution operation.

    Parameters
    ----------
    Z : np.ndarray
        The input to the convolutional layer.
    weight : np.ndarray
        The weights of the convolutional layer.

    Returns
    -------
    np.ndarray
        The output of the convolutional layer.
    """
    # Shapes of input and weight
    N, H, W, C_in = Z.shape
    K, _, _, C_out = weight.shape

    # Initialize output
    out = np.zeros((N, H - K + 1, W - K + 1, C_out))

    for n in range(N):
        for c_in in range(C_in):
            for c_out in range(C_out):
                # Loop over the height of the output feature map
                for y in range(H - K + 1):
                    # Loop over the width of the output feature map
                    for x in range(W - K + 1):
                        for i in range(K):
                            for j in range(K):
                                out[n, y, x, c_out] += (
                                    Z[n, y + i, x + j, c_in] * weight[i, j, c_in, c_out]
                                )
    return out

We can check to make sure this implementation works by comparing to the PyTorch reference implementation.

In [10]:
out_2 = conv_naive(Z, W)

np.linalg.norm(out - out_2)

1.2281529924016432e-12

The implementation works, but not surprisingly, the 7-fold loop in interpreted code is much slower than the PyTorch implementation:

In [11]:
%%timeit
out_2 = conv_naive(Z, W)

15.4 s ± 1.47 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
out = conv_reference(Z, W)

1.29 ms ± 462 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---
---

## Convolutions as Matrix Multiplications

Another way to implement convolution is to perform it as a sequence of matrix multiplications.  Remember that a kernel size $K = 1$ convolution is equivalent to performing matrix multiplication over the channel dimensions.  That is, suppose we have the following convolution.

In [13]:
# A kernel size of 1 x 1 with 8 input channels and 16 output channels
W_1 = np.random.randn(1, 1, 8, 16)
out = conv_reference(Z, W_1)

out.shape

(10, 32, 32, 16)

Then we could implement the convolution using a _single_ matrix multiplication.

In [18]:
Z.shape

(10, 32, 32, 8)

In [20]:
# W_1[0, 0] is a 8 x 16 matrix
out_2 = Z @ W_1[0, 0]

np.linalg.norm(out - out_2)

9.72678468389406e-15

We're here exploiting the nicety that in numpy, when you compute a matrix multiplication by a multi-dimensional array (a rank 4 tensor in $Z$), it will treat the leading dimensions of $Z$ all as rows of a matrix. That is, the above operation would be equivalent to:

In [26]:
out_2 = (Z.reshape(-1, 8) @ W_1[0, 0]).reshape(
    Z.shape[0], Z.shape[1], Z.shape[2], W_1.shape[3]
)
np.linalg.norm(out - out_2)

9.72678468389406e-15

In [30]:
print(
    f"We have reshape Z from {Z.shape} to {(Z.reshape(-1, 8)).shape} in order to perform a matrix multiplication of it with {W_1[0, 0].shape}"
)

We have reshape Z from (10, 32, 32, 8) to (10240, 8) in order to perform a matrix multiplication of it with (8, 16)


This strategy immediately motivates a very natural approach to convolution: we can iterate over just the kernel dimensions $i$ and $j$, and use matrix multiplication to perform the convolution.

In [32]:
def conv_matrix_mult(Z: np.ndarray, weight: np.ndarray) -> np.ndarray:
    """
    Implementation of convolution operation using matrix multiplication.

    Parameters
    ----------
    Z : np.ndarray
        The input to the convolutional layer.
    weight : np.ndarray
        The weights of the convolutional layer.

    Returns
    -------
    np.ndarray
        The output of the convolutional layer.
    """
    # Shapes of input and weight
    N, H, W, C_in = Z.shape
    K, _, _, C_out = weight.shape

    # Initialize output
    out = np.zeros((N, H - K + 1, W - K + 1, C_out))

    for i in range(K):
        for j in range(K):
            # Each weight[i, j] is a C_in x C_out matrix
            out += Z[:, i : (i + H - K + 1), j : (j + W - K + 1), :] @ weight[i, j]

    return out

The `[:, i:(i + H - K + 1), j:(j + W - K + 1), :]` line is where we are sliding the window over patches of the input $Z$ where

* $H - K + 1$ is the height of the output feature map

* $W - K + 1$ is the width of the output feature map

**All assuming that stride equals 1**.

In [34]:
# A batch of 100 images of size 32 x 32 with 8 channels
Z = np.random.randn(100, 32, 32, 8)
# Kernel size of 3 x 3 with 8 input channels and 16 output channels
W = np.random.randn(3, 3, 8, 16)

out = conv_reference(Z, W)
out_2 = conv_matrix_mult(Z, W)

np.linalg.norm(out - out_2)

3.149296845152869e-12

This works as well, as (as expected) is _much_ faster, starting to be competitive even with the PyTorch version.

In [39]:
%%timeit
out = conv_reference(Z, W)

29.7 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [40]:
%%timeit
out = conv_matrix_mult(Z, W)

63.8 ms ± 7.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


---
---

## Manipulating Matrices via Strides

Before implementing convolutions via **im2col**, which is a technique to rearrange discrete image blocks of size m-by-n into columns, we consider an example that actually has nothing to do with convolution. Instead, we consider the efficient matrix multiplication operations where we think of storing a matrix as a 2D array:

```c++
float A[M][N];
```

In the typical **row-major format**, this will store each N-dimensional row vector of the matrix one after another in memory. However, recall that in order to make better use of the caches and vector operations in modern CPUs, it was beneficial to lay our our matrix memory groups by individual small "tiles", so that the CPU vector operations could efficiently access operators

```c++
float A[M/TILE][N/TILE][TILE][TILE];
```    

where `TILE` is some small constant (like 4), which allows the CPU to use its vector processor to perform very efficient operations on `TILE x TILE` blocks. Importantly, what enables this to be so efficient is that in the standard memory ordering for an ND array, this grouping would locate all `TILE x TILE` block consecutively in memory, so they could quickly be loaded in and out of cache / registers / etc.

How exactly would we convert a matrix to this form? We can use the function `np.lib.stride_tricks.as_strided()`, which lets us create new matrices by manually manipulating the strides of a matrix while *not* changing the underlying data. We can then use `np.ascontiguousarray()` to lay out the memory sequentially. These sets of tricks let us rearrange matrices fairly efficiently in just one or two lines of `numpy` code.

### An Example: A 6x6 2D Array

To see how this works, let's consider an example 6x6 numpy array.

In [43]:
n = 6

A = np.arange(n**2, dtype=np.float32).reshape(n, n)

A

array([[ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10., 11.],
       [12., 13., 14., 15., 16., 17.],
       [18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29.],
       [30., 31., 32., 33., 34., 35.]], dtype=float32)

#### Row-Major Format

This array is layed out in memory by row.  It's actually a bit of a pain to access the underlying raw memory of a numpy array in Python (numpy goes to great lengths to try to *prevent* us from doing this, but we can see how the array is layed out using the following code):

In [52]:
import ctypes

np.frombuffer(
    ctypes.string_at(A.ctypes.data, size=A.nbytes), dtype=A.dtype, count=A.size
)

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35.], dtype=float32)

* `numpy.ndarray.nbytes` is an attribute that shows the total bytes consumed by the elements of the array. Note that this does not include memory consumed by non-element attributes of the array object.

* `_ctypes.data` is a pointer to the memory area of the array as a Python integer.

* `ctypes.string_at(address, size=-1)` is a function that returns the C string starting at memory address `address` as a bytes object. If size is specified, it is used as `size``, otherwise the string is assumed to be zero-terminated.

#### Strides

The `strides` structure can be a way to lay out n-dimensional arrays in memory. In order to access the `A[i][j]` element of a 2D array, for instance, we would access the memory location at:

```c++
A.bytes[i * strides[0] + j * strides[1]];
```
    
The same can be done e.g., with a 3D tensor, accessing `A[i][j][k]` at memory location:

```c++
A.bytes[i * strides[0] + j * strides[1] + k * strides[2]];
```

For an array in row-major format, we would thus have

```c++
strides[0] = num_cols;
strides[1] = 1;
```

We can look at the strides of the array we have created using the `.strides` attribute.

In [61]:
A.strides

(24, 4)

Note that numpy, somewhat unconventionally, actually uses strides equal to the total number of *bytes*, so these numbers are all multiplied by 4 from the above, because a `float32` type takes up 4 bytes. The `strides` attribute returns a tuple of bytes to step in each dimension when traversing an array.

For example, the byte offset of element `(i[0], i[1], ..., i[n])` in an array `a` is:

```python
offset = sum(np.array(i) * a.strides)
```

In [66]:
for i in range(n):
    print(np.sum(np.array(i) * A.strides))

0
28
56
84
112
140


The strides of an array tell us how many bytes we have to skip in memory to move to the next position along a certain axis. For example, we have to skip 4 bytes (1 value) to move to the next column, but $6 \times 4 = 24$ bytes (6 values) to get to the same position in the next row. As such, the strides for the array x will be `(24, 4)`.

### Tiling a Matrix using Strides

For simplicity, assume we want to tile into $2 \times 2$ blocks, and thus we want to convert `A`, which is $6 \times 6$, into a `3 x 3 x 2 x 2` array.  What would the strides be in this case?  In other words, if we accessed the element `A[i][j][k][l]`, how would this index into a memory location in the array as layed out above?  

1. Incrementing the first index, `i`, would move down *two* rows in the matrix with 6 rows since 
   
   * each block of the tile is $2 \times 2$ 
   * `tile = 2`, `num_rows = 6`, and `num_rows / tile = 3`

   Therefore, incrementing the first index by 1 actually skips `strides[0] = 12` elements (two rows) 

2. Similarly, incrementing the second index `j` would move over two columns, so `strides[1] = 2` elements (two columns) in row-major format
   
3. Things get a bit tricker next, but are still fairly straightforward: incrementing the next index `k` (the height index within a tile) moves down one row in the tile matrix, so `strides[2] = 6` elements or 1 row, which is half of the amount bytes we have to skip for incrementing `i` 
   
4. Finally incrementing the last index `l` just moves us over one column, so `strides[3] = 1` element

The `np.lib.stride_tricks.as_strided()` function lets us specify the shape and stride of a new matrix, created from the same memory as the old matrix. That is, it doesn't do any memory copies, so it's very efficient. But we also have to be careful when we use it, since it is directly creating a new view of an existing array, and without proper care we could e.g., go outside the bounds of the array.

To create the tiled view of the matrix `A`:

In [68]:
A

array([[ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10., 11.],
       [12., 13., 14., 15., 16., 17.],
       [18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29.],
       [30., 31., 32., 33., 34., 35.]], dtype=float32)

In [74]:
# Note we need to multiple the strides by 4 since each element is a float32 which is 4 bytes
B = np.lib.stride_tricks.as_strided(
    A, shape=(3, 3, 2, 2), strides=np.array((12, 2, 6, 1)) * 4
)
B

array([[[[ 0.,  1.],
         [ 6.,  7.]],

        [[ 2.,  3.],
         [ 8.,  9.]],

        [[ 4.,  5.],
         [10., 11.]]],


       [[[12., 13.],
         [18., 19.]],

        [[14., 15.],
         [20., 21.]],

        [[16., 17.],
         [22., 23.]]],


       [[[24., 25.],
         [30., 31.]],

        [[26., 27.],
         [32., 33.]],

        [[28., 29.],
         [34., 35.]]]], dtype=float32)

Parsing numpy output for ND array isn't the most intuitive thing, but if we look closely then we can see that the array above basically lay out each 2x2 block of the matrix `A` as desired. However, we can also see that this call didn't change the actual memory layout by again inspecting the raw memory.

In [72]:
np.array_equal(
    np.frombuffer(ctypes.string_at(B.ctypes.data, size=B.nbytes), B.dtype, B.size),
    np.frombuffer(ctypes.string_at(A.ctypes.data, size=A.nbytes), A.dtype, A.size),
)

True

In [92]:
B.strides
B.strides == tuple((12, 2, 6, 1) * np.array(np.dtype(np.float32).itemsize))

(48, 8, 24, 4)

True

In order to reorder the memory so that the underlying matrix is contiguous/compact (which is what we need for making the matrix multiplication efficient), we can use the `np.ascontinugousarray()` function.

In [86]:
C = np.ascontiguousarray(B)

np.frombuffer(ctypes.string_at(C.ctypes.data, size=C.nbytes), C.dtype, C.size)

array([ 0.,  1.,  6.,  7.,  2.,  3.,  8.,  9.,  4.,  5., 10., 11., 12.,
       13., 18., 19., 14., 15., 20., 21., 16., 17., 22., 23., 24., 25.,
       30., 31., 26., 27., 32., 33., 28., 29., 34., 35.], dtype=float32)

As you can see, the `C` array is layed out in compact order.  This can also be verified by looking as it's `.strides` attribute.

In [93]:
print(C.strides)

(48, 16, 8, 4)


## Convolutions via im2col

Let's consider finally the "real" way to implement convolutions, which will end up being about as fast as PyTorch's implementation. Essentially, we want to bundle all the computation needed for convolution into a *single* matrix multiplication, which will then leverage all the optimizations that we can implement for normal matrix multiplication.

The key approach to doing this is called the `im2col` operator, which "unfolds" a 4D array (rank 4 tensor) into exactly the form needed to perform multiplication via convolution.  Let's see an example of how this works using a simple 2D array, before we move to the 4D case:

In [97]:
A = np.arange(n**2, dtype=np.float32).reshape(n, n)
A

array([[ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10., 11.],
       [12., 13., 14., 15., 16., 17.],
       [18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29.],
       [30., 31., 32., 33., 34., 35.]], dtype=float32)

A 3x3 filter:

In [98]:
W = np.arange(9, dtype=np.float32).reshape(3, 3)

W

array([[0., 1., 2.],
       [3., 4., 5.],
       [6., 7., 8.]], dtype=float32)

Recall that a convolution will multiply this filter element-wise with every $3 \times 3$ block in the image. 

So how can we extract every such $3 \times 3$ blocks from the matrix `A`? 

The key will be to form a $(H - K + 1) \times (W - K + 1) \times K \times K$ array that contains all of these blocks, then flatten it to a matrix we can multiply by the filter. To create this array of all blocks, short of manual copying, we can use the `as_strided()` function.

Specifically, if we create a new view into the tiled array of size `(6 - 3 + 1 = 4, 6 - 3 + 1 = 4, 3, 3)`, how can we use `as_strided()` to return the matrix we want? 

Note that the first two dimensions or indices will have strides of `6` elements and `1` element:

1. Incrementing the first index by `1` will move to the next row since we assume a stride of 1 vertically
   
2. Incrementing the next index by `1` will move to the next column since we assume a stride of 1 horizontally
   
Interestingly (and this is the "trick"), the third and fourth dimensions *also* have strides of `6` and `1`, respectively. This is because incrementing the third index by one *also* moves to the next row, and similarly, for the fourth index, moves to the next column. Again, this is all assuming a stride of 1 both vertically and horizontally within each tile of $3 \times 3$.

In [100]:
B = np.lib.stride_tricks.as_strided(
    A, shape=(4, 4, 3, 3), strides=(np.array((6, 1, 6, 1))) * 4
)
B.shape
B

(4, 4, 3, 3)

array([[[[ 0.,  1.,  2.],
         [ 6.,  7.,  8.],
         [12., 13., 14.]],

        [[ 1.,  2.,  3.],
         [ 7.,  8.,  9.],
         [13., 14., 15.]],

        [[ 2.,  3.,  4.],
         [ 8.,  9., 10.],
         [14., 15., 16.]],

        [[ 3.,  4.,  5.],
         [ 9., 10., 11.],
         [15., 16., 17.]]],


       [[[ 6.,  7.,  8.],
         [12., 13., 14.],
         [18., 19., 20.]],

        [[ 7.,  8.,  9.],
         [13., 14., 15.],
         [19., 20., 21.]],

        [[ 8.,  9., 10.],
         [14., 15., 16.],
         [20., 21., 22.]],

        [[ 9., 10., 11.],
         [15., 16., 17.],
         [21., 22., 23.]]],


       [[[12., 13., 14.],
         [18., 19., 20.],
         [24., 25., 26.]],

        [[13., 14., 15.],
         [19., 20., 21.],
         [25., 26., 27.]],

        [[14., 15., 16.],
         [20., 21., 22.],
         [26., 27., 28.]],

        [[15., 16., 17.],
         [21., 22., 23.],
         [27., 28., 29.]]],


       [[[18., 19., 20.],
        

In [101]:
np.frombuffer(ctypes.string_at(B.ctypes.data, size=A.nbytes), B.dtype, A.size)

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35.], dtype=float32)

This is exactly the 4D array we need.  Now, if we want to compute the convolution as a "single" matrix multiplication, we can simply 

1. Flatten this array to a $(4 \cdot 4 = 16) \times (3 \cdot 3 = 9)$ matrix
   
2. Reshape the weights to a $9$ dimensional vector (the weights will become a matrix again for the case of multi-channel convolutions)

3. Perform the matrix multiplication  
   
4. Finally, reshape the resulting vector back into a $4 \times 4$ array to perform the convolution.

In [103]:
C = B.reshape(16, 9)
C

array([[ 0.,  1.,  2.,  6.,  7.,  8., 12., 13., 14.],
       [ 1.,  2.,  3.,  7.,  8.,  9., 13., 14., 15.],
       [ 2.,  3.,  4.,  8.,  9., 10., 14., 15., 16.],
       [ 3.,  4.,  5.,  9., 10., 11., 15., 16., 17.],
       [ 6.,  7.,  8., 12., 13., 14., 18., 19., 20.],
       [ 7.,  8.,  9., 13., 14., 15., 19., 20., 21.],
       [ 8.,  9., 10., 14., 15., 16., 20., 21., 22.],
       [ 9., 10., 11., 15., 16., 17., 21., 22., 23.],
       [12., 13., 14., 18., 19., 20., 24., 25., 26.],
       [13., 14., 15., 19., 20., 21., 25., 26., 27.],
       [14., 15., 16., 20., 21., 22., 26., 27., 28.],
       [15., 16., 17., 21., 22., 23., 27., 28., 29.],
       [18., 19., 20., 24., 25., 26., 30., 31., 32.],
       [19., 20., 21., 25., 26., 27., 31., 32., 33.],
       [20., 21., 22., 26., 27., 28., 32., 33., 34.],
       [21., 22., 23., 27., 28., 29., 33., 34., 35.]], dtype=float32)

In matrix $C$, every row is a flattened version of the $3 \times 3$ block of the original input $A$, simulating a sliding window over $A$. This can then be multiplied by a flattened $W$:

In [109]:
(
    C
    @ W.reshape(
        9,
    )
).reshape(4, 4)

array([[ 366.,  402.,  438.,  474.],
       [ 582.,  618.,  654.,  690.],
       [ 798.,  834.,  870.,  906.],
       [1014., 1050., 1086., 1122.]], dtype=float32)

#### A Critical Note on Memory Efficiency

There is a _very_ crucial point to make regarding memory efficiency of this operation. While reshaping `W` into an array (or what will be a matrix for multi-channel convolutions) is "free" in that it doesn't allocate any new memory, reshaping the `B` matrix above is very much *not* a free operation. 

Specifically, while the strided form of `B` uses the same memory as `A`, once we actually convert `B` into a 2D matrix with `reshape`, there is no way to represent this data using any kind of strides, and we have to just allocate the entire matrix. This means we actually need to *form* the full `im2col` matrix, which requires $O(K^2)$ more memory than the original image, which can be quite costly for large kernel sizes.

For this reason, in practice it's often the case that the best modern implementations *won't* actually instantiate the full `im2col` matrix, and will instead perform a kind of "lazy" formation, or specialize the matrix operation natively to `im2col` matrices in their native strided form. These are all fairly advanced topics; for our purposes, it will be sufficient to just allocate this matrix and then quickly de-allocate it after we perform the convolution (remember that we aren't e.g., doing back-propagation through the `im2col` operation).

### Using im2col for Multi-Channel Convolutions

So how do we actually implement an `im2col` operation for real multi-channel, mini-batched convolutions? 

Instead of forming a 4D $(H - K + 1) \times (W - K + 1) \times K \times K$ array, we form a 6D $N \times (H - K + 1) \times (W - K + 1) \times K \times K \times C$ array (leaving the mini-batch and channel dimensions untouched). 

After thinking about it for a bit, it should be pretty clear that we can apply the same trick by just repeating the strides for dimensions 1 and 2 (the height and width) for dimensions 3 and 4 (the $K \times K$ blocks), and leave the strides for the mini-batch and channels unchanged. Furthermore, we don't even need to worry about manually computing the strides manually: we can just use the strides of the $Z$ input and repeat whatever they are.

To compute the convolution, 

1. Flatten the `im2col` matrix to a $(N \cdot (H - K + 1) \cdot (W - K + 1)) \times (K \cdot K \cdot C)$ matrix (remember, this operation is highly memory inefficient)
   * Also note that this is exactly how we reshaped the 4D strided tensor $((H - K + 1) \times (W - K + 1)) \times (K \times K)$ earlier except we now include the batch and the channel dimensions since we have a 6D tensor
2. Flatten the weights array to a $(K \cdot K \cdot C) \times C_{out}$ matrix so that dimensions of these two matrices are matching
3. Perform the multiplication
4. Resize back to the desired size of the final 4D array output

In [142]:
def conv_im2col(Z: np.ndarray, weight: np.ndarray) -> np.ndarray:
    """
    Implementation of convolution operation using im2col.

    Parameters
    ----------
    Z : np.ndarray
        The input to the convolutional layer.
    weight : np.ndarray
        The weights of the convolutional layer.

    Returns
    -------
    np.ndarray
        The output of the convolutional layer.
    """
    # Shapes of input and weight
    N, H, W, C_in = Z.shape
    K, _, _, C_out = weight.shape
    # Strides of the input
    Ns, Hs, Ws, Cs = Z.strides
    # The output (column) dimension of the im2col matrix
    inner_dim = K * K * C_in
    # The im2col matrix
    A = np.lib.stride_tricks.as_strided(
        Z, shape=(N, H - K + 1, W - K + 1, K, K, C_in), strides=(Ns, Hs, Ws, Hs, Ws, Cs)
    ).reshape(-1, inner_dim)

    # A is a N * (H - K + 1) * (W - K + 1) x (K * K * C_in) matrix
    # Reshaped weight is a (K * K * C_in) x C_out matrix
    out = A @ weight.reshape(-1, C_out)

    # Reshape to output shape
    return out.reshape(N, H - K + 1, W - K + 1, C_out)

Again, we can check that this version produces the same output as the PyTorch reference (or our other implementations, at this point):

In [143]:
# A batch of 100 images of size 32 x 32 with 8 channels
Z = np.random.randn(100, 32, 32, 8)
# Kernel size of 3 x 3 with 8 input channels and 16 output channels
W = np.random.randn(3, 3, 8, 16)

out = conv_reference(Z, W)
out_2 = conv_im2col(Z, W)

np.linalg.norm(out - out_2)

0.0

Runtime:

In [144]:
%%timeit
out_3 = conv_im2col(Z, W)

48.7 ms ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


This outperforms the implementation using matrix multiplication:

In [146]:
%%timeit
out = conv_matrix_mult(Z, W)

67.6 ms ± 8.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


However, the PyTorch implementation is still the fastest:

In [145]:
%%timeit
out_3 = conv_reference(Z, W)

24.7 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
