In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from conv import window_images
from utils import batch_plot

In [None]:
# fmt: off
zero = np.array([
    [0 ,0 ,5, 13,9, 1 ,0 ,0 ], 
    [0 ,0 ,13,15,10,15,5 ,0 ], 
    [0 ,3 ,15,2 ,0, 11,8 ,0 ],
    [0 ,4 ,12,0 ,0, 8 ,8 ,0 ],
    [0 ,5 ,8, 0 ,0, 9 ,8 ,0 ],
    [0 ,4 ,11,0 ,1, 12,7 ,0 ],
    [0 ,2 ,14,5 ,10,12,0 ,0 ],
    [0 ,0 ,6, 13,10,0 ,0 ,0 ]
])
# fmt: on

batch_plot(np.expand_dims(zero, 0), cmap=plt.cm.gray_r)

# Max pooling with stride=2

## 1. using max values to find the maximum locations

<img src="./images/max_pooling_with_duplicates.png">

### 1.1 generate sliding window views 

In [None]:
# fmt: off
def generate_windowed_2d(x, kernel=(2, 2), stride=(2, 2), channels_first=False):
    return (
        window_images(
            images=np.expand_dims(x, (0,3)), # (height, width) -> (batch, height, width, channel)
            kernel_size=kernel, 
            stride=stride, 
            channels_first=channels_first,
        ).squeeze(axis=(0,-1)) # (batch, sliding_width, sliding_height, kernel_height, kernel_width, channel) -> (sliding_width, sliding_height, kernel_height, kernel_width)
    )
# fmt: on

In [None]:
height, width = zero.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 2, 2
windowed_zero = generate_windowed_2d(x=zero, stride=(stride_height, stride_width))
assert windowed_zero.ndim == 4

### 1.2. find the max values along height and width axis

In [None]:
max_windowed_zero = np.max(windowed_zero, axis=(2, 3))
max_windowed_zero

### 1.3. expand the max values for broadcasting

In [None]:
expanded_max_windowed_zero = np.expand_dims(max_windowed_zero, axis=(2, 3))
assert expanded_max_windowed_zero.ndim == 4

### 1.4. compare the windowed array with the expanded max values to find the maximum locations

In [None]:
max_windowed_zero_index_with_duplicates = (windowed_zero == expanded_max_windowed_zero).astype(int)
assert max_windowed_zero_index_with_duplicates.shape == windowed_zero.shape

### 1.5 get the gradient

In [None]:
grad_duplicates = max_windowed_zero_index_with_duplicates.transpose((0, 2, 1, 3)).reshape(zero.shape)

grad_duplicates


## 2. using `argmax` to find the maximum locations

<img src="./images/max_pooling_without_duplicates.png">

### 2.1 generate sliding window views 

In [None]:
height, width = zero.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 2, 2
windowed_zero = generate_windowed_2d(x=zero, stride=(stride_height, stride_width))
assert windowed_zero.ndim == 4

### 2.2 reshape the windowed array

In [None]:
height_blocks, width_blocks, kernel_height, kernel_width = windowed_zero.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width
reshaped_windowed_zero = windowed_zero.reshape((block_size, kernel_size))

### 2.3 find maximum locations using `argmax`

In [None]:
reshaped_windowed_zero_argmax = np.argmax(reshaped_windowed_zero, axis=1)
max_windowed_zero = np.max(windowed_zero, axis=(2, 3))
max_windowed_zero

### 2.4 map maximum locations to the reshaped array 

In [None]:
reshaped_maximum_locations = np.zeros_like(reshaped_windowed_zero)
reshaped_maximum_locations[np.indices((block_size,)), reshaped_windowed_zero_argmax] = 1

### 2.5 get the gradient

In [None]:
max_windowed_zero_index_without_duplicates = reshaped_maximum_locations.reshape(windowed_zero.shape)
grad_normal = max_windowed_zero_index_without_duplicates.transpose((0, 2, 1, 3)).reshape(zero.shape)

grad_normal

### 2.6 compare the result with pytorch

In [None]:
torch_zero = torch.tensor(zero, requires_grad=True, dtype=torch.float).unsqueeze(0)
torch_zero.retain_grad()
pool = torch.nn.functional.max_pool2d(
    torch_zero,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.detach().numpy(), max_windowed_zero)

pool.sum().backward()
assert np.allclose(torch_zero.grad.numpy(), grad_normal)

# Max pooling with stride=1


## 1. using `argmax` to find the maximum locations

<img src="./images/max_pooling_with_overlaps.png">

### 1.1 generate sliding window views 

In [None]:
height, width = zero.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 1, 1
windowed_zero = generate_windowed_2d(x=zero, stride=(stride_height, stride_width))
assert windowed_zero.ndim == 4

### 1.2 reshape the windowed array

