In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import datasets as misc

from utils import batch_plot
from conv import load_sample_filters

In [None]:
sample_image = misc.face()
sample_image = sample_image / sample_image.max()

# for better plot
n = np.min(sample_image.shape[:2])
sample_image = sample_image[:n, :n, :]

batch_plot(np.expand_dims(sample_image, 0), with_border=False, imgsize=6)

# Generate sliding window views of the image

In [None]:
# example of using numpy sliding_window_view: stride=1
height, width, channel, window, stride = 7, 7, 3, 3, 1
x = np.arange(height * width * channel).reshape((height, width, channel))
y = np.lib.stride_tricks.sliding_window_view(x, window_shape=(window, window, channel)).squeeze(axis=2)

# (height - window)/stride + 1 = chunk_height
chunk_height, chunk_width = (height - window) // stride + 1, (width - window) // stride + 1
assert y.shape == (chunk_height, chunk_width, window, window, channel)

# low level operation
stride = 1
stride_height, stride_width, stride_channel = x.strides
z = np.lib.stride_tricks.as_strided(
    x,
    shape=(chunk_height, chunk_width, window, window, channel),
    strides=(
        stride * stride_height,
        stride * stride_width,
        stride_height,
        stride_width,
        stride_channel,
    ),
)
assert np.allclose(y, z)

In [None]:
# plot nxn views of the image
n = 4
height, width, channel = sample_image.shape
num_stride_height, num_stride_width = height // n, width // n
chunk_height, chunk_width = n, n
stride_height, stride_width, stride_channel = sample_image.strides
# (height - filter_height)/stride + 1 = chunk_height
filter_height, filter_width = height - num_stride_height * (chunk_height - 1), width - num_stride_width * (
    chunk_width - 1
)
chunks = np.lib.stride_tricks.as_strided(
    sample_image,
    shape=(chunk_height, chunk_width, filter_height, filter_width, channel),
    strides=(
        num_stride_height * stride_height,
        num_stride_width * stride_width,
        stride_height,
        stride_width,
        stride_channel,
    ),
)
assert chunks.shape == (chunk_height, chunk_width, filter_height, filter_width, channel)

In [None]:
sliding_sample_image = chunks.reshape((-1, filter_height, filter_width, channel))
assert sliding_sample_image.shape == (
    chunk_height * chunk_width,
    filter_height,
    filter_width,
    channel,
)
batch_plot(
    sliding_sample_image,
    with_border=False,
    cmap=plt.cm.gray,
    tight_layout=None,
    wspace=0.01,
    hspace=0.01,
    imgsize=2,
)

# Apply filters to the sliding window chunks

<img src="https://miro.medium.com/v2/resize:fit:790/1*1VJDP6qDY9-ExTuQVEOlVg.gif">

[Convolution Operation with Stride Length = 2](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53)

In [None]:
filter_height = filter_width = 7
sample_filters = load_sample_filters(size=filter_height, sigma=1, channel=channel)

batch_plot(
    list(sample_filters.values()),
    list(sample_filters.keys()),
    with_border=True,
    tight_layout=None,
    wspace=0.1,
    hspace=0.1,
    imgsize=4,
)

In [None]:
# chunked with stride of 1
height, width, channel = sample_image.shape
filters = np.asarray(list(sample_filters.values()))
chunks = np.lib.stride_tricks.sliding_window_view(
    sample_image, window_shape=(filter_height, filter_width, channel)
).squeeze(axis=2)

chunk_height, chunk_width = height - filter_height + 1, width - filter_width + 1
assert chunks.shape == (chunk_height, chunk_width, filter_height, filter_width, channel)

# chunks:                                                      (chunk_height, chunk_width, filter_height, filter_width, channel)
# filters:                                                     (num_filters, filter_height, filter_width, channel)

# 1. step by step
# np.expand_dims(chunks, 2):                                   (chunk_height, chunk_width, 1, filter_height, filter_width, channel)
# np.expand_dims(chunks, 2) * filters:                         (chunk_height, chunk_width, num_filters, filter_height, filter_width, channel)
# np.expand_dims(chunks, 2) * filters).sum(axis=(-3,-2,-1)):   (chunk_height, chunk_width, num_filters, channel)
# filtered_sample_image = (np.expand_dims(chunks, 2) * filters).sum(axis=(-3,-2,-1)).transpose((2,0,1))
# 2. tensordot
# filtered_sample_image = np.tensordot(chunks, filters, axes=((2,3,4), (1,2,3))).transpose((2,0,1))
# 3. einsum
# filtered_sample_image = np.einsum('ijklc,nklc->nij',chunks,filters)

# 4. img2col
# filters.reshape((-1, filter_height*filter_width*channel)).T                           (filter_height*filter_width*channel, num_filters)
# chunks.reshape((chunk_height*chunk_width, filter_height*filter_width*channel))        (chunk_height*chunk_width*channel, filter_height*filter_width*channel)
filtered_sample_image = (
    chunks.reshape((chunk_height * chunk_width, filter_height * filter_width * channel))
    @ filters.reshape((-1, filter_height * filter_width * channel)).T
)
filtered_sample_image = filtered_sample_image.reshape((chunk_height, chunk_width, -1)).transpose((2, 0, 1))

assert filtered_sample_image.shape == (len(filters), chunk_height, chunk_width)

batch_plot(
    filtered_sample_image,
    list(sample_filters.keys()),
    with_border=False,
    cmap=plt.cm.gray,
    tight_layout=None,
    wspace=0.1,
    hspace=0.1,
    imgsize=6,
)