In [1]:
import torch
import torch.nn as nn

In [2]:
input = torch.arange(24, dtype=torch.float32).reshape((1, 2, 3, 4))
print(input)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]]])


# nn.Unfold
Extracts sliding local blocks from a batched input tensor.

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold

- example1: kernel_size = (2, 3)

In [3]:
size = (2, 3)
unfold_size23 = nn.Unfold(kernel_size=size)
output_size23 = unfold_size23(input)
print(output_size23)
print(output_size23.shape)

tensor([[[ 0.,  1.,  4.,  5.],
         [ 1.,  2.,  5.,  6.],
         [ 2.,  3.,  6.,  7.],
         [ 4.,  5.,  8.,  9.],
         [ 5.,  6.,  9., 10.],
         [ 6.,  7., 10., 11.],
         [12., 13., 16., 17.],
         [13., 14., 17., 18.],
         [14., 15., 18., 19.],
         [16., 17., 20., 21.],
         [17., 18., 21., 22.],
         [18., 19., 22., 23.]]])
torch.Size([1, 12, 4])


In [4]:
# details in unfolding with kernel_size = (2, 3) 
print("---------- step1 ----------")
print(f"block from input:\n{input[:, [0], :size[0], :size[1]]}")
print(f"block in output:\n{output_size23[:, :size[0] * size[1], [0]]}\n")
print("---------- step2 ----------")
print(f"block from input:\n{input[:, [0], :size[0], 1:size[1] + 1]}")
print(f"block in output:\n{output_size23[:, :size[0] * size[1], [1]]}\n")
print("---------- step3 ----------")
print(f"block from input:\n{input[:, [0], 1:size[0]+1, :size[1]]}")
print(f"block in output:\n{output_size23[:, :size[0] * size[1], [2]]}\n")
print("---------- step4 ----------")
print(f"block from input:\n{input[:, [0], 1:size[0]+1, 1:size[1]+1]}")
print(f"block in output:\n{output_size23[:, :size[0] * size[1], [3]]}\n")

---------- step1 ----------
block from input:
tensor([[[[0., 1., 2.],
          [4., 5., 6.]]]])
block in output:
tensor([[[0.],
         [1.],
         [2.],
         [4.],
         [5.],
         [6.]]])

---------- step2 ----------
block from input:
tensor([[[[1., 2., 3.],
          [5., 6., 7.]]]])
block in output:
tensor([[[1.],
         [2.],
         [3.],
         [5.],
         [6.],
         [7.]]])

---------- step3 ----------
block from input:
tensor([[[[ 4.,  5.,  6.],
          [ 8.,  9., 10.]]]])
block in output:
tensor([[[ 4.],
         [ 5.],
         [ 6.],
         [ 8.],
         [ 9.],
         [10.]]])

---------- step4 ----------
block from input:
tensor([[[[ 5.,  6.,  7.],
          [ 9., 10., 11.]]]])
block in output:
tensor([[[ 5.],
         [ 6.],
         [ 7.],
         [ 9.],
         [10.],
         [11.]]])



- example2: kernel_size = (2, 2)

In [5]:
print(input)
print(input.shape)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]]])
torch.Size([1, 2, 3, 4])


In [6]:
unfold_size22 = nn.Unfold(kernel_size=(2, 2))
output_size22 = unfold_size22(input)
print(output_size22)
print(output_size22.shape)

tensor([[[ 0.,  1.,  2.,  4.,  5.,  6.],
         [ 1.,  2.,  3.,  5.,  6.,  7.],
         [ 4.,  5.,  6.,  8.,  9., 10.],
         [ 5.,  6.,  7.,  9., 10., 11.],
         [12., 13., 14., 16., 17., 18.],
         [13., 14., 15., 17., 18., 19.],
         [16., 17., 18., 20., 21., 22.],
         [17., 18., 19., 21., 22., 23.]]])
torch.Size([1, 8, 6])


