In [24]:
import torch
import torch.nn as nn

In [25]:
batch_size = 1
in_channels = 2
out_channels = 3
height = 4
width = 4
rank = 2
padding = 1
stride = 1
kernel_size = 3

In [26]:
def ori_calcul(input, C):
    padded_I = nn.functional.pad(input, pad=[padding]*4)
    padded_I = padded_I.permute(0, 2, 3, 1)
    padded_h = padded_I.shape[1]
    padded_w = padded_I.shape[2]
    padded_I_col = padded_I.reshape(batch_size * padded_h * padded_w, in_channels)
    weight_col = C.reshape(in_channels, out_channels * rank)
    output = torch.matmul(padded_I_col, weight_col).reshape(batch_size, padded_h, padded_w, out_channels, rank)

    return output

In [27]:
def my_calcul(input, C):
    output = torch.zeros(batch_size, height+2*padding, width+2*padding, out_channels, rank)
    for i in range(out_channels):
        conv = nn.Conv2d(in_channels, rank, 1, padding=padding, bias=False)
        conv.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv(input)
        output[:, :, :, i] = out.permute(0, 2, 3, 1)
    
    return output


In [28]:
input = torch.rand(batch_size, in_channels, height, width)
C = torch.rand(in_channels, out_channels, rank)

ori_out = ori_calcul(input, C)
my_out = my_calcul(input, C).detach().clone()
print(torch.all(torch.lt(torch.abs(torch.add(ori_out, -my_out)), 1e-12)))
print(torch.eq(ori_out, my_out))

tensor(True)
tensor([[[[[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]]],


         [[[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]]],


         [[[True, True],
           [True, True],
         

In [29]:
ori_out.shape

torch.Size([1, 6, 6, 3, 2])

In [30]:
def ori_body(input, B):
    w = input.size(2) - 2*padding
    Oc = input.permute(0, 1, 3, 4, 2)
    # Add a new axis to B for broadcasting, B's shape becomes (1, 1, Cout, r, 1, d)
    B_expanded = B[None, None, :, :, None, :]
    # Assuming 'Oc' is a 5-dimensional and 'w' and 'd' are the window width and depth, respectively
    window_indices = torch.arange(start=0, end=w, step=stride)[:, None] + torch.arange(kernel_size)
    Oc_expanded = Oc[:, :, :, :, window_indices]

    # Perform the element-wise multiplication and sum over the last axis (d)
    output = torch.sum(Oc_expanded * B_expanded, dim=-1)

    return output

In [31]:
B = torch.rand(out_channels, rank, kernel_size)
ori_out_body = ori_body(ori_out, B)

In [32]:
print(ori_out_body)
print(ori_out_body.shape)


tensor([[[[[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]]],


         [[[0.0584, 0.8072, 0.5685, 0.8712],
           [0.1074, 0.8548, 0.2802, 0.4322]],

          [[0.3180, 0.5677, 0.5819, 0.2089],
           [0.1709, 0.7285, 0.3866, 0.3973]],

          [[0.6018, 0.6952, 0.7753, 0.8649],
           [0.6341, 1.5196, 1.2646, 0.8859]]],


         [[[0.0352, 0.5157, 0.3331, 0.6375],
           [0.3577, 0.4239, 0.3442, 0.1625]],

          [[0.1349, 0.4280, 0.2980, 0.1556],
           [0.3239, 0.4277, 0.3351, 0.1689]],

          [[0.3291, 0.2908, 0.5857, 0.5190],
           [0.5138, 1.0133, 0.7888, 0.5281]]],


         [[[0.1133, 0.8163, 1.1251, 0.7989],
           [0.6182, 0.2720, 0.7370, 0.0869]],

          [[0.5863, 0.6727, 0.3470, 0.1827],
           [0.5590, 0.3944, 0

In [33]:
def my_body(input, C, B):
    output = torch.zeros(batch_size, out_channels, height, width)
    for i in range(out_channels):
        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        conv_body = nn.Conv2d(rank, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[i].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        output[:, i] = out

    return output

In [34]:
my_out_body = my_body(input, C, B)

In [35]:
print(my_out_body)
print(my_out_body.shape)

tensor([[[[0.1658, 1.6620, 0.8488, 1.3033],
          [0.3929, 0.9397, 0.6773, 0.8000],
          [0.7315, 1.0883, 1.8622, 0.8857],
          [0.6145, 0.9710, 1.1205, 0.3396]],

         [[0.4889, 1.2962, 0.9685, 0.6061],
          [0.4588, 0.8556, 0.6331, 0.3245],
          [1.1453, 1.0671, 0.9767, 0.3353],
          [0.7510, 0.6505, 0.6231, 0.1667]],

         [[1.2358, 2.2148, 2.0399, 1.7508],
          [0.8429, 1.3041, 1.3745, 1.0471],
          [1.9841, 2.5514, 2.3203, 0.9172],
          [1.2765, 1.3025, 0.9983, 0.3580]]]], grad_fn=<CopySlices>)
torch.Size([1, 3, 4, 4])


In [36]:
def ori_tail(input, A):
    h = input.size(1) - 2*padding
    Ob = input.permute(0, 4, 2, 3, 1)
    # Add a new axis to B for broadcasting, A's shape becomes (1, 1, Cout, r, 1, d)
    A_expanded = A[None, None, :, :, None, :]

    # Assuming 'Ob' is a 5-dimensional and 'h' and 'd' are the window width and depth, respectively
    window_indices = torch.arange(start=0, end=h, step=stride)[:, None] + torch.arange(kernel_size)
    Ob_expanded = Ob[:, :, :, :, window_indices]

    # Perform the element-wise multiplication and sum over the last axis (d)
    Oa = torch.sum(Ob_expanded * A_expanded, axis=-1)

    Oa = Oa.permute(0, 4, 1, 2, 3)

    # print(Oa)

    # Step 4: Compute O
    output = torch.sum(Oa, dim=-1)

    # print(output)

    output = output.permute(0, 3, 1, 2)
    # print(output)


    return output

In [37]:
A = torch.rand(out_channels, rank, kernel_size)
ori_out_tail = ori_tail(ori_out_body, A)

In [38]:
ori_out_tail

tensor([[[[0.1501, 1.0284, 0.6259, 0.9927],
          [0.4202, 1.9828, 1.7689, 1.6817],
          [0.6004, 1.4801, 1.5663, 0.9666],
          [0.6038, 0.9097, 1.3912, 0.5852]],

         [[0.5948, 0.9537, 0.8535, 0.3682],
          [0.8837, 1.5442, 1.3040, 0.6932],
          [1.3286, 1.4310, 1.1224, 0.4507],
          [0.8326, 0.6714, 0.7885, 0.2108]],

         [[0.9234, 1.5056, 1.3871, 1.1832],
          [2.3927, 3.5863, 3.4348, 2.4604],
          [2.1241, 2.8999, 2.7304, 1.4651],
          [2.3006, 2.7939, 2.4352, 0.9714]]]])

In [39]:
def my_tail(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for i in range(out_channels):

        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        conv_body = nn.Conv2d(rank, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[i].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        conv_tail = nn.Conv2d(1, rank, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
        conv_tail.weight.data = A[i].unsqueeze(1).unsqueeze(-1)
        out = conv_tail(out)

        # print(out.shape)
        # print(out)
        out = torch.sum(out, dim=1)
        # print(out.shape)
        # print(out)


        output[:, i] = out

    return output

In [40]:
my_tail(input, C, B, A)

tensor([[[[0.5102, 2.0036, 1.2280, 1.6367],
          [1.1911, 3.6057, 3.3186, 2.9001],
          [1.5382, 2.7524, 3.0874, 1.8128],
          [1.2146, 1.8421, 2.8170, 1.2202]],

         [[0.9048, 2.0377, 1.5164, 0.8772],
          [2.0718, 3.1927, 2.5595, 1.2602],
          [2.2717, 2.5114, 2.1708, 0.8091],
          [1.8730, 1.7001, 1.5814, 0.4998]],

         [[1.7565, 2.9829, 2.8860, 2.3710],
          [4.5414, 7.1550, 6.7183, 4.7912],
          [4.2707, 5.6158, 5.2966, 2.9681],
          [4.6774, 5.7178, 5.0369, 1.9589]]]], grad_fn=<CopySlices>)

In [41]:
def my_other(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)
    for i in range(out_channels):

        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        out = torch.sum(out, dim=1)
        output[:, i] = out

    conv_body = nn.Conv2d(out_channels, rank, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
    conv_body.weight.data = B.permute(1, 0, 2).unsqueeze(2)
    output = conv_body(output)

    conv_tail = nn.Conv2d(rank, out_channels, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
    conv_tail.weight.data = A.unsqueeze(-1)
    output = conv_tail(output)

    # print(out.shape)
    # print(out)
    # out = torch.sum(out, dim=1)
    # print(out.shape)
    # print(out)

    return output

In [42]:
my_other(input, C, B, A)

tensor([[[[ 2.8290,  5.9203,  5.3343,  4.5124],
          [ 6.8386, 12.7874, 11.5765,  7.5944],
          [ 6.5085,  9.9331,  8.8002,  4.6937],
          [ 6.1629,  7.7181,  7.2002,  3.0241]],

         [[ 3.6090,  7.6436,  6.3064,  5.5358],
          [ 8.6532, 14.1680, 12.3749,  7.6686],
          [ 7.6412, 11.4409, 10.2043,  4.9115],
          [ 6.9913,  8.3687,  7.4327,  3.1170]],

         [[ 3.2049,  6.9593,  5.7525,  4.9095],
          [ 8.9991, 16.1579, 14.0614, 10.0175],
          [ 8.2643, 12.7301, 11.4814,  6.3622],
          [ 9.3103, 11.8172, 11.0133,  4.6841]]]],
       grad_fn=<ConvolutionBackward0>)

In [43]:
def my_another(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for i in range(out_channels):
        sum = torch.zeros(batch_size, 1, height, width)
        for j in range(rank):

            conv_head = nn.Conv2d(in_channels, 1, 1, padding=0, bias=False)
            conv_head.weight.data = C[:, i, j].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            out = conv_head(input)

            conv_body = nn.Conv2d(1, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
            conv_body.weight.data = B[i, j].unsqueeze(0).unsqueeze(0).unsqueeze(0)
            out = conv_body(out)

            conv_tail = nn.Conv2d(1, 1, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
            conv_tail.weight.data = A[i, j].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
            out = conv_tail(out)                      

            sum += out

        output[:, i] = sum

    return output

In [44]:
my_another(input, C, B, A)

tensor([[[[0.1501, 1.0284, 0.6259, 0.9927],
          [0.4202, 1.9828, 1.7689, 1.6817],
          [0.6004, 1.4801, 1.5663, 0.9666],
          [0.6038, 0.9097, 1.3912, 0.5852]],

         [[0.5948, 0.9537, 0.8535, 0.3682],
          [0.8837, 1.5442, 1.3040, 0.6932],
          [1.3286, 1.4310, 1.1224, 0.4507],
          [0.8326, 0.6714, 0.7885, 0.2108]],

         [[0.9234, 1.5056, 1.3871, 1.1832],
          [2.3927, 3.5863, 3.4348, 2.4604],
          [2.1241, 2.8999, 2.7304, 1.4651],
          [2.3006, 2.7939, 2.4352, 0.9714]]]], grad_fn=<CopySlices>)

In [45]:
def my_func1(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for j in range(rank):

        conv_head = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, :, j].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)

        conv_body = nn.Conv2d(out_channels, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[:, j].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        conv_tail = nn.Conv2d(1, out_channels, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
        conv_tail.weight.data = A[:, j].unsqueeze(1).unsqueeze(-1)
        out = conv_tail(out)

        output += out

    return output

In [46]:
my_func1(input, C, A, B)

tensor([[[[0.8323, 1.7989, 1.5240, 1.1652],
          [3.4157, 6.0044, 6.0656, 4.2418],
          [2.5548, 4.1356, 4.1657, 2.4511],
          [2.9785, 4.1634, 4.1622, 1.8136]],

         [[1.9587, 4.0324, 3.6233, 3.1005],
          [4.9612, 8.3164, 7.8639, 4.8326],
          [3.9593, 5.6802, 5.2868, 2.5720],
          [2.3324, 3.1258, 3.3784, 1.3571]],

         [[2.6241, 5.2695, 4.7364, 4.0373],
          [4.0072, 7.2106, 7.1663, 5.1479],
          [4.9980, 7.5350, 7.2089, 3.5260],
          [3.7216, 4.7828, 4.5968, 1.8483]]]], grad_fn=<AddBackward0>)