import pyfftw
import numpy as np
import scipy
import time
from numba import njit


def naive(F):
    n = len(F)
    irfft_obj = pyfftw.builders.irfft(np.empty(n, dtype=np.complex128))
    irfft_obj(F)
    return irfft_obj.output_array
            
def performant(F, irfft_dict):
    n = len(F)
    irfft_obj = irfft_dict[n]
    irfft_obj(F)
    return irfft_obj.output_array


def performant_optimized(F, irfft_dict, input_arr):
    n = len(F)
    np.multiply(F, 1.0 / (2 * (n - 1)), out=input_arr[:n])
    irfft_obj = irfft_dict[n]
    irfft_obj.execute()
    return irfft_obj.output_array



@njit(fastmath=True)
def update_arr_inplace(out, a, scale=1.0):
    for i in range(len(out)):
        out[i] = a[i] * scale
    return


def performant_optimized_njit(F, irfft_dict, input_arr):
    n = len(F)
    update_arr_inplace(input_arr[:n], F, scale=1.0 / (2 * (n - 1)))
    irfft_obj = irfft_dict[n]
    irfft_obj.execute()
    
    return irfft_obj.output_array


def pyfftw_timing_func():
    
    p_min, p_max = 2, 20
    p_range = range(p_min, p_max + 1)
    
    n_max = 2 ** p_max

    irfft_dict = {}
    input_arr = pyfftw.empty_aligned(2 ** p_max, dtype=np.complex128)
    for i, p in enumerate(p_range):
        T = np.random.rand(2 ** p).astype(np.float64)
        F = np.fft.rfft(T)
        n = len(F)
        
        # uncomment following for performant_optimized and performant_optimized_njit
        """irfft_obj = pyfftw.builders.irfft(
            input_arr[:n],
            overwrite_input=True,
            avoid_copy=True,
            n=2*(n-1),
            threads=1,
        )
        input_arr[:n] = F
        irfft_obj.execute()
        irfft_dict[n] = irfft_obj"""
        

        # uncomment following for performant 
        """irfft_obj = pyfftw.builders.irfft(
            np.empty(n, dtype=np.complex128), 
            overwrite_input=True,
            n=2*(n-1), 
            threads=1
        )
        irfft_obj(F)
        irfft_dict[n] = irfft_obj"""
        

    
    n_iter = 500
    timing = np.full((len(p_range), n_iter), -1.0, dtype=np.float64)

    for j in range(n_iter):
        for i, p in enumerate(p_range):
            T = np.random.rand(2 ** p).astype(np.float64)
            F = np.fft.rfft(T)
           
            start = time.perf_counter()
            naive(F)
            stop = time.perf_counter()
            
            timing[i, j] = stop - start
            
            
    return timing



print(pyfftw.__version__)
print(np.__version__)

timing = pyfftw_timing_func()
# np.save('pyfftw_X_irfft.npy', timing)
print('Done!')