In [1]:
from __future__ import annotations
import torch
import math

In [2]:
# Move to project root
from pathlib import Path
import os

if not Path("./src").is_dir():
    for parent_path in Path.cwd().parents:
        if (parent_path / "src").is_dir():
            os.chdir(parent_path)
            break
    else:
        raise FileNotFoundError("Can't find project root")

assert Path("./src").is_dir()

In [3]:
from src import load_data

k_mnist = load_data.k_mnist()
k_mnist

Dataset(x_train=torch.Size([60000, 1, 28, 28]), x_test=torch.Size([10000, 1, 28, 28]), y_train=torch.Size([60000]), y_test=torch.Size([10000]))

In [6]:
test_i = k_mnist.x_train[:4096].repeat((1, 5, 1, 1)).cuda()
test_i.shape

torch.Size([4096, 5, 28, 28])

In [7]:
def _as_tup(v: int | tuple[int] | tuple[int, int]):
    if isinstance(v, int):
        return v, v
    if len(v) == 1:
        return v[0], v[0]
    if len(v) == 2:
        return v

    raise ValueError(f"Invalid 2-tuple-like object {v=}")

In [8]:
def unfold_view(imgs: torch.Tensor, kernel_size: int | tuple[int, int],
                dilation: int | tuple[int, int] = 1,
                padding: int | tuple[int, int] = 0,
                stride: int | tuple[int, int] = 1):
    if imgs.ndim != 4:
        raise ValueError(f"imgs must be in BCHW, but {imgs.shape=}")
    krs_y, krs_x = _as_tup(kernel_size)
    dil_y, dil_x = _as_tup(dilation)
    str_y, str_x = _as_tup(stride)
    pad_y, pad_x = _as_tup(padding)

    if pad_x or pad_y:
        msg = ("unfold_view produces a view, and cannot pad."
               " Please perform the padding beforehand.")
        raise ValueError(msg)

    out_y = math.floor(
        (imgs.shape[2] + 2 * pad_y - dil_y * (krs_y - 1) - 1) / str_y + 1
    )
    out_x = math.floor(
        (imgs.shape[3] + 2 * pad_x - dil_x * (krs_x - 1) - 1) / str_x + 1
    )
    if out_y <= 0:
        raise ValueError("Output collapsed in y-dimension")
    if out_x <= 0:
        raise ValueError("Output collapsed in x-dimension")

    return imgs.as_strided(
        (
            imgs.shape[0],
            imgs.shape[1],
            krs_y,
            krs_x,
            out_y,
            out_x,
        ),
        (
            imgs.stride(0),
            imgs.stride(1),
            imgs.stride(2) * dil_y,
            imgs.stride(3) * dil_x,
            imgs.stride(2) * str_y,
            imgs.stride(3) * str_x,
        )
    )

In [21]:
test_s = unfold_view(test_i, (3, 3))
print(test_s.shape)

torch.Size([4096, 5, 3, 3, 26, 26])


In [11]:
test_u = torch.nn.functional.unfold(test_i, (3, 3)).view(4096, 5, 3, 3, 26, 26)
test_u.shape

torch.Size([4096, 5, 3, 3, 26, 26])

In [12]:
torch.isclose(test_s, test_u).all()

tensor(True, device='cuda:0')

In [29]:
test_large = unfold_view(test_i, (7, 6), stride=(1, 2))
test_large.shape

torch.Size([4096, 5, 7, 6, 22, 12])

In [30]:
torch.allclose(torch.nn.functional.unfold(test_ig, (7, 6), stride=(1, 2)).view(4096, 5, 7, 6, 22, 12), test_large)

True

In [31]:
test_strange = unfold_view(test_i, (10, 2), stride=(1, 4), dilation=(2, 3))
test_strange.shape

torch.Size([4096, 5, 10, 2, 10, 7])

In [33]:
torch.allclose(torch.nn.functional.unfold(test_ig, (10, 2), stride=(1, 4), dilation=(2, 3))
               .view(4096, 5, 10, 2, 10, 7), test_strange)


True

In [35]:
a = torch.arange(12).reshape(3, 4)
torch.amax(a, dim=(0, 1))

tensor(11)

In [44]:
ag = torch.ones(3, 4, requires_grad=True)
torch.max(ag, dim=0).values.sum().backward()
ag.grad

tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [13]:
test_ig = test_i.clone().requires_grad_(True)
res = torch.nn.functional.unfold(test_ig, (7, 7), stride=2).view(4096, 5, 7, 7, 11, 11)
res.sum().backward()
ug = test_ig.grad.clone()

In [14]:
test_ig = test_i.clone().requires_grad_(True)
res = unfold_view(test_ig, (7, 7), stride=2)
res.sum().backward()
sg = test_ig.grad.clone()

In [15]:
torch.isclose(ug, sg).all()

tensor(True, device='cuda:0')

In [19]:
test_ig = test_i.clone().requires_grad_(True)


def run_one():
    padded = torch.constant_pad_nd(
        test_ig, (3, 3, 3, 3), 5
    )
    time_res = torch.nn.functional.unfold(padded, (7, 7), stride=2).view(4096, 5, 7, 7, 14, 14)
    time_res.sum().backward()
    torch.cuda.synchronize()


run_one()
%timeit run_one()

56.2 ms ± 5.82 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
test_ig = test_i.clone().requires_grad_(True)


def run_one():
    padded = torch.constant_pad_nd(
        test_ig, (3, 3, 3, 3), 5
    )
    time_res = unfold_view(padded, (7, 7), stride=2)
    time_res.sum().backward()
    torch.cuda.synchronize()


run_one()
%timeit run_one()

3.35 ms ± 2.05 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