- example3: kernel_size = 2

In [7]:
unfold_size2 = nn.Unfold(kernel_size=2)
output_size2 = unfold_size2(input)
print(f"kernel_size=2 has the same result as kernel_size=(2, 2): {(output_size2 == output_size22).all()}\n")
print(output_size2)
print(output_size2.shape)

kernel_size=2 has the same result as kernel_size=(2, 2): True

tensor([[[ 0.,  1.,  2.,  4.,  5.,  6.],
         [ 1.,  2.,  3.,  5.,  6.,  7.],
         [ 4.,  5.,  6.,  8.,  9., 10.],
         [ 5.,  6.,  7.,  9., 10., 11.],
         [12., 13., 14., 16., 17., 18.],
         [13., 14., 15., 17., 18., 19.],
         [16., 17., 18., 20., 21., 22.],
         [17., 18., 19., 21., 22., 23.]]])
torch.Size([1, 8, 6])


# how to reimplement it?

In [8]:
output_manual = []
kernel_size = [2, 3]
# sliding window approach
for i in torch.arange(input.size(2)-kernel_size[0]+1):
    for j in torch.arange(input.size(3)-kernel_size[1]+1):
        # index current patch
        tmp = input[:, :, i:i+kernel_size[0], j:j+kernel_size[1]]
        # flatten and keep batch dim
        tmp = tmp.contiguous().view(tmp.size(0), -1) # has a shape of [2, 30] afterwards
        output_manual.append(tmp)
        print(tmp)
        print(tmp.shape)
    
# stack outputs in dim2
output_manual = torch.stack(output_manual, dim=2)

# compare
print((output_manual == output_size23).all())
# > tensor(True)

tensor([[ 0.,  1.,  2.,  4.,  5.,  6., 12., 13., 14., 16., 17., 18.]])
torch.Size([1, 12])
tensor([[ 1.,  2.,  3.,  5.,  6.,  7., 13., 14., 15., 17., 18., 19.]])
torch.Size([1, 12])
tensor([[ 4.,  5.,  6.,  8.,  9., 10., 16., 17., 18., 20., 21., 22.]])
torch.Size([1, 12])
tensor([[ 5.,  6.,  7.,  9., 10., 11., 17., 18., 19., 21., 22., 23.]])
torch.Size([1, 12])
tensor(True)


In [9]:
output_manual

tensor([[[ 0.,  1.,  4.,  5.],
         [ 1.,  2.,  5.,  6.],
         [ 2.,  3.,  6.,  7.],
         [ 4.,  5.,  8.,  9.],
         [ 5.,  6.,  9., 10.],
         [ 6.,  7., 10., 11.],
         [12., 13., 16., 17.],
         [13., 14., 17., 18.],
         [14., 15., 18., 19.],
         [16., 17., 20., 21.],
         [17., 18., 21., 22.],
         [18., 19., 22., 23.]]])

# how to reshape the output of nn.Unfold to behave like a convolution

In [10]:
output_like_convolution = output_size23.reshape((1,2,6,4)).transpose(-1,-2).reshape(1,2,-1,2,3)
print(output_like_convolution)
print(output_like_convolution.shape)

tensor([[[[[ 0.,  1.,  2.],
           [ 4.,  5.,  6.]],

          [[ 1.,  2.,  3.],
           [ 5.,  6.,  7.]],

          [[ 4.,  5.,  6.],
           [ 8.,  9., 10.]],

          [[ 5.,  6.,  7.],
           [ 9., 10., 11.]]],


         [[[12., 13., 14.],
           [16., 17., 18.]],

          [[13., 14., 15.],
           [17., 18., 19.]],

          [[16., 17., 18.],
           [20., 21., 22.]],

          [[17., 18., 19.],
           [21., 22., 23.]]]]])
torch.Size([1, 2, 4, 2, 3])


# reference
https://discuss.pytorch.org/t/how-nn-unfold-works/137349