It is possible to express CNN as FCN with some weight binding and fixed weights as shown in https://medium.com/impactai/cnns-from-different-viewpoints-fab7f52d159c.

![CNN as matrix multiplication](https://miro.medium.com/max/1400/1*95lL-PY5WEeBAtfaWAIIRQ.png)

> The matrix above is a weight matrix, just like the ones from traditional neural networks. However, this weight matrix has two special properties:
> 1. The zeros shown in gray are untrainable. This means that they’ll stay zero throughout the optimization process.
> 2. Some of the weights are equal, and while they are trainable (i.e. changeable), they must remain equal. These are called “shared weights”.
> The zeros correspond to the pixels that the filter didn’t touch. Each row of the weight matrix corresponds to one application of the filter.

For sure, this is not an efficient way, but it's a nice thought excersice. Note, that this way we use number of output elements by image size weights, thus if we have 28x28 image and 3x3 filter we will end up with more than 7k params.

**Q:** is there a way to express CNN as FCN using less weights?

**A:** Let's say initial image is 5x5 and we use 2x2 filter, then output will have 4x4 values. If we think about transformation alone we need to map 5x5 to 4x4 and using FCN most 'efficient' reimplementation would have to use 5x5x4x4 weights.

In [117]:
import torch

In [118]:
F = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
X = torch.randn((5, 5), dtype=torch.float32)
# Expected
Y = torch.conv2d(X[None, None, :, :], F[None, None, :, :])

For sure this works.

In [119]:
[[a, b],
 [c, d]] = F

A = torch.tensor([[a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, 0, 0, 0, c, d]])

assert torch.allclose(A @ X.flatten(), Y.flatten())

In [120]:
%%timeit

A @ X.flatten()

8.43 µs ± 300 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


**Q:** is there a way to avoid image reshaping?

**A:** Again, let's consider shape transformation from previous answer, we want to map 5x5 to 4x4. We could construct two weight matrices and use one on the left and one on the right 4x5 @ 5x5 @ 5x4 or we can mask part of the initial matrix and express it as 4x5 @ 5x4 + 4x5 @ 5x4.

To make our task more interesting, let's to exactly that - produce two weight matrices that could replicate CNN.

In [121]:
[[a, b],
 [c, d]] = F

A = torch.tensor([[a, c, 0, 0, 0],
                  [0, a, c, 0, 0],
                  [0, 0, a, c, 0],
                  [0, 0, 0, a, c]])

B = torch.tensor([[b, d, 0, 0, 0],
                  [0, b, d, 0, 0],
                  [0, 0, b, d, 0],
                  [0, 0, 0, b, d]])

C = A @ X[:, :-1] + B @ X[:, 1:]

assert torch.allclose(C, Y)

In [122]:
%%timeit

C = A @ X[:, :-1] + B @ X[:, 1:]

24 µs ± 4.37 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


That look's much nicer, but it looks like computation is ~3x slower.

Let's instead benchmark for bigger images 256x256 and 3x3 filter.

In [123]:
X = torch.randn((256, 256), dtype=torch.float32)
F = torch.randn((3, 3), dtype=torch.float32)

1. Real conv

In [124]:
%%timeit

torch.conv2d(X[None, None, :, :], F[None, None, :, :])

400 µs ± 68.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


2. Our flat version

In [125]:
A = torch.randn((254, 256 * 256))

In [126]:
%%timeit

A @ X.flatten()

2.3 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


3. Nicer version

In [131]:
A = torch.randn((254, 256), dtype=torch.float32)
B = torch.randn((254, 256), dtype=torch.float32)
C = torch.randn((254, 256), dtype=torch.float32)

In [132]:
%%timeit

A @ X[:, :-2] + B @ X[:, 1:-1] + C @ X[:, 2:]

567 µs ± 62.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


So it turns out that this nicer variation is quite fast compared to flat one. For sure if filter size grows we end up with a lot of operation here, but usually filter size is small compared to image size and this is not supposed to be normal implementation either way.

# CNN Layer (in progress)

In [137]:
class CNNLayer():
    def __init__(self, shape, filter_shape):
        self.n = filter_shape[0]
        self.F = torch.randn(filter_shape)
        self.b = torch.zeros((1))
        self.out_shape = (shape[0] - filter_shape[0] + 1,
                          shape[1] - filter_shape[1] + 1)
        self.params = [self.F]
        self.inp = None
        self.out = None
    
    def __call__(self, x):
        self.inp = x
        self.out = torch.zeros(self.out_shape)
        for i in range(self.n):
            end_idx = i + 1 - self.n if i + 1 - self.n else None
            self.out += self._filt_as_mat(i) @ X[:, i:end_idx]
        self.out += self.b
        return self.out
    
    def _filt_as_mat(self, idx):
        raise NotImplemented
    
    def back(self):
        raise NotImplemented