In [1]:
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
A = torch.arange(0, 16, 1, dtype=torch.long)

In [3]:
A

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [4]:
shapes = [2 for i in range(int(np.log2(len(A))))]

In [5]:
shapes

[2, 2, 2, 2]

In [6]:
A0 = A.view(*shapes)
A0

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]]]])

In [7]:
# A1 = A0.transpose(-1, -2)
# A1

In [8]:
A1 = A0.permute(0, 1, 3, 2)
A1

tensor([[[[ 0,  2],
          [ 1,  3]],

         [[ 4,  6],
          [ 5,  7]]],


        [[[ 8, 10],
          [ 9, 11]],

         [[12, 14],
          [13, 15]]]])

In [9]:
A2 = A0.permute(0, 2, 3, 1)
A2

tensor([[[[ 0,  4],
          [ 1,  5]],

         [[ 2,  6],
          [ 3,  7]]],


        [[[ 8, 12],
          [ 9, 13]],

         [[10, 14],
          [11, 15]]]])

In [10]:
A3 = A0.permute(1, 2, 3, 0)
A3

tensor([[[[ 0,  8],
          [ 1,  9]],

         [[ 2, 10],
          [ 3, 11]]],


        [[[ 4, 12],
          [ 5, 13]],

         [[ 6, 14],
          [ 7, 15]]]])

### Continuous permute

In [11]:
A0 = A.view(*shapes)
A0

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]]]])

In [12]:
A1 = A0.permute(0,1,3,2)
A1

tensor([[[[ 0,  2],
          [ 1,  3]],

         [[ 4,  6],
          [ 5,  7]]],


        [[[ 8, 10],
          [ 9, 11]],

         [[12, 14],
          [13, 15]]]])

In [13]:
A2 = A1.permute(0,3,2,1)
A2

tensor([[[[ 0,  4],
          [ 1,  5]],

         [[ 2,  6],
          [ 3,  7]]],


        [[[ 8, 12],
          [ 9, 13]],

         [[10, 14],
          [11, 15]]]])

In [14]:
A3 = A2.permute(3,1,2,0)
A3

tensor([[[[ 0,  8],
          [ 1,  9]],

         [[ 2, 10],
          [ 3, 11]]],


        [[[ 4, 12],
          [ 5, 13]],

         [[ 6, 14],
          [ 7, 15]]]])

### For BIG Dimension

In [15]:
A = torch.arange(0, 64, 1, dtype=torch.long)

In [16]:
shapes = [2 for i in range(int(np.log2(len(A))))]

In [17]:
A0 = A.view(*shapes)
A0 ## this has Even/Odd -> 0,2,4,6 / 1,3,5,7 ...
### the structure is grouped by N/2, N/4, 2 (N/8) 

tensor([[[[[[ 0,  1],
            [ 2,  3]],

           [[ 4,  5],
            [ 6,  7]]],


          [[[ 8,  9],
            [10, 11]],

           [[12, 13],
            [14, 15]]]],



         [[[[16, 17],
            [18, 19]],

           [[20, 21],
            [22, 23]]],


          [[[24, 25],
            [26, 27]],

           [[28, 29],
            [30, 31]]]]],




        [[[[[32, 33],
            [34, 35]],

           [[36, 37],
            [38, 39]]],


          [[[40, 41],
            [42, 43]],

           [[44, 45],
            [46, 47]]]],



         [[[[48, 49],
            [50, 51]],

           [[52, 53],
            [54, 55]]],


          [[[56, 57],
            [58, 59]],

           [[60, 61],
            [62, 63]]]]]])

In [18]:
len(A0.shape)

6

## How to generalize the permutation

In [19]:
A = torch.arange(0, 32, 1, dtype=torch.long)

In [20]:
A0 = A.reshape(-1,2,1).permute(0, 2,1)

In [21]:
A1 = A.reshape(-1,2,2).permute(0, 2,1)
A1

tensor([[[ 0,  2],
         [ 1,  3]],

        [[ 4,  6],
         [ 5,  7]],

        [[ 8, 10],
         [ 9, 11]],

        [[12, 14],
         [13, 15]],

        [[16, 18],
         [17, 19]],

        [[20, 22],
         [21, 23]],

        [[24, 26],
         [25, 27]],

        [[28, 30],
         [29, 31]]])

