In [None]:
import numpy as np
import mayfly as mf
import h5py
import pandas as pf
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import os 
import sys
import json
import scipy.signal
import scipy.stats
import scipy.interpolate
import pickle as pkl

PATH = '/storage/home/adz6/group/project/'
RESULTPATH = os.path.join(PATH, 'results/mayfly')
PLOTPATH = os.path.join(PATH, 'plots/mayfly')
SIMDATAPATH = os.path.join(PATH, 'sim_data')

def f_c(time, fax, f0, deltaf):
    
    return f0 + deltaf * np.cos(2*np.pi*time*fax)**2 + 5 * np.mean(deltaf * np.cos(2*np.pi*time*fax)**2)

def A(time, fax, a0, deltaa):
    
    return a0 - deltaa * np.cos(2*np.pi*time*fax)**2
    
def TemplateFrequenciesFromSpectrum(f_start, fspace, path2tritium, hz_per_ev=49.1e3, bw_ev=100, bins_tritium=201, power=0.7):
    
    with open(path2tritium, 'rb') as infile:
        tritium_data = pkl.load(infile)
        
    vals, bins = np.histogram(tritium_data, bins_tritium)
    
    bin_prob = vals / vals.sum()
    rel_bin_prob = bin_prob / bin_prob.max()
    
    bw = abs(hz_per_ev) * bw_ev
    f_max = f_start + bw
    interpolation_frequencies = np.linspace(f_start, f_max, rel_bin_prob.size)
    rel_bin_prob_interp = scipy.interpolate.interp1d(interpolation_frequencies, rel_bin_prob, fill_value='extrapolate')
    
    f_last = f_start
    frequency_list = []

    while f_last < f_max:
        frequency_list.append(f_last)
        delta_f =  fspace / rel_bin_prob_interp(f_last) ** power
        f_last += delta_f
        
    frequency_list = np.array(frequency_list)
    
    return frequency_list

def RNGFrequenciesFromSpectrum(f_start, f_end, N, path2tritium):
    with open(path2tritium, 'rb') as infile:
        tritium_data = pkl.load(infile)
        
    #vals, bins = np.histogram(tritium_data, bins_tritium)
    
    norm_tritium_data = (tritium_data - np.min(tritium_data)) / (np.max(tritium_data) - np.min(tritium_data))
    
    rng = np.random.default_rng()
    
    rng_frequencies = f_start + (f_end - f_start) * norm_tritium_data[rng.integers(0, norm_tritium_data.size, N)]
    
    return rng_frequencies

def GenTemplates(pseudoangles, f_0, f_space, path2tritium, time, hz_per_ev=49.1e3, bw_ev=100, bins_tritium=201, power=0.7, var=1):
    frequencies = TemplateFrequenciesFromSpectrum(
                                                f_0, 
                                                f_space, 
                                                path2tritium, 
                                                hz_per_ev=hz_per_ev, 
                                                bw_ev=bw_ev, 
                                                bins_tritium=bins_tritium, 
                                                power=power
                                                )
    
    a0 = 1
    deltaa = 0.1 * pseudoangles
    deltaf = 5e3 * pseudoangles
    f_ax = 2e6 * pseudoangles
    
    print(frequencies.shape)
    
    templates = np.zeros((frequencies.size * pseudoangles.size, time.size), dtype=np.complex128)
    n = 0 
    for i, frequency in enumerate(frequencies):
        for j, angle in enumerate(pseudoangles):
            
            signal_with_FM = np.exp(1j * 2 * np.pi * f_c(time, f_ax[j], frequency, deltaf[j] ) * time)
            templates[n, :] = A(time, f_ax[j], a0, deltaa[j]) * signal_with_FM
            n += 1
    
    
    #norm = 1 / np.sqrt(var * abs(np.diag(np.matmul(templates, templates.conjugate().T)))).reshape((templates.shape[0], 1)).repeat(templates.shape[-1], axis=-1)
    #print(norm)
    return templates