In [None]:
height_blocks, width_blocks, kernel_height, kernel_width = windowed_zero.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width
reshaped_windowed_zero = windowed_zero.reshape((block_size, kernel_size))

### 1.3 find maximum locations using `argmax`

In [None]:
reshaped_windowed_zero_argmax = np.argmax(reshaped_windowed_zero, axis=1)
max_windowed_zero = np.max(windowed_zero, axis=(2, 3))
max_windowed_zero

### 1.4 map maximum locations to the reshaped array 

In [None]:
reshaped_maximum_locations = np.zeros_like(reshaped_windowed_zero)
reshaped_maximum_locations[np.indices((block_size,)), reshaped_windowed_zero_argmax] = 1

### 1.5 get the gradient

#### 1.5.1 example of handle overlap (need better ideas)

<img src="./images/combine_overlaps.png">


In [None]:
x = np.array(
    [
        [0, 0, 0, 0, 2, 1, 2, 1],
        [2, 1, 0, 2, 0, 0, 1, 0],
        [0, 2, 1, 1, 1, 2, 2, 2],
        [1, 2, 0, 1, 0, 0, 0, 2],
        [0, 2, 1, 0, 1, 1, 2, 1],
        [0, 2, 0, 1, 2, 1, 2, 1],
        [1, 0, 0, 1, 2, 2, 1, 2],
        [0, 2, 1, 0, 2, 2, 1, 2],
    ]
).reshape((4, 4, 2, 2))

col_start, col_end = x[:, 0, :, 0], x[:, -1, :, -1]
col_middle = x[:, :-1, :, 1] + x[:, 1:, :, 0]

col_overlap = np.concatenate([col_start[:, None, :], col_middle, col_end[:, None, :]], axis=1).transpose((0, 2, 1))

assert col_overlap.shape == (4, 2, 5)

row_start, row_end = col_overlap[0, 0, :], col_overlap[-1, -1, :]
row_middle = col_overlap[:-1, 1, :] + col_overlap[1:, 0, :]
row_overlap = np.concatenate([row_start[None, :], row_middle, row_end[None, :]], axis=0)

assert row_overlap.shape == (5, 5)


def combine_overlap(x):
    col_start, col_end = x[:, 0, :, 0], x[:, -1, :, -1]
    col_middle = x[:, :-1, :, 1] + x[:, 1:, :, 0]
    col_overlap = np.concatenate([col_start[:, None, :], col_middle, col_end[:, None, :]], axis=1).transpose((0, 2, 1))

    row_start, row_end = col_overlap[0, 0, :], col_overlap[-1, -1, :]
    row_middle = col_overlap[:-1, 1, :] + col_overlap[1:, 0, :]
    row_overlap = np.concatenate([row_start[None, :], row_middle, row_end[None, :]], axis=0)
    return row_overlap

#### 1.5.2 calculate the gradient

In [None]:
max_windowed_zero_index_without_duplicates = reshaped_maximum_locations.reshape(windowed_zero.shape)

grad_normal = combine_overlap(max_windowed_zero_index_without_duplicates)
assert grad_normal.shape == zero.shape

In [None]:
torch_zero = torch.tensor(zero, requires_grad=True, dtype=torch.float).unsqueeze(0)
torch_zero.retain_grad()
pool = torch.nn.functional.max_pool2d(
    torch_zero,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.detach().numpy(), max_windowed_zero)

pool.sum().backward()
assert np.allclose(torch_zero.grad.numpy(), grad_normal)

# Mean pooling with stride=2

## 1. simply calculate mean of each sliding views

### 1.1 generate sliding window views 

In [None]:
height, width = zero.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 2, 2
windowed_zero = generate_windowed_2d(x=zero, stride=(stride_height, stride_width))

height_blocks, width_blocks, kernel_height, kernel_width = windowed_zero.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width
assert windowed_zero.ndim == 4

### 1.2. find the mean values along height and width axis

In [None]:
mean_windowed_zero = np.mean(windowed_zero, axis=(2, 3))
mean_windowed_zero

### 1.3. get the gradient

In [None]:
grad_normal = np.ones_like(zero) / (kernel_height * kernel_width)

### 1.4 compare the result with pytorch

In [None]:
torch_zero = torch.tensor(zero, requires_grad=True, dtype=torch.float).unsqueeze(0)
torch_zero.retain_grad()
pool = torch.nn.functional.avg_pool2d(
    torch_zero,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.detach().numpy(), mean_windowed_zero)

pool.sum().backward()
assert np.allclose(torch_zero.grad.numpy(), grad_normal)

# Mean pooling with stride=1

## 1. simply calculate mean of each sliding views

### 1.1 generate sliding window views 

In [None]:
height, width = zero.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 1, 1
windowed_zero = generate_windowed_2d(x=zero, stride=(stride_height, stride_width))