In [22]:
A2 = A.reshape(-1, 2, 4).permute(0, 2, 1) ## making pair at 3 and 1 dims
A2

tensor([[[ 0,  4],
         [ 1,  5],
         [ 2,  6],
         [ 3,  7]],

        [[ 8, 12],
         [ 9, 13],
         [10, 14],
         [11, 15]],

        [[16, 20],
         [17, 21],
         [18, 22],
         [19, 23]],

        [[24, 28],
         [25, 29],
         [26, 30],
         [27, 31]]])

In [23]:
A3 = A.reshape(-1, 2, 8).permute(0, 2, 1)
A3

tensor([[[ 0,  8],
         [ 1,  9],
         [ 2, 10],
         [ 3, 11],
         [ 4, 12],
         [ 5, 13],
         [ 6, 14],
         [ 7, 15]],

        [[16, 24],
         [17, 25],
         [18, 26],
         [19, 27],
         [20, 28],
         [21, 29],
         [22, 30],
         [23, 31]]])

In [24]:
A4 = A.reshape(-1, 2, 16).permute(0, 2, 1)
A4

tensor([[[ 0, 16],
         [ 1, 17],
         [ 2, 18],
         [ 3, 19],
         [ 4, 20],
         [ 5, 21],
         [ 6, 22],
         [ 7, 23],
         [ 8, 24],
         [ 9, 25],
         [10, 26],
         [11, 27],
         [12, 28],
         [13, 29],
         [14, 30],
         [15, 31]]])

In [25]:
A5 = A.reshape(-1, 2, 32).permute(0, 2, 1)
A5

RuntimeError: shape '[-1, 2, 32]' is invalid for input of size 32

## For N != power of 2

In [None]:
A = torch.arange(0, 20, 1, dtype=torch.long)

In [None]:
A

In [None]:
# shapes = [2 for i in range(int(np.log2(len(A))))]

In [None]:
s = len(A)
s

In [None]:
A1 = A.reshape(-1,2,2).permute(0, 2,1)
A1

In [None]:
A2 = A.reshape(-1, 2, 4).permute(0, 2, 1) ## making pair at 3 and 1 dims
A2

In [None]:
A3 = A.reshape(-1, 2, 8).permute(0, 2, 1)
A3

In [None]:
A4 = A.reshape(-1, 2, 16).permute(0, 2, 1)
A4

In [None]:
A5 = A.reshape(-1, 2, 32).permute(0, 2, 1)
A5

## Formula for permutation values 

In [None]:
### Source/Inspiration : https://github.com/roguh/cuda-fft/blob/main/main.cu

In [None]:
ti = torch.arange(0, 8, 1, dtype=torch.long) ## for 16 dim, 8 threads
ti

In [None]:
for l in range(4):
    print(f"layer: {l}")
    gap = 1 << l
    print(f"gap: {gap}")
    
    index = ti%gap
    print(f"index0: {index}")
    
    
    pindex = (ti//gap)*(1<<(l+1))
    print(f"index1: {pindex}")
    
    print()
    print(index+pindex)
    print(index+pindex+gap)
    
    print()

## For Radix-4 FFT

In [26]:
## FFT permutation

A = torch.arange(0, 32, 1, dtype=torch.long)

In [27]:
A0 = A.reshape(-1,4,1).permute(0, 2,1)
A0

tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]],

        [[12, 13, 14, 15]],

        [[16, 17, 18, 19]],

        [[20, 21, 22, 23]],

        [[24, 25, 26, 27]],

        [[28, 29, 30, 31]]])

In [28]:
A1 = A.reshape(-1,4,4).permute(0, 2,1)
A1

tensor([[[ 0,  4,  8, 12],
         [ 1,  5,  9, 13],
         [ 2,  6, 10, 14],
         [ 3,  7, 11, 15]],

        [[16, 20, 24, 28],
         [17, 21, 25, 29],
         [18, 22, 26, 30],
         [19, 23, 27, 31]]])

In [29]:
A2 = A.reshape(-1,4,16).permute(0, 2,1)
A2

RuntimeError: shape '[-1, 4, 16]' is invalid for input of size 32