def GenRNGSignals(f_0, f_end, path2tritium, time, N):
    frequencies = RNGFrequenciesFromSpectrum(
                                            f_0, 
                                            f_end,
                                            N,
                                            path2tritium, 
                                            )
    
    rng = np.random.default_rng()
    
    pseudoangles = rng.uniform(0., 6., N)
    #print(pseudoangles)
    
    a0 = 1
    deltaa = 0.1 * pseudoangles
    deltaf = 5e3 * pseudoangles
    f_ax = 2e6 * pseudoangles
    
    signals = np.zeros((N, time.size), dtype=np.complex64)
    
    for i, frequency in enumerate(frequencies):
            
        signal_with_FM = np.exp(1j * 2 * np.pi * f_c(time, f_ax[i], frequency, deltaf[i] ) * time)
        signals[i, :] = A(time, f_ax[i], a0, deltaa[i]) * signal_with_FM
            
    return signals    



In [None]:
tritium_samples = os.path.join(SIMDATAPATH, '210615_tritium_energy_spectrum.pkl')
ratio = 0.3
power = 0.5

f0 = 50e6
fend = f0 + 49.1e3 * 100
N = 8192
fs = 200e6
t = np.arange(0, N, 1) * 1 / fs

f_space_0 = ratio * fs/N

template_pseudoangles = np.linspace(0, 6, 1)

N_test = 1000

templates = GenTemplates(template_pseudoangles, f0, f_space_0, tritium_samples, t, power=power)
rng_signals = GenRNGSignals(f0, fend, tritium_samples, t, N_test)


In [None]:
print(templates.shape)
print(rng_signals.shape)

In [None]:
scores = abs(np.matmul(templates, rng_signals.conjugate().T))
ideal_scores = abs(np.matmul(rng_signals, rng_signals.conjugate().T))

In [None]:
match = scores.max(axis=0) / np.diag(ideal_scores)

In [None]:
sns.set_theme(context='talk', style='darkgrid')
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)

hist = ax.hist(match)
ax.vlines(np.mean(match), 0, 100, color='r')
ax.text(0.8, 150, f'mean={np.round(np.mean(match), 4)}')

# prototyping toy pitch angle simulation

In [None]:
N = 8192
fs = 200e6

t = np.arange(0, N, 1) * 1 / fs
f0 = 100e6
deltaf = 15e3
f_AM = 5e6
f_FM = 5e6

a0 = 1
deltaa = 0.8

signal = np.exp(1j * 2 * np.pi * f0 * t)
signal_with_AM = A(t, f_AM, a0, deltaa) * signal
signal_with_FM = np.exp(1j * 2 * np.pi * f_c(t, f_FM, f0, deltaf ) * t)
signal_with_AMFM = A(t, f_AM, a0, deltaa) * signal_with_FM

#plt.plot(abs(np.fft.fft(signal_with_FM)))
#plt.plot(abs(np.fft.fft(signal_with_AM)))
plt.plot(abs(np.fft.fft(signal_with_AMFM)))
print(np.argmax(abs(np.fft.fft(signal_with_FM))))

In [None]:
pseudoangle = np.linspace(0, 6, 21)
f_0 = 100e6
deltaf = 5e3 * pseudoangle
f_ax = 2e6 * pseudoangle

a0 = 1
deltaa = 0.1 * pseudoangle 

signals = np.zeros((pseudoangle.size, N), dtype=np.complex64)

for i, angle in enumerate(pseudoangle):
    
    signal_with_FM = np.exp(1j * 2 * np.pi * f_c(t, f_ax[i], f0, deltaf[i] ) * t)
    
    signals[i, :] = A(t, f_ax[i], a0, deltaa[i]) * signal_with_FM


In [None]:
plt.plot(abs(np.fft.fft(signals[0, :])))
plt.plot(abs(np.fft.fft(signals[1, :])))
plt.plot(abs(np.fft.fft(signals[10, :])))
#plt.xlim(4040, 4120)