In [1]:
import torch
import torch.nn as nn

In [2]:
input = torch.randn(1, 2, 4, 3)
print(input)

tensor([[[[ 0.9708, -0.0897,  0.2432],
          [ 0.3745, -0.3273, -0.4943],
          [ 0.2818, -0.2199, -2.3114],
          [-0.8519,  0.8845,  0.9618]],

         [[ 0.6366, -0.9750, -0.0879],
          [-1.1110,  0.3916,  0.0232],
          [ 0.8274, -0.1599,  1.3982],
          [ 1.7917,  0.2212,  0.4457]]]])


# nn.Unfold
Extracts sliding local blocks from a batched input tensor.

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold

In [3]:
# kenel size1: (3, 1)
unfold_with_size31 = nn.Unfold(kernel_size=(3, 1))
output_with_size31 = unfold_with_size31(input)
print(output_with_size31)
print(output_with_size31.shape)

tensor([[[ 0.9708, -0.0897,  0.2432,  0.3745, -0.3273, -0.4943],
         [ 0.3745, -0.3273, -0.4943,  0.2818, -0.2199, -2.3114],
         [ 0.2818, -0.2199, -2.3114, -0.8519,  0.8845,  0.9618],
         [ 0.6366, -0.9750, -0.0879, -1.1110,  0.3916,  0.0232],
         [-1.1110,  0.3916,  0.0232,  0.8274, -0.1599,  1.3982],
         [ 0.8274, -0.1599,  1.3982,  1.7917,  0.2212,  0.4457]]])
torch.Size([1, 6, 6])


In [4]:
# Alternative method: torch.Tensor.unflod(dimension, step, size)
# Note that unfolding happens only in one dimension with torch.Tensor.unfold
print("1. the output of torch.Tensor.fold")
output_with_unfold = input.unfold(dimension=2, step=1, size=3) # kernel_size = (3, 1) -> dimension=2, size=3
print(output_with_unfold) # size=3 means that the sliding window length on the last dimension is 3
print(output_with_unfold.shape)
print("\n2. reshape the output of torch.Tensor.fold to behave like nn.unfold")
print(output_with_unfold.reshape(1, 2, 6, 3).transpose(2, 3).reshape(1, 6, 6))
print(output_with_unfold.reshape(1, 2, 6, 3).transpose(2, 3).reshape(1, 6, 6).shape)

1. the output of torch.Tensor.fold
tensor([[[[[ 0.9708,  0.3745,  0.2818],
           [-0.0897, -0.3273, -0.2199],
           [ 0.2432, -0.4943, -2.3114]],

          [[ 0.3745,  0.2818, -0.8519],
           [-0.3273, -0.2199,  0.8845],
           [-0.4943, -2.3114,  0.9618]]],


         [[[ 0.6366, -1.1110,  0.8274],
           [-0.9750,  0.3916, -0.1599],
           [-0.0879,  0.0232,  1.3982]],

          [[-1.1110,  0.8274,  1.7917],
           [ 0.3916, -0.1599,  0.2212],
           [ 0.0232,  1.3982,  0.4457]]]]])
torch.Size([1, 2, 2, 3, 3])

