In [38]:
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.stats import bernoulli
from IO.config import parse_config
from core import signal_generator
from core.freq_transform import transform_all

In [39]:
usr_config = parse_config('../yaml/example.yaml')

In [None]:
N = 16
L = 20
phi = np.pi*1/4

usr_config.signal.block_size = N
usr_config.signal.num_blocks_avg = L
usr_config.signal.hop_size = N*L
usr_config.signal.num_pos_decision = 2

usr_config.noise.init_args.top = 0.1
usr_config.noise.init_args.steady_state = 0.1
npos = usr_config.signal.num_pos_decision
fs = usr_config.signal.fs

test_k = np.linspace(0,N,2001)
test_f = test_k/N * fs

ret = np.zeros((4,N,2001))

for idx_f, f in enumerate(test_f):
    #np.random.seed(55635)
    for idx_m, method in enumerate(['fft','fht','fht_jitter','fht_ditter']):
        usr_config.freq_transform_method.name = method
        usr_config.signal.phases = [phi]
        usr_config.signal.freqs = [f]

        input_signal_generator = signal_generator.InputSignalGenerator(usr_config.signal, usr_config.noise)
        input_signal, _ = input_signal_generator.get()
        input_signal = input_signal[npos:,:,:]
        sqm, _ = transform_all(input_signal, usr_config.freq_transform_method, usr_config.signal)
        sqm = sqm.mean(1)
    
        ret[idx_m, :, idx_f] = sqm[0]

In [None]:
bins = np.array([i for i in range(1,int(N/2))] + [i for i in range(int(N/2)+1, N)])
for k in bins:
    plt.figure(figsize=(10,5))
    plt.plot(test_f[0:int(len(test_f)/2)],ret[0, k, 0:int(len(test_f)/2)],label='DFT')
    plt.plot(test_f[0:int(len(test_f)/2)],ret[2, k, 0:int(len(test_f)/2)],label='J-DHT')
    plt.plot(test_f[0:int(len(test_f)/2)],ret[1, k, 0:int(len(test_f)/2)],label='DHT')
    plt.plot(test_f[0:int(len(test_f)/2)],ret[3, k, 0:int(len(test_f)/2)],label='D-DHT')
    plt.legend(fontsize=15)
    plt.grid()
    plt.title('k={}, N={}'.format(k,N),fontsize=15)
    plt.tick_params(labelsize=15)
    plt.xlabel('$f_0$(Hz)',fontsize=15)
    plt.ylabel('Normalized Squared Magnitude',fontsize=15)
    
    dirname = '../plots/jitter_response/phi_{}/N_{}'.format(phi,N)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        
    plt.savefig(os.path.join(dirname,'k_{}.png'.format(k)))