In [2]:
import numpy as np
import torch

### Testing how `torch.nn.Unfold` works

In [9]:
h, w = 10, 20
img = np.arange(h*w).reshape(1, 1, h, w)
img = np.tile(img, [1, 3, 1, 1])
img = torch.from_numpy(img).type(torch.double)
img[0, 1, :, :] *= -1
img[0, 1, :, :] *= 10
print(img.shape)
print(img[0, :, 1, 1])

torch.Size([1, 3, 10, 20])
tensor([  21., -210.,   21.], dtype=torch.float64)


In [11]:
psize = 3
unfold = torch.nn.Unfold(kernel_size=(psize, psize), dilation=2, padding=0, stride=1)
output = unfold(img)
print(output.shape)

torch.Size([1, 27, 96])


In [13]:
# reshape to [B, 3, psize, psize, h-ofs*2, w-ofs*2]
effective_psize = 1 + 2 * (psize-1) 
ofs = (effective_psize-1)//2
patches = output.view(1, 3, psize, psize, h-ofs*2, w-ofs*2)
print(patches.shape)
print(patches[0, :, :, :, 0, 0]) # top-left patch

torch.Size([1, 3, 3, 3, 6, 16])
tensor([[[   0.,    2.,    4.],
         [  40.,   42.,   44.],
         [  80.,   82.,   84.]],

        [[  -0.,  -20.,  -40.],
         [-400., -420., -440.],
         [-800., -820., -840.]],

        [[   0.,    2.,    4.],
         [  40.,   42.,   44.],
         [  80.,   82.,   84.]]], dtype=torch.float64)


In [14]:
print(patches[0, :, :, :, 0, -1]) # top-right patch

tensor([[[  15.,   17.,   19.],
         [  55.,   57.,   59.],
         [  95.,   97.,   99.]],

        [[-150., -170., -190.],
         [-550., -570., -590.],
         [-950., -970., -990.]],

        [[  15.,   17.,   19.],
         [  55.,   57.,   59.],
         [  95.,   97.,   99.]]], dtype=torch.float64)


In [16]:
print(img[0, :, -5:, -5:])
print(patches[0, :, :, :, -1, -1]) # bottom-right

tensor([[[  115.,   116.,   117.,   118.,   119.],
         [  135.,   136.,   137.,   138.,   139.],
         [  155.,   156.,   157.,   158.,   159.],
         [  175.,   176.,   177.,   178.,   179.],
         [  195.,   196.,   197.,   198.,   199.]],

        [[-1150., -1160., -1170., -1180., -1190.],
         [-1350., -1360., -1370., -1380., -1390.],
         [-1550., -1560., -1570., -1580., -1590.],
         [-1750., -1760., -1770., -1780., -1790.],
         [-1950., -1960., -1970., -1980., -1990.]],

        [[  115.,   116.,   117.,   118.,   119.],
         [  135.,   136.,   137.,   138.,   139.],
         [  155.,   156.,   157.,   158.,   159.],
         [  175.,   176.,   177.,   178.,   179.],
         [  195.,   196.,   197.,   198.,   199.]]], dtype=torch.float64)
tensor([[[  115.,   117.,   119.],
         [  155.,   157.,   159.],
         [  195.,   197.,   199.]],

        [[-1150., -1170., -1190.],
         [-1550., -1570., -1590.],
         [-1950., -1970., -1990

#### Transpose to get shape [B, 3, H - 2 * offset, W - 2 * offset, patch_size, patch_size]

In [17]:
patches_t = patches.permute(0, 1, 4, 5, 2, 3)

In [18]:
print(patches_t[0, :, 0, 0, :, :]) # top-left

tensor([[[   0.,    2.,    4.],
         [  40.,   42.,   44.],
         [  80.,   82.,   84.]],

        [[  -0.,  -20.,  -40.],
         [-400., -420., -440.],
         [-800., -820., -840.]],

        [[   0.,    2.,    4.],
         [  40.,   42.,   44.],
         [  80.,   82.,   84.]]], dtype=torch.float64)


In [20]:
print(patches_t[0, :, 0, -1, :, :]) # top-right patch

tensor([[[  15.,   17.,   19.],
         [  55.,   57.,   59.],
         [  95.,   97.,   99.]],

        [[-150., -170., -190.],
         [-550., -570., -590.],
         [-950., -970., -990.]],

        [[  15.,   17.,   19.],
         [  55.,   57.,   59.],
         [  95.,   97.,   99.]]], dtype=torch.float64)