2. reshape the output of torch.Tensor.fold to behave like nn.unfold
tensor([[[ 0.9708, -0.0897,  0.2432,  0.3745, -0.3273, -0.4943],
         [ 0.3745, -0.3273, -0.4943,  0.2818, -0.2199, -2.3114],
         [ 0.2818, -0.2199, -2.3114, -0.8519,  0.8845,  0.9618],
         [ 0.6366, -0.9750, -0.0879, -1.1110,  0.3916,  0.0232],
         [-1.1110,  0.3916,  0.0232,  0.8274, -0.1599,  1.3982],
         [ 0.8274, -0.1599,  1.3982,  1.7917,  0.22

In [5]:
# kenel size2: 2
unfold_with_size2 = nn.Unfold(kernel_size=2)
output_with_size2 = unfold_with_size2(input)
print(output_with_size2)
print(output_with_size2.shape)

tensor([[[ 0.9708, -0.0897,  0.3745, -0.3273,  0.2818, -0.2199],
         [-0.0897,  0.2432, -0.3273, -0.4943, -0.2199, -2.3114],
         [ 0.3745, -0.3273,  0.2818, -0.2199, -0.8519,  0.8845],
         [-0.3273, -0.4943, -0.2199, -2.3114,  0.8845,  0.9618],
         [ 0.6366, -0.9750, -1.1110,  0.3916,  0.8274, -0.1599],
         [-0.9750, -0.0879,  0.3916,  0.0232, -0.1599,  1.3982],
         [-1.1110,  0.3916,  0.8274, -0.1599,  1.7917,  0.2212],
         [ 0.3916,  0.0232, -0.1599,  1.3982,  0.2212,  0.4457]]])
torch.Size([1, 8, 6])


In [6]:
# kenel size3: (2, 2)
# same output as kernel_size=2
unfold_with_size22 = nn.Unfold(kernel_size=(2, 2))
output_with_size22 = unfold_with_size22(input)
print(output_with_size22)
print(output_with_size22.shape)

tensor([[[ 0.9708, -0.0897,  0.3745, -0.3273,  0.2818, -0.2199],
         [-0.0897,  0.2432, -0.3273, -0.4943, -0.2199, -2.3114],
         [ 0.3745, -0.3273,  0.2818, -0.2199, -0.8519,  0.8845],
         [-0.3273, -0.4943, -0.2199, -2.3114,  0.8845,  0.9618],
         [ 0.6366, -0.9750, -1.1110,  0.3916,  0.8274, -0.1599],
         [-0.9750, -0.0879,  0.3916,  0.0232, -0.1599,  1.3982],
         [-1.1110,  0.3916,  0.8274, -0.1599,  1.7917,  0.2212],
         [ 0.3916,  0.0232, -0.1599,  1.3982,  0.2212,  0.4457]]])
torch.Size([1, 8, 6])


In [7]:
# kernel_size4: (2, 3)
unfold_with_size23 = nn.Unfold(kernel_size=(2, 3))
output_with_size23 = unfold_with_size23(input)
print(output_with_size23)
print(output_with_size23.shape)

tensor([[[ 0.9708,  0.3745,  0.2818],
         [-0.0897, -0.3273, -0.2199],
         [ 0.2432, -0.4943, -2.3114],
         [ 0.3745,  0.2818, -0.8519],
         [-0.3273, -0.2199,  0.8845],
         [-0.4943, -2.3114,  0.9618],
         [ 0.6366, -1.1110,  0.8274],
         [-0.9750,  0.3916, -0.1599],
         [-0.0879,  0.0232,  1.3982],
         [-1.1110,  0.8274,  1.7917],
         [ 0.3916, -0.1599,  0.2212],
         [ 0.0232,  1.3982,  0.4457]]])
torch.Size([1, 12, 3])


# how to reimplement it?

In [8]:
output_manual = []
kernel_size = [2, 3]
# sliding window approach
for i in torch.arange(input.size(2)-kernel_size[0]+1):
    for j in torch.arange(input.size(3)-kernel_size[1]+1):
        # index current patch
        tmp = input[:, :, i:i+kernel_size[0], j:j+kernel_size[1]]
        # flatten and keep batch dim
        tmp = tmp.contiguous().view(tmp.size(0), -1) # has a shape of [2, 30] afterwards
        output_manual.append(tmp)
        print(tmp)
        print(tmp.shape)
    
# stack outputs in dim2
output_manual = torch.stack(output_manual, dim=2)

# compare
print((output_manual == output_with_size23).all())
# > tensor(True)

tensor([[ 0.9708, -0.0897,  0.2432,  0.3745, -0.3273, -0.4943,  0.6366, -0.9750,
         -0.0879, -1.1110,  0.3916,  0.0232]])
torch.Size([1, 12])
tensor([[ 0.3745, -0.3273, -0.4943,  0.2818, -0.2199, -2.3114, -1.1110,  0.3916,
          0.0232,  0.8274, -0.1599,  1.3982]])
torch.Size([1, 12])
tensor([[ 0.2818, -0.2199, -2.3114, -0.8519,  0.8845,  0.9618,  0.8274, -0.1599,
          1.3982,  1.7917,  0.2212,  0.4457]])
torch.Size([1, 12])
tensor(True)


In [9]:
output_manual

tensor([[[ 0.9708,  0.3745,  0.2818],
         [-0.0897, -0.3273, -0.2199],
         [ 0.2432, -0.4943, -2.3114],
         [ 0.3745,  0.2818, -0.8519],
         [-0.3273, -0.2199,  0.8845],
         [-0.4943, -2.3114,  0.9618],
         [ 0.6366, -1.1110,  0.8274],
         [-0.9750,  0.3916, -0.1599],
         [-0.0879,  0.0232,  1.3982],
         [-1.1110,  0.8274,  1.7917],
         [ 0.3916, -0.1599,  0.2212],
         [ 0.0232,  1.3982,  0.4457]]])

# how to reshape the output of nn.Unfold to behave like a convolution

In [10]:
output_like_convolution = output_with_size23.reshape((1,2,6,3)).transpose(-1,-2).reshape(1,2,-1,2,3)
print(output_like_convolution)
print(output_like_convolution.shape)

tensor([[[[[ 0.9708, -0.0897,  0.2432],
           [ 0.3745, -0.3273, -0.4943]],

          [[ 0.3745, -0.3273, -0.4943],
           [ 0.2818, -0.2199, -2.3114]],

          [[ 0.2818, -0.2199, -2.3114],
           [-0.8519,  0.8845,  0.9618]]],


         [[[ 0.6366, -0.9750, -0.0879],
           [-1.1110,  0.3916,  0.0232]],

          [[-1.1110,  0.3916,  0.0232],
           [ 0.8274, -0.1599,  1.3982]],

          [[ 0.8274, -0.1599,  1.3982],
           [ 1.7917,  0.2212,  0.4457]]]]])
torch.Size([1, 2, 3, 2, 3])


# reference
https://discuss.pytorch.org/t/how-nn-unfold-works/137349