## USE SAGEMATH TO RUN THIS CODE

In [1]:
class FFT:
    def generator(self, n):
        return exp(-2*I*pi / (3*n));
        #return CyclotomicField(3*n).gen()
    
    def __init__(self, n):
        self.n = n

        fac = factor(self.n)

        if len(fac) == 2:
            if fac[0][0] == 2 and fac[1][0] == 3:
                self.Radix2 = fac[0][1]
                self.Radix3 = fac[1][1]
            else:
                raise ValueError(f"unsupported n");
        elif len(fac) == 1:
            if fac[0][0] == 2:
                self.Radix2 = fac[0][1]
                self.Radix3 = 0
            else:
                raise ValueError(f"unsupported n");

        self.level = self.Radix2 + self.Radix3

        self.w = self.generator(n);
        self.omega = self.w ** self.n;

        self._build_tree()
        
        self.z     = self.w**(self.n // 2)
        self.zminusz5inv = 1/(self.z - self.z**5)
        self.level1 = 1/self.n
        self.level2 = 2/self.n
        
    def _build_tree(self):

        self.tree = zero_matrix(ZZ,self.level+1,self.n);
        self.tree[0,0] = 3 * self.n

        self.zetas = [self.w**0];
        
        # Radix-2 with cyclotomic
        self.tree[1,0] = self.tree[0,0] // 6
        self.tree[1,1] = 5*self.tree[0,0] // 6

        x = self.w ** (self.tree[1,0])
        self.zetas.append(x)

        # Radix-3 part
        for ll in range(1, self.Radix3+1):
            for ii in range(2*3**(ll-1)):
                self.tree[ll+1,3*ii  ] = self.tree[ll,ii] // 3
                self.tree[ll+1,3*ii+1] = self.tree[ll+1,3*ii  ] + self.tree[0,0] // 3
                self.tree[ll+1,3*ii+2] = self.tree[ll+1,3*ii  ] + 2*self.tree[0,0] // 3

                x = self.w ** (self.tree[ll+1,3*ii])
                self.zetas.append(x)
                x = self.w ** (2*self.tree[ll+1,3*ii])
                self.zetas.append(x)

        # Radix-2 part
        for ll in range(self.Radix3+1, self.level):
            for ii in range(2*3**(self.Radix3)*2**(ll-(self.Radix3+1))):
                self.tree[ll+1,2*ii  ] = self.tree[ll,ii] // 2
                self.tree[ll+1,2*ii+1] = self.tree[ll,ii] // 2 + self.tree[0,0] // 2

                x = self.w ** (self.tree[ll+1,2*ii])
                self.zetas.append(x)

    def get_zetas(self):
        return self.zetas

    def fft(self, a):
        k = 1;
        
        b = a.copy();

        zeta1 = self.zetas[k];
        k += 1;

        for ii in range(0, self.n // 2):
            t1 = zeta1 * b[ii + self.n // 2];

            b[ii + self.n // 2] = b[ii] + b[ii + self.n // 2] - t1;
            b[ii              ] = b[ii]                       + t1;

        step = self.n // 6;
        
        while step >= 2**(self.Radix2-1):
            for start in range(0,self.n,3*step):
                zeta1 = self.zetas[k];
                zeta2 = self.zetas[k+1];
                k += 2;

                for ii in range(start,start+step):
                    t1 = zeta1 * b[ii+  step];
                    t2 = zeta2 * b[ii+2*step];
                    t3 = self.omega * (t1 - t2);

                    b[ii+2*step] = b[ii] - t1 - t3;
                    b[ii+  step] = b[ii] - t2 + t3;
                    b[ii       ] = b[ii] + t1 + t2;

            step = step // 3;

        step = 2^(self.Radix2-2)
        while step >= 1:
            for start in range(0,self.n,(step << 1)):
                zeta1 = self.zetas[k];
                k += 1;

                for ii in range(start,start+step):
                    t1 = zeta1 * b[ii+step];

                    b[ii+step] = b[ii] - t1;
                    b[ii     ] = b[ii] + t1;

            step >>= 1;
            
        return b;

    def invfft(self, a):
        k = self.n - 1;
        
        b = a.copy();

        step = 1;
        
        while step < 2**(self.Radix2-1):
            for start in range(0,self.n,(step << 1)):
                zeta1 = self.zetas[k];
                k -= 1;

                for ii in range(start,start+step):
                    t1 = b[ii+step];

                    b[ii+step] = (t1 - b[ii]) * zeta1;
                    b[ii     ] = t1 + b[ii]

            step <<= 1;

        while step <= self.n // 6:  
            for start in range(0,self.n,3*step):
                zeta2 = self.zetas[k];
                k = k - 1;
                zeta1 = self.zetas[k];
                k = k - 1;

                for ii in range(start,start+step):
                    t1 = self.omega * (b[ii+  step] - b[ii]);
                    t2 = zeta1      * (b[ii+2*step] - b[ii]      + t1);
                    t3 = zeta2      * (b[ii+2*step] - b[ii+step] - t1);

                    b[ii       ] = b[ii] + b[ii + step] + b[ii + 2*step];
                    b[ii+  step] = t2;
                    b[ii+2*step] = t3;
                    
            step = step*3;

        for ii in range(0, self.n // 2):
            t1 = b[ii] + b[ii + self.n // 2];
            t2 = self.zminusz5inv * (b[ii] - b[ii + self.n // 2]);

            b[ii              ] = self.level1 * (t1 - t2)
            b[ii + self.n // 2] = self.level2 * t2

        return b;



In [25]:
n = 36

In [26]:
fft = FFT(n)

In [28]:
# a = [1 + 4*I, -2-3*I,  5 + I,  2 + 5*I, -1 + 5*I,  1 + 2*I]
a = [ZZ.random_element(-10, 10) + ZZ.random_element(-10, 10) * I for _ in range(n)]
a[:5]

[I - 4, -2*I + 8, 5*I + 2, 5*I - 3, 7*I + 7]

In [30]:
b = fft.fft(a)

In [36]:
for i in range(min(5, n)):
    print(b[i].n(50))

-15.560225576235 - 52.795554650809*I
-0.83198514711611 + 34.111459286328*I
-48.587730720853 + 24.271432803538*I
-33.283953462789 + 9.5235283995068*I
-25.816848428923 - 17.326376626200*I


In [37]:
c = fft.invfft(b)

In [38]:
# c = [x.n(20) for x in c]

In [40]:
for i in range(min(5, n)):
    print(c[i].n(50))

-4.0000000000000 + 1.0000000000000*I
8.0000000000000 - 2.0000000000000*I
2.0000000000000 + 5.0000000000000*I
-3.0000000000000 + 5.0000000000000*I
7.0000000000000 + 7.0000000000000*I


In [43]:
errs = 0
for i in range(n):
    errs += abs(a[i].n(20) - c[i].n(20))**2
errs /= n
errs

1.2863e-9

In [10]:
n = 12

In [11]:
fft = FFT(n)

In [12]:
a = [1-1*I, -2-2*I,  5+2*I,  2+2*I, -1-3*I,  1+0*I,  4-1*I, -3-4*I,  1+2*I,
  5+0*I,  5-4*I,  2-1*I]

In [13]:
b = fft.fft(a)

In [14]:
for i in range(n):
    print(b[i].n(100))

-3.1927747667350385707935270026 - 18.059800857368096504325053802*I
7.5084997625864529933622210201 - 4.5744396665045885487586764590*I
1.8859607012163976486668818778 - 6.4705661033857159295402334236*I
11.739866201100431674178616486 - 23.133265297779972672569458405*I
6.0509847769879228369060800288 + 13.496913732441437296312173183*I
-11.188689097862798462902611435 + 8.9565485017704088365518928091*I
-5.3641328528438424912265732808 - 9.7645507422521037537513163157*I
10.530933013620096113166926852 + 9.5832811709911882835802522511*I
-9.7953754357861855142017310331 + 5.7263435401927713912957073968*I
7.1891488929966082367672615882 - 12.887268714843872474324378949*I
7.5769077484939183141681421803 + 5.3479718798026081086580342897*I
13.058671056226037221908312718 + 13.778832556935935966871057426*I


In [15]:
c = fft.invfft(b)

In [16]:
c = [x.n(20) for x in c]

In [17]:
c

[1.0000 - 1.0000*I,
 -2.0000 - 2.0000*I,
 5.0000 + 2.0000*I,
 2.0000 + 2.0000*I,
 -1.0000 - 3.0000*I,
 1.0000 + 3.8147e-6*I,
 4.0000 - 1.0000*I,
 -3.0000 - 4.0000*I,
 1.0000 + 2.0000*I,
 5.0000 - 8.2591e-6*I,
 5.0000 - 4.0000*I,
 2.0000 - 1.0000*I]