In [1]:
import torch
import torch.nn as nn

In [13]:
input = torch.randn(1, 2, 3, 4)
print(input)

tensor([[[[-0.8331,  0.7040, -1.4139, -0.6550],
          [ 0.0994, -2.4464, -0.0328, -0.1848],
          [-0.5071, -0.5411, -0.6181, -0.4922]],

         [[-2.0584,  0.5070, -1.6810,  1.1463],
          [-1.0809, -2.6029, -0.8712,  0.0188],
          [ 0.1240,  0.3978, -1.2131,  0.4445]]]])


# nn.Unfold
Extracts sliding local blocks from a batched input tensor.

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold

In [14]:
unfold = nn.Unfold(kernel_size=(2, 3))

In [15]:
output = unfold(input)
print(output)
print(output.shape)

tensor([[[-0.8331,  0.7040,  0.0994, -2.4464],
         [ 0.7040, -1.4139, -2.4464, -0.0328],
         [-1.4139, -0.6550, -0.0328, -0.1848],
         [ 0.0994, -2.4464, -0.5071, -0.5411],
         [-2.4464, -0.0328, -0.5411, -0.6181],
         [-0.0328, -0.1848, -0.6181, -0.4922],
         [-2.0584,  0.5070, -1.0809, -2.6029],
         [ 0.5070, -1.6810, -2.6029, -0.8712],
         [-1.6810,  1.1463, -0.8712,  0.0188],
         [-1.0809, -2.6029,  0.1240,  0.3978],
         [-2.6029, -0.8712,  0.3978, -1.2131],
         [-0.8712,  0.0188, -1.2131,  0.4445]]])
torch.Size([1, 12, 4])


# how to reimplement it?

In [17]:
output_manual = []
kernel_size = [2, 3]
# sliding window approach
for i in torch.arange(input.size(2)-kernel_size[0]+1):
    for j in torch.arange(input.size(3)-kernel_size[1]+1):
        # index current patch
        tmp = input[:, :, i:i+kernel_size[0], j:j+kernel_size[1]]
        # flatten and keep batch dim
        tmp = tmp.contiguous().view(tmp.size(0), -1) # has a shape of [2, 30] afterwards
        output_manual.append(tmp)
        print(tmp)
        print(tmp.shape)
    
# stack outputs in dim2
output_manual = torch.stack(output_manual, dim=2)

# compare
print((output_manual == output).all())
# > tensor(True)

tensor([[-0.8331,  0.7040, -1.4139,  0.0994, -2.4464, -0.0328, -2.0584,  0.5070,
         -1.6810, -1.0809, -2.6029, -0.8712]])
torch.Size([1, 12])
tensor([[ 0.7040, -1.4139, -0.6550, -2.4464, -0.0328, -0.1848,  0.5070, -1.6810,
          1.1463, -2.6029, -0.8712,  0.0188]])
torch.Size([1, 12])
tensor([[ 0.0994, -2.4464, -0.0328, -0.5071, -0.5411, -0.6181, -1.0809, -2.6029,
         -0.8712,  0.1240,  0.3978, -1.2131]])
torch.Size([1, 12])
tensor([[-2.4464, -0.0328, -0.1848, -0.5411, -0.6181, -0.4922, -2.6029, -0.8712,
          0.0188,  0.3978, -1.2131,  0.4445]])
torch.Size([1, 12])
tensor(True)


In [18]:
output_manual

tensor([[[-0.8331,  0.7040,  0.0994, -2.4464],
         [ 0.7040, -1.4139, -2.4464, -0.0328],
         [-1.4139, -0.6550, -0.0328, -0.1848],
         [ 0.0994, -2.4464, -0.5071, -0.5411],
         [-2.4464, -0.0328, -0.5411, -0.6181],
         [-0.0328, -0.1848, -0.6181, -0.4922],
         [-2.0584,  0.5070, -1.0809, -2.6029],
         [ 0.5070, -1.6810, -2.6029, -0.8712],
         [-1.6810,  1.1463, -0.8712,  0.0188],
         [-1.0809, -2.6029,  0.1240,  0.3978],
         [-2.6029, -0.8712,  0.3978, -1.2131],
         [-0.8712,  0.0188, -1.2131,  0.4445]]])

# how to reshape the output of nn.Unfold to behave like a convolution

In [27]:
output_like_convolution = output.reshape((1,2,6,4)).transpose(-1,-2).reshape(1,2,-1,2,3)
print(output_like_convolution)
print(output_like_convolution.shape)

tensor([[[[[-0.8331,  0.7040, -1.4139],
           [ 0.0994, -2.4464, -0.0328]],

          [[ 0.7040, -1.4139, -0.6550],
           [-2.4464, -0.0328, -0.1848]],

          [[ 0.0994, -2.4464, -0.0328],
           [-0.5071, -0.5411, -0.6181]],

          [[-2.4464, -0.0328, -0.1848],
           [-0.5411, -0.6181, -0.4922]]],


         [[[-2.0584,  0.5070, -1.6810],
           [-1.0809, -2.6029, -0.8712]],

          [[ 0.5070, -1.6810,  1.1463],
           [-2.6029, -0.8712,  0.0188]],

          [[-1.0809, -2.6029, -0.8712],
           [ 0.1240,  0.3978, -1.2131]],

          [[-2.6029, -0.8712,  0.0188],
           [ 0.3978, -1.2131,  0.4445]]]]])
torch.Size([1, 2, 4, 2, 3])


# reference
https://discuss.pytorch.org/t/how-nn-unfold-works/137349