height_blocks, width_blocks, kernel_height, kernel_width = windowed_zero.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width
assert windowed_zero.ndim == 4

### 1.2. find the mean values along height and width axis

In [None]:
mean_windowed_zero = np.mean(windowed_zero, axis=(2, 3))
mean_windowed_zero

### 1.3. get the gradient

In [None]:
position_matrix = combine_overlap(np.ones_like(windowed_zero))

grad_normal = position_matrix / (kernel_height * kernel_width)

### 2.6 compare the result with pytorch

In [None]:
torch_zero = torch.tensor(zero, requires_grad=True, dtype=torch.float).unsqueeze(0)
torch_zero.retain_grad()
pool = torch.nn.functional.avg_pool2d(
    torch_zero,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.detach().numpy(), mean_windowed_zero)

pool.sum().backward()
assert np.allclose(torch_zero.grad.numpy(), grad_normal)

# handle boundary case

In [None]:
x = np.arange(35).reshape((7, 5))

height, width = x.shape
kernel_height, kernel_width = 2, 2
stride_height, stride_width = 2, 2
windowed_x = generate_windowed_2d(x=x, stride=(stride_height, stride_width))

height_blocks, width_blocks, kernel_height, kernel_width = windowed_x.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width
assert windowed_x.ndim == 4

mean_windowed_x = np.mean(windowed_x, axis=(2, 3))

grad_height, grad_width = height_blocks * stride_height, width_blocks * stride_width
grad_x = np.ones((grad_height, grad_width)) / (kernel_height * kernel_width)
padding_height, padding_width = height - grad_height, width - grad_width
grad_x = np.pad(grad_x, ((0, padding_height), (0, padding_width)))

torch_x = torch.tensor(x, requires_grad=True, dtype=torch.float).unsqueeze(0)
torch_x.retain_grad()
pool = torch.nn.functional.avg_pool2d(
    torch_x,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.detach().numpy(), mean_windowed_x)

pool.sum().backward()
assert np.allclose(torch_x.grad.numpy(), grad_x)

# handle batch size

In [None]:
from conv import load_sample_filters

filters = np.asarray(list(load_sample_filters(size=7, channel=1).values()))[:, :, :, None]
assert filters.ndim == 4

batch_size, height, width, channels = filters.shape

kernel_height, kernel_width = 2, 2
stride_height, stride_width = 2, 2

windowed_filters = window_images(
    filters,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
    channels_first=False,
).squeeze(axis=-1)

(
    batch_size,
    height_blocks,
    width_blocks,
    kernel_height,
    kernel_width,
) = windowed_filters.shape
block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width

max_pooling_filters = np.max(windowed_filters, axis=(3, 4))
reshaped_windowed_filters = windowed_filters.reshape((batch_size, block_size, kernel_size))
reshaped_windowed_filters_argmax = np.argmax(
    reshaped_windowed_filters, axis=2
)  # batch_size, stride_height * stride_width

indices = np.stack(np.indices((batch_size, block_size)), axis=-1)  # shape: (batch_size, block_size, 2)
reshaped_maximum_locations = np.zeros_like(reshaped_windowed_filters)
reshaped_maximum_locations[indices[..., 0], indices[..., 1], reshaped_windowed_filters_argmax] = 1

max_windowed_filters_index_without_duplicates = reshaped_maximum_locations.reshape(windowed_filters.shape)

# check padding
grad_height, grad_width = height_blocks * stride_height, width_blocks * stride_width
padding_height, padding_width = height - grad_height, width - grad_width

grad_normal = max_windowed_filters_index_without_duplicates.transpose((0, 1, 3, 2, 4)).reshape(
    (batch_size, grad_height, grad_width, channels)
)

if padding_height > 0 or padding_width > 0:
    grad_normal = np.pad(grad_normal, ((0, 0), (0, padding_height), (0, padding_width), (0, 0)))

assert grad_normal.shape == filters.shape

In [None]:
torch_filters = torch.tensor(filters, requires_grad=True, dtype=torch.float).permute(
    0, 3, 2, 1
)  # batch_size, channels, height, width
torch_filters.retain_grad()
pool = torch.nn.functional.max_pool2d(
    torch_filters,
    kernel_size=(kernel_height, kernel_width),
    stride=(stride_height, stride_width),
)

assert np.allclose(pool.permute(0, 3, 2, 1).detach().numpy().squeeze(axis=-1), max_pooling_filters)

pool.sum().backward()
assert np.allclose(torch_filters.grad.permute(0, 3, 2, 1).numpy(), grad_normal)

# make a new function

In [None]:
def max_pool2d_with_grad(x, kernel_size=None, stride=None):
    if kernel_size is None:
        kernel_size = (2, 2)
    if stride is None:
        stride = (2, 2)

    stride_height, stride_width = stride
    batch_size, height, width, channels = x.shape

    windowed_x = window_images(x, kernel_size=kernel_size, stride=stride, channels_first=False)

    (
        batch_size,
        height_blocks,
        width_blocks,
        kernel_height,
        kernel_width,
        channels,
    ) = windowed_x.shape
    block_size, kernel_size = height_blocks * width_blocks, kernel_height * kernel_width

    max_pooling_x = np.max(windowed_x, axis=(3, 4))

    # backward
    reshaped_windowed_x = windowed_x.reshape((batch_size, block_size, kernel_size, channels))
    reshaped_windowed_x_argmax = np.argmax(reshaped_windowed_x, axis=2)  # batch_size, block_size, channels

    indices = np.stack(
        np.indices((batch_size, block_size, channels)), axis=-1
    )  # shape: (batch_size, block_size, channels, 3)
    reshaped_maximum_locations = np.zeros_like(reshaped_windowed_x)

    reshaped_maximum_locations[indices[..., 0], indices[..., 1], reshaped_windowed_x_argmax, indices[..., 2]] = 1

    max_windowed_x_index = reshaped_maximum_locations.reshape(windowed_x.shape)

    # check padding
    grad_height, grad_width = height_blocks * stride_height, width_blocks * stride_width
    padding_height, padding_width = height - grad_height, width - grad_width

    grad = max_windowed_x_index.transpose((0, 1, 3, 2, 4, 5)).reshape((batch_size, grad_height, grad_width, channels))

    if padding_height > 0 or padding_width > 0:
        grad = np.pad(grad, ((0, 0), (0, padding_height), (0, padding_width), (0, 0)))

    return max_pooling_x, grad

In [None]:
kernel_size, stride, channels = 2, 2, 3
filters = np.asarray(list(load_sample_filters(size=7, channel=channels).values()))  # .transpose((0,3,1,2))
assert filters.ndim == 4

filters_pooled, filters_grad = max_pool2d_with_grad(filters)
batch_size, block_height, block_width, channels = filters_pooled.shape
block_size = block_height * block_width

torch_filters = torch.tensor(filters, requires_grad=True, dtype=torch.float).permute(
    0, 3, 1, 2
)  # batch_size, channels, height, width
torch_filters.retain_grad()
pool = torch.nn.functional.max_pool2d(torch_filters, kernel_size=kernel_size, stride=stride)

assert np.allclose(pool.permute(0, 2, 3, 1).detach().numpy(), filters_pooled)

pool.sum().backward()
assert np.allclose(torch_filters.grad.permute(0, 2, 3, 1).numpy(), filters_grad)

In [None]:
from sklearn.datasets import load_sample_images

images = np.asarray(load_sample_images().images)

images.shape

In [None]:
images_pooled, images_grad = max_pool2d_with_grad(images)
batch_size, block_height, block_width, channels = images_pooled.shape
block_size = block_height * block_width

torch_images = torch.tensor(images, requires_grad=True, dtype=torch.float).permute(
    0, 3, 1, 2
)  # batch_size, channels, height, width
torch_images.retain_grad()
pool = torch.nn.functional.max_pool2d(torch_images, kernel_size=kernel_size, stride=stride)

assert np.allclose(pool.permute(0, 2, 3, 1).detach().numpy(), images_pooled)

pool.sum().backward()
assert np.allclose(torch_images.grad.permute(0, 2, 3, 1).numpy(), images_grad)

# channel first

In [None]:
from sklearn.datasets import load_sample_images
from conv import max_pool2d_with_grad, avg_pool2d_with_grad

images = np.asarray(load_sample_images().images).transpose((0, 3, 1, 2))

images.shape

In [None]:
images_pooled, images_grad = max_pool2d_with_grad(images)
batch_size, block_height, block_width, channels = images_pooled.shape
block_size = block_height * block_width

torch_images = torch.tensor(images, requires_grad=True, dtype=torch.float)
torch_images.retain_grad()
pool = torch.nn.functional.max_pool2d(torch_images, kernel_size=kernel_size, stride=stride)

assert np.allclose(pool.detach().numpy(), images_pooled)

pool.sum().backward()
assert np.allclose(torch_images.grad.numpy(), images_grad)

In [None]:
images_pooled, images_grad = avg_pool2d_with_grad(images)
batch_size, block_height, block_width, channels = images_pooled.shape
block_size = block_height * block_width

torch_images = torch.tensor(images, requires_grad=True, dtype=torch.float)
torch_images.retain_grad()
pool = torch.nn.functional.avg_pool2d(torch_images, kernel_size=kernel_size, stride=stride)

assert np.allclose(pool.detach().numpy(), images_pooled)

pool.sum().backward()
assert np.allclose(torch_images.grad.numpy(), images_grad)