In [None]:
import numpy as np
import torch
from scipy import datasets as misc
from sklearn.datasets import load_sample_images

from utils import batch_plot
from conv import generate_output_size


benchmark = False

# gray

In [None]:
sample_image = misc.ascent()
sample_image = sample_image / sample_image.max()
batch_plot(np.expand_dims(sample_image, 0), with_border=False, cmap="gray", imgsize=6)

sample_image = torch.from_numpy(sample_image)

In [None]:
patch_size = 64
image_height, image_width = sample_image.shape
stride_height, stride_width = sample_image.stride()

output_height = generate_output_size(image_height, kernel_size=patch_size, stride=patch_size, padding=0)
output_width = generate_output_size(image_width, kernel_size=patch_size, stride=patch_size, padding=0)

patches = torch.as_strided(
    sample_image,
    size=(output_height, output_width, patch_size, patch_size),
    stride=(stride_height * patch_size, stride_width * patch_size, stride_height, stride_width),
)

assert torch.allclose(
    sample_image.reshape(output_height, patch_size, output_width, patch_size).transpose(1, 2), patches
)

assert torch.allclose(sample_image.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size), patches)

if benchmark:
    %timeit torch.as_strided(sample_image, size=(output_height, output_width, patch_size, patch_size), stride=(stride_height*patch_size, stride_width*patch_size, stride_height, stride_width))
    %timeit sample_image.reshape(output_height, patch_size, output_width, patch_size).transpose(1,2)
    %timeit sample_image.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)

In [None]:
batch_plot(
    patches.reshape(-1, patch_size, patch_size).numpy(),
    with_border=False,
    cmap="gray",
    tight_layout=None,
    wspace=0.01,
    hspace=0.01,
    imgsize=2,
    vmin=0,
    vmax=1,
)

# color

In [None]:
sample_image = misc.face()
n = np.min(sample_image.shape[:2])
sample_image = sample_image[:n, :n, :]
sample_image = sample_image / sample_image.max()
batch_plot(np.expand_dims(sample_image, 0), with_border=False, cmap="gray", imgsize=6)

sample_image = torch.from_numpy(sample_image)

In [None]:
patch_size = 96
image_height, image_width, channel = sample_image.shape
stride_height, stride_width, stride_channel = sample_image.stride()

output_height = generate_output_size(image_height, kernel_size=patch_size, stride=patch_size, padding=0)
output_width = generate_output_size(image_width, kernel_size=patch_size, stride=patch_size, padding=0)

patches = torch.as_strided(
    sample_image,
    size=(output_height, output_width, patch_size, patch_size, channel),
    stride=(stride_height * patch_size, stride_width * patch_size, stride_height, stride_width, stride_channel),
)

assert torch.allclose(
    sample_image.reshape(output_height, patch_size, output_width, patch_size, channel).transpose(1, 2), patches
)

assert torch.allclose(
    sample_image.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size).permute(0, 1, 3, 4, 2), patches
)

if benchmark:
    %timeit torch.as_strided(sample_image, size=(output_height, output_width, patch_size, patch_size, channel), stride=(stride_height*patch_size, stride_width*patch_size, stride_height, stride_width, stride_channel))
    %timeit sample_image.reshape(output_height, patch_size, output_width, patch_size, channel).transpose(1,2)
    %timeit sample_image.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size).permute(0,1,3,4,2)

In [None]:
batch_plot(
    patches.reshape(-1, patch_size, patch_size, channel).numpy(),
    with_border=False,
    cmap="gray",
    tight_layout=None,
    wspace=0.01,
    hspace=0.01,
    imgsize=2,
    vmin=0,
    vmax=1,
)

# batch

In [None]:
sample_images = np.asarray(load_sample_images().images)
n = 420
sample_images = sample_images[:, :n, :n, :]
sample_images = sample_images / sample_images.max()

batch_plot(sample_images, with_border=False, imgsize=6)
sample_images = torch.from_numpy(sample_images)

In [None]:
patch_size = 42
batch_size, image_height, image_width, channel = sample_images.shape
stride_batch, stride_height, stride_width, stride_channel = sample_images.stride()

output_height = generate_output_size(image_height, kernel_size=patch_size, stride=patch_size, padding=0)
output_width = generate_output_size(image_width, kernel_size=patch_size, stride=patch_size, padding=0)

patches = torch.as_strided(
    sample_images,
    size=(batch_size, output_height, output_width, patch_size, patch_size, channel),
    stride=(
        stride_batch,
        stride_height * patch_size,
        stride_width * patch_size,
        stride_height,
        stride_width,
        stride_channel,
    ),
)

assert torch.allclose(
    sample_images.reshape(batch_size, output_height, patch_size, output_width, patch_size, channel).transpose(2, 3),
    patches,
)

assert torch.allclose(
    sample_images.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(0, 1, 2, 4, 5, 3),
    patches,
)

if benchmark:
    %timeit torch.as_strided(sample_images, size=(batch_size, output_height, output_width, patch_size, patch_size, channel), stride=(stride_batch, stride_height*patch_size, stride_width*patch_size, stride_height, stride_width, stride_channel))
    %timeit sample_images.reshape(batch_size, output_height, patch_size, output_width, patch_size, channel).transpose(2,3)
    %timeit sample_images.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(0,1,2,4,5,3)

In [None]:
for patch in patches:
    batch_plot(
        patch.reshape(-1, patch_size, patch_size, channel).numpy(),
        with_border=False,
        cmap="gray",
        tight_layout=None,
        wspace=0.01,
        hspace=0.01,
        imgsize=2,
        vmin=0,
        vmax=1,
    )