import pyfftw
import numpy as np
import scipy
import time
from numba import njit


def naive(T):
    n = len(T)
    rfft_obj = pyfftw.builders.rfft(np.empty(n))
    rfft_obj(T)
    return
            
def performant(T, rfft_dict):
    n = len(T)
    rfft_obj = rfft_dict[n]
    rfft_obj(T)
    return rfft_obj.output_array


def performant_optimized(T, rfft_dict, input_arr):
    n = len(T)
    input_arr[:n] = T
    rfft_obj = rfft_dict[n]
    rfft_obj.execute()
    return rfft_obj.output_array



def pyfftw_timing_func():
    
    p_min, p_max = 2, 23
    p_range = range(p_min, p_max + 1)
    
    n_max = 2 ** p_max
    input_arr = pyfftw.empty_aligned(n_max, dtype=np.float64)

    rfft_dict = {}
    a = pyfftw.empty_aligned(2 ** p_max, dtype=np.float64)
    for i, p in enumerate(p_range):
        T = np.random.rand(2 ** p).astype(np.float64)
        n = len(T)
        rfft_obj = rfft_obj = pyfftw.builders.rfft(
            input_arr[:n],
            overwrite_input=True,
            avoid_copy=True,
            n=n,
            threads=1,
        )
        rfft_obj.execute()
        rfft_dict[n] = rfft_obj

    
    n_iter = 500
    timing = np.full((len(p_range), n_iter), -1.0, dtype=np.float64)

    for j in range(n_iter):
        for i, p in enumerate(p_range):
            T = np.random.rand(2 ** p).astype(np.float64)
            # ref = np.fft.rfft(T)
           
           
            start = time.perf_counter()
            performant_optimized(T, rfft_dict, input_arr)
            stop = time.perf_counter()
            
            timing[i, j] = stop - start
            
            # np.testing.assert_almost_equal(comp, ref, decimal=8)
        
            
    return timing



print(pyfftw.__version__)
print(np.__version__)

timing = pyfftw_timing_func()
np.save('pyfftw_performant_optimized.npy', timing)
print('Done!')