In [1]:
import torch
import torch.nn as nn

In [2]:
batch_size = 1
in_channels = 2
out_channels = 3
height = 4
width = 4
rank = 2
padding = 1
stride = 1
kernel_size = 3

In [3]:
def ori_calcul(input, C):
    padded_I = nn.functional.pad(input, pad=[padding]*4)
    padded_I = padded_I.permute(0, 2, 3, 1)
    padded_h = padded_I.shape[1]
    padded_w = padded_I.shape[2]
    padded_I_col = padded_I.reshape(batch_size * padded_h * padded_w, in_channels)
    weight_col = C.reshape(in_channels, out_channels * rank)
    output = torch.matmul(padded_I_col, weight_col).reshape(batch_size, padded_h, padded_w, out_channels, rank)

    return output

In [4]:
def my_calcul(input, C):
    output = torch.zeros(batch_size, height+2*padding, width+2*padding, out_channels, rank)
    for i in range(out_channels):
        conv = nn.Conv2d(in_channels, rank, 1, padding=padding, bias=False)
        conv.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv(input)
        output[:, :, :, i] = out.permute(0, 2, 3, 1)
    
    return output


In [5]:
input = torch.rand(batch_size, in_channels, height, width)
C = torch.rand(in_channels, out_channels, rank)

ori_out = ori_calcul(input, C)
my_out = my_calcul(input, C).detach().clone()
print(torch.all(torch.lt(torch.abs(torch.add(ori_out, -my_out)), 1e-12)))
print(torch.eq(ori_out, my_out))

tensor(True)
tensor([[[[[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]]],


         [[[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]],

          [[True, True],
           [True, True],
           [True, True]]],


         [[[True, True],
           [True, True],
         

In [6]:
ori_out.shape

torch.Size([1, 6, 6, 3, 2])

In [8]:
def my_head(input, C):
    output = torch.zeros(batch_size, height+2*padding, width+2*padding, out_channels, rank)
    conv = nn.Conv2d(in_channels, out_channels*rank, 1, padding=padding, bias=False)
    conv.weight.data = C.reshape(in_channels, out_channels*rank).permute(1, 0).unsqueeze(-1).unsqueeze(-1)
    out = conv(input)
    out = out.reshape(batch_size, out_channels, rank, height+2*padding, width+2*padding)
    output = out.permute(0, 3, 4, 1, 2)

    return output


In [9]:
my_out

tensor([[[[[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]]],


         [[[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0508, 0.7597],
           [0.2611, 0.4634],
           [0.2565, 0.0368]],

          [[0.0619, 0.6816],
           [0.3019, 0.4350],
           [0.2953, 0.0411]],

          [[0.0807, 0.9394],
           [0.3970, 0.5941],
           [0.3886, 0.0544]],

          [[0.1108, 0.6109],
           [0.4994, 0.4551],
           [0.4851, 0.0640]],

          [[0.0000,

In [10]:
head_output = my_head(input, C)
print(head_output)
print(torch.eq(ori_out, head_output))

tensor([[[[[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]]],


         [[[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]],

          [[0.0508, 0.7597],
           [0.2611, 0.4634],
           [0.2565, 0.0368]],

          [[0.0619, 0.6816],
           [0.3019, 0.4350],
           [0.2953, 0.0411]],

          [[0.0807, 0.9394],
           [0.3970, 0.5941],
           [0.3886, 0.0544]],

          [[0.1108, 0.6109],
           [0.4994, 0.4551],
           [0.4851, 0.0640]],

          [[0.0000,

In [103]:
def ori_body(input, B):
    w = input.size(2) - 2*padding
    Oc = input.permute(0, 1, 3, 4, 2)
    # Add a new axis to B for broadcasting, B's shape becomes (1, 1, Cout, r, 1, d)
    B_expanded = B[None, None, :, :, None, :]
    # Assuming 'Oc' is a 5-dimensional and 'w' and 'd' are the window width and depth, respectively
    window_indices = torch.arange(start=0, end=w, step=stride)[:, None] + torch.arange(kernel_size)
    print(window_indices.shape)
    Oc_expanded = Oc[:, :, :, :, window_indices]
    print(Oc_expanded.shape)
    print(B_expanded.shape)
    print((Oc_expanded*B_expanded).shape)

    # Perform the element-wise multiplication and sum over the last axis (d)
    output = torch.sum(Oc_expanded * B_expanded, dim=-1)

    return output

In [104]:
B = torch.rand(out_channels, rank, kernel_size)
ori_out_body = ori_body(ori_out, B)

torch.Size([4, 3])
torch.Size([1, 6, 3, 2, 4, 3])
torch.Size([1, 1, 3, 2, 1, 3])
torch.Size([1, 6, 3, 2, 4, 3])


In [105]:
print(ori_out_body)
print(ori_out_body.shape)


tensor([[[[[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000]]],


         [[[0.0651, 0.1046, 0.1396, 0.0542],
           [0.2413, 0.4901, 0.4289, 0.3241]],

          [[0.0896, 0.2515, 0.3070, 0.3822],
           [0.3641, 0.4279, 0.4391, 0.2073]],

          [[0.3476, 0.6081, 0.7558, 0.6180],
           [0.0240, 0.0447, 0.0519, 0.0195]]],


         [[[0.0879, 0.0978, 0.1347, 0.0524],
           [0.2957, 0.4828, 0.3676, 0.2078]],

          [[0.0304, 0.1927, 0.3992, 0.3493],
           [0.4205, 0.4331, 0.3009, 0.1131]],

          [[0.3021, 0.5916, 0.8061, 0.5391],
           [0.0360, 0.0364, 0.0493, 0.0191]]],


         [[[0.1103, 0.1171, 0.1494, 0.0584],
           [0.2832, 0.3277, 0.4479, 0.1536]],

          [[0.0357, 0.2318, 0.4672, 0.3785],
           [0.3935, 0.3795, 0

In [106]:
def my_body(input, C, B):
    output = torch.zeros(batch_size, out_channels, height, width)
    for i in range(out_channels):
        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        conv_body = nn.Conv2d(rank, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[i].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        output[:, i] = out

    return output

In [107]:
my_out_body = my_body(input, C, B)

In [108]:
print(my_out_body)
print(my_out_body.shape)

tensor([[[[0.3064, 0.5947, 0.5686, 0.3784],
          [0.3837, 0.5806, 0.5023, 0.2603],
          [0.3935, 0.4448, 0.5973, 0.2120],
          [0.2258, 0.2959, 0.5000, 0.2813]],

         [[0.4537, 0.6795, 0.7461, 0.5895],
          [0.4509, 0.6258, 0.7002, 0.4625],
          [0.4292, 0.6113, 0.7851, 0.5952],
          [0.3713, 0.6732, 0.7579, 0.5234]],

         [[0.3716, 0.6528, 0.8077, 0.6375],
          [0.3380, 0.6280, 0.8554, 0.5582],
          [0.4075, 0.7172, 0.9785, 0.6070],
          [0.6389, 0.7969, 0.8185, 0.4569]]]], grad_fn=<CopySlices>)
torch.Size([1, 3, 4, 4])


In [109]:
def test_body_d(input, C, B):
    conv = nn.Conv2d(in_channels, out_channels*rank, 1, padding=padding, bias=False)
    conv.weight.data = C.reshape(in_channels, out_channels*rank).permute(1, 0).unsqueeze(-1).unsqueeze(-1)
    out = conv(input)

    out = out.unsqueeze(1)
    conv_body = nn.Conv3d(1, out_channels*rank, kernel_size, 1, padding=1, bias=False)
    conv_body.weight.data = B.reshape(out_channels*rank, kernel_size).permute(1, 0).unsqueeze(-1).unsqueeze(-1)
    out = conv_body(out)

    return out

In [110]:
test_body_d(input, C, B)

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5765, 0.6112, 0.8158, 0.8656, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.3346, 0.8899, 0.7524, 0.5976, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.2700, 1.0619, 0.6690, 0.7970, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.6699, 0.8508, 0.3570, 0.9570, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5612, 0.5825, 0.7806, 0.7895, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.3692, 0.8357, 0.6996, 0.5285, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.2870, 0.9901, 0.5865, 0.7452, 0.0000, 0.0000],
          

In [111]:
def ori_tail(input, A):
    h = input.size(1) - 2*padding
    Ob = input.permute(0, 4, 2, 3, 1)
    # Add a new axis to B for broadcasting, A's shape becomes (1, 1, Cout, r, 1, d)
    A_expanded = A[None, None, :, :, None, :]

    # Assuming 'Ob' is a 5-dimensional and 'h' and 'd' are the window width and depth, respectively
    window_indices = torch.arange(start=0, end=h, step=stride)[:, None] + torch.arange(kernel_size)
    Ob_expanded = Ob[:, :, :, :, window_indices]

    # Perform the element-wise multiplication and sum over the last axis (d)
    Oa = torch.sum(Ob_expanded * A_expanded, axis=-1)

    Oa = Oa.permute(0, 4, 1, 2, 3)

    # print(Oa)

    # Step 4: Compute O
    output = torch.sum(Oa, dim=-1)

    # print(output)

    output = output.permute(0, 3, 1, 2)
    # print(output)


    return output

In [112]:
A = torch.rand(out_channels, rank, kernel_size)
ori_out_tail = ori_tail(ori_out_body, A)

In [113]:
ori_out_tail

tensor([[[[0.2911, 0.5189, 0.4864, 0.3189],
          [0.3858, 0.5787, 0.5232, 0.2741],
          [0.3797, 0.4290, 0.5659, 0.2108],
          [0.1319, 0.2191, 0.3594, 0.2203]],

         [[0.2898, 0.4234, 0.5110, 0.3431],
          [0.5886, 0.8285, 0.9320, 0.6531],
          [0.6056, 0.8541, 0.8650, 0.5699],
          [0.3927, 0.5159, 0.6416, 0.4719]],

         [[0.1915, 0.3474, 0.4677, 0.3028],
          [0.2426, 0.4124, 0.5611, 0.3444],
          [0.3586, 0.4514, 0.4929, 0.2691],
          [0.0525, 0.0645, 0.0837, 0.0399]]]])

In [114]:
def my_tail(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for i in range(out_channels):

        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        conv_body = nn.Conv2d(rank, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[i].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        conv_tail = nn.Conv2d(1, rank, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
        conv_tail.weight.data = A[i].unsqueeze(1).unsqueeze(-1)
        out = conv_tail(out)

        # print(out.shape)
        # print(out)
        out = torch.sum(out, dim=1)
        # print(out.shape)
        # print(out)


        output[:, i] = out

    return output

In [115]:
my_tail(input, C, B, A)

tensor([[[[0.5819, 1.0000, 0.9144, 0.5501],
          [0.7100, 0.9770, 1.0217, 0.4662],
          [0.5993, 0.7335, 1.0207, 0.4576],
          [0.2700, 0.3422, 0.5531, 0.2910]],

         [[0.5954, 0.8540, 0.9476, 0.6796],
          [1.0901, 1.5810, 1.8300, 1.3763],
          [1.0307, 1.5595, 1.8045, 1.2518],
          [0.6913, 1.0648, 1.3082, 0.9633]],

         [[0.5250, 0.9631, 1.2850, 0.8748],
          [0.6709, 1.1914, 1.6084, 1.0341],
          [0.9628, 1.3107, 1.4677, 0.8554],
          [0.2912, 0.4048, 0.4640, 0.2711]]]], grad_fn=<CopySlices>)

In [116]:
def my_other(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)
    for i in range(out_channels):

        conv_head = nn.Conv2d(in_channels, rank, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, i, :].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)
        out = torch.sum(out, dim=1)
        output[:, i] = out

    conv_body = nn.Conv2d(out_channels, rank, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
    conv_body.weight.data = B.permute(1, 0, 2).unsqueeze(2)
    output = conv_body(output)

    conv_tail = nn.Conv2d(rank, out_channels, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
    conv_tail.weight.data = A.unsqueeze(-1)
    output = conv_tail(output)

    # print(out.shape)
    # print(out)
    # out = torch.sum(out, dim=1)
    # print(out.shape)
    # print(out)

    return output

In [117]:
my_other(input, C, B, A)

tensor([[[[2.2111, 3.5042, 3.6479, 2.2126],
          [2.6232, 3.4829, 4.1524, 2.1182],
          [2.4911, 3.3037, 4.0373, 2.0771],
          [1.2614, 1.3922, 1.9186, 0.9906]],

         [[1.7439, 2.7175, 2.7961, 1.6800],
          [3.2875, 4.8101, 5.4602, 3.1393],
          [3.2956, 4.5396, 5.1281, 2.7538],
          [2.2895, 2.6967, 3.6302, 1.8064]],

         [[2.0020, 2.9269, 2.9264, 1.6439],
          [2.3810, 2.9948, 3.6546, 1.8197],
          [2.1374, 2.6827, 3.4025, 1.7800],
          [0.6455, 0.7104, 0.9502, 0.4761]]]], grad_fn=<ConvolutionBackward0>)

In [118]:
def my_another(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for i in range(out_channels):
        sum = torch.zeros(batch_size, 1, height, width)
        for j in range(rank):

            conv_head = nn.Conv2d(in_channels, 1, 1, padding=0, bias=False)
            conv_head.weight.data = C[:, i, j].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            out = conv_head(input)

            conv_body = nn.Conv2d(1, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
            conv_body.weight.data = B[i, j].unsqueeze(0).unsqueeze(0).unsqueeze(0)
            out = conv_body(out)

            conv_tail = nn.Conv2d(1, 1, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
            conv_tail.weight.data = A[i, j].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
            out = conv_tail(out)                      

            sum += out

        output[:, i] = sum

    return output

In [119]:
my_another(input, C, B, A)

tensor([[[[0.2911, 0.5189, 0.4864, 0.3189],
          [0.3858, 0.5787, 0.5232, 0.2741],
          [0.3797, 0.4290, 0.5659, 0.2108],
          [0.1319, 0.2191, 0.3594, 0.2203]],

         [[0.2898, 0.4234, 0.5110, 0.3431],
          [0.5886, 0.8285, 0.9320, 0.6531],
          [0.6056, 0.8541, 0.8650, 0.5699],
          [0.3927, 0.5159, 0.6416, 0.4719]],

         [[0.1915, 0.3474, 0.4677, 0.3028],
          [0.2426, 0.4124, 0.5611, 0.3444],
          [0.3586, 0.4514, 0.4929, 0.2691],
          [0.0525, 0.0645, 0.0837, 0.0399]]]], grad_fn=<CopySlices>)

In [120]:
def test_reshape(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    conv_head = nn.Conv2d(in_channels, out_channels*rank, 1, padding=0, bias=False)
    conv_head.weight.data = C.reshape(in_channels, out_channels*rank).permute(1, 0).unsqueeze(-1).unsqueeze(-1)
    out = conv_head(input)

    conv_body = nn.Conv2d(out_channels*rank, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
    conv_body.weight.data = B.reshape(out_channels*rank, kernel_size).unsqueeze(0).unsqueeze(2)
    out = conv_body(out)

    conv_tail = nn.Conv2d(1, out_channels*rank, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
    conv_tail.weight.data = A.reshape(out_channels*rank, kernel_size).unsqueeze(1).unsqueeze(-1)
    out = conv_tail(out)

    # conv_body = nn.Conv2d(out_channels*rank, 1, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
    # conv_body.weight.data = B.reshape(out_channels*rank, kernel_size).unsqueeze(0).unsqueeze(-1)
    # out = conv_body(out)

    # conv_tail = nn.Conv2d(1, out_channels*rank, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
    # conv_tail.weight.data = A.reshape(out_channels*rank, kernel_size).unsqueeze(1).unsqueeze(2)
    # out = conv_tail(out)

    x = out.reshape(batch_size, out_channels, rank, height, width)


    output = torch.sum(x, dim=2)

    return output

In [121]:
test_reshape(input, C, B, A)

tensor([[[[1.9564, 3.2032, 3.5576, 2.4735],
          [2.2237, 3.3849, 4.0840, 2.5448],
          [2.2874, 3.3084, 4.1252, 2.4930],
          [1.3271, 1.8986, 2.2773, 1.3804]],

         [[1.5216, 2.4673, 2.7460, 1.8671],
          [2.8671, 4.5471, 5.3395, 3.6027],
          [2.9498, 4.4030, 5.2183, 3.1939],
          [2.0768, 2.9850, 3.8214, 2.2990]],

         [[1.7696, 2.8197, 3.1503, 2.0543],
          [2.0773, 3.1007, 3.9135, 2.4267],
          [2.1113, 3.0535, 3.6611, 2.2239],
          [0.6514, 0.9338, 1.1529, 0.6965]]]], grad_fn=<SumBackward1>)

In [122]:
def my_func1(input, C, B, A):
    output = torch.zeros(batch_size, out_channels, height, width)

    for j in range(rank):

        conv_head = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False)
        conv_head.weight.data = C[:, :, j].permute(1, 0).unsqueeze(-1).unsqueeze(-1)
        out = conv_head(input)

        conv_body = nn.Conv2d(out_channels, 1, kernel_size=(1, kernel_size), padding=(0, padding), bias=False)
        conv_body.weight.data = B[:, j].unsqueeze(0).unsqueeze(2)
        out = conv_body(out)

        conv_tail = nn.Conv2d(1, out_channels, kernel_size=(kernel_size, 1), padding=(padding, 0), bias=False)
        conv_tail.weight.data = A[:, j].unsqueeze(1).unsqueeze(-1)
        out = conv_tail(out)

        output += out

    return output

In [123]:
my_func1(input, C, A, B)

tensor([[[[0.9362, 1.1645, 1.3074, 0.6042],
          [1.4696, 1.8392, 2.0636, 1.0655],
          [1.5086, 1.5645, 1.9697, 0.9215],
          [0.7116, 0.8367, 0.9452, 0.5844]],

         [[1.0254, 1.4558, 1.4889, 0.8978],
          [1.2941, 1.8541, 1.7673, 0.9926],
          [1.1364, 1.4900, 1.6468, 1.2225],
          [0.8168, 0.8580, 1.2344, 0.8054]],

         [[1.2643, 1.7145, 1.7769, 0.7809],
          [1.9359, 2.5930, 2.7964, 1.5856],
          [1.9326, 2.1265, 2.7178, 1.5271],
          [1.3022, 1.3232, 1.6107, 0.6862]]]], grad_fn=<AddBackward0>)