In [None]:
import numpy as np
import mayfly as mf
import h5py
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os 
import sys
import json
import scipy.signal
import scipy.stats
import scipy.interpolate
import pickle as pkl
import scipy.optimize

PATH = '/storage/home/adz6/group/project/'
RESULTPATH = os.path.join(PATH, 'results/mayfly')
PLOTPATH = os.path.join(PATH, 'plots/mayfly')
DATAPATH = os.path.join(PATH, 'mayfly/data/datasets')


def linear_fit(x, a, b):
    
    return a + b * x


In [None]:
os.listdir(os.path.join(DATAPATH))

In [None]:
test_data = mf.data.MFDataset(os.path.join(DATAPATH, '210810_84_100_rng_test.h5'))
template_data = mf.data.MFDataset(os.path.join(DATAPATH, '211002_mf_84_100_slice8192.h5'))

In [None]:
test_data.data.shape

In [None]:
param_list = ['theta_min', 'energy']

test_metadata = pd.DataFrame(test_data.metadata)[param_list]
template_metadata = pd.DataFrame(template_data.metadata)[param_list]

test_energies = np.array(test_metadata['energy'].array)
test_angles = np.array(test_metadata['theta_min'].array)


In [None]:
energy_grid, angle_grid = np.meshgrid(template_metadata['energy'].unique(), template_metadata['theta_min'].unique())

# plot parameter distribution

In [None]:
sns.set_theme(style='whitegrid', context='talk')
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot(energy_grid.flatten(), angle_grid.flatten(), '.')
ax.plot(test_energies, test_angles, '.')

# compute scores

In [None]:
var = 1.38e-23 * 10 * 50 * 200e6
ideal_test_scores = abs((1 / np.sqrt(var * np.sum(test_data.data[:] * test_data.data[:].conjugate(), axis=-1))) * np.sum(test_data.data[:] * test_data.data[:].conjugate(), axis=-1))

In [None]:
plt.plot(ideal_test_scores, '.')

# template on template

In [None]:
print(template_data.data.shape)
var = 1.38e-23 * 10 * 50 * 200e6
n_split = 21
ideal_template_scores = np.zeros((template_data.data.shape[0], template_data.data.shape[0]))

#print(np.array_split(ideal_template_scores,3, axis)[-1].shape)
split_scores = np.array_split(ideal_template_scores, n_split,)

split_inds = np.zeros(n_split+1)

for i in range(n_split):
    #print(split_scores[i].shape[0] )
    split_inds[i + 1] += split_scores[i].shape[0] + split_inds[i]
    
split_inds = np.int32(split_inds)

print(split_inds)

for i in range(n_split):
    norm = (1 / np.sqrt(var * np.sum(template_data.data[split_inds[i]:split_inds[i+1], :] * template_data.data[split_inds[i]:split_inds[i+1], :].conjugate(), axis=-1))).reshape((template_data.data[split_inds[i]:split_inds[i+1], :].shape[0], 1)).repeat(template_data.data[split_inds[i]:split_inds[i+1], :].shape[-1], axis=-1)
    norm_templates = norm * template_data.data[split_inds[i]:split_inds[i+1], :]
    print('norm done')
    
    for j in range(n_split):
        
        if j <= i:
            #norm = (1 / np.sqrt(var * np.sum(template_data.data[split_inds[i]:split_inds[i+1], :] * template_data.data[split_inds[i]:split_inds[i+1], :].conjugate(), axis=-1))).reshape((template_data.data[split_inds[i]:split_inds[i+1], :].shape[0], 1)).repeat(template_data.data[split_inds[i]:split_inds[i+1], :].shape[-1], axis=-1)

            scores_ij = abs(np.matmul(norm_templates, template_data.data[split_inds[j]:split_inds[j+1], :].conjugate().T))

            #scores_ij = abs((1 / np.sqrt(var * np.sum(template_data.data[split_inds[i]:split_inds[i+1], :] * template_data.data[split_inds[i]:split_inds[i+1], :].conjugate(), axis=-1))) * np.sum(template_data.data[split_inds[i]:split_inds[i+1], :] * template_data.data[split_inds[j]:split_inds[j+1], :].conjugate(), axis=-1))

            ideal_template_scores[split_inds[i]:split_inds[i+1], split_inds[j]:split_inds[j+1]] = scores_ij
            print(i+1, j+1)


In [None]:
np.save(os.path.join(RESULTPATH, '211005_mf_84_100_template_ideal_scores_bottom_tri'), ideal_template_scores, )

# random signal on template

In [None]:
var = 1.38e-23 * 10 * 50 * 200e6
n_split = 5
signal_template_scores = np.zeros((test_data.data.shape[0], template_data.data.shape[0]))

split_scores = np.array_split(np.zeros(template_data.data.shape[0]), n_split,)

split_inds = np.zeros(n_split+1)

for i in range(n_split):
    #print(split_scores[i].shape[0] )
    split_inds[i + 1] += split_scores[i].shape[0] + split_inds[i]
    
split_inds = np.int32(split_inds)

print(split_inds)

for i in range(n_split):
    norm = (1 / np.sqrt(var * np.sum(template_data.data[split_inds[i]:split_inds[i+1], :] * template_data.data[split_inds[i]:split_inds[i+1], :].conjugate(), axis=-1))).reshape((template_data.data[split_inds[i]:split_inds[i+1], :].shape[0], 1)).repeat(template_data.data[split_inds[i]:split_inds[i+1], :].shape[-1], axis=-1)
    norm_templates = norm * template_data.data[split_inds[i]:split_inds[i+1], :]
    print('norm done')
    
    scores_ij = abs(np.matmul(test_data.data[:], norm_templates.conjugate().T))
    
    signal_template_scores[:, split_inds[i]:split_inds[i+1]] = scores_ij


In [None]:
np.save(os.path.join(RESULTPATH, '211007_mf_84_100_test_scores'), signal_template_scores, )

In [None]:
fig = plt.figure(figsize=(13,8))

ax = fig.add_subplot(1,1,1)

ax.imshow(signal_template_scores, interpolation='none', aspect='auto')

In [None]:
ideal_template_scores.shape

In [None]:
fig = plt.figure(figsize=(13,13))

ax = fig.add_subplot(1,1,1)

ax.imshow(np.tril(ideal_template_scores).reshape((24461, 24461)), interpolation='none')

In [None]:
np.save(os.path.join(RESULTPATH, '211005_mf_84_100_template_ideal_scores_bottom_tri'), 
        np.tril(ideal_template_scores).reshape((24461, 24461)), )

# compute templates

In [None]:
norm = (1 / np.sqrt(var * np.sum(template_data.data[:] * template_data.data[:].conjugate(), axis=-1))).reshape((template_data.data.shape[0], 1)).repeat(template_data.data.shape[-1], axis=-1)

templates = template_data.data * norm

In [None]:
print(templates)

In [None]:
print(templates.shape)
print(test_data.data.shape)

In [None]:
scores = abs(np.matmul(test_data.data[:], templates.T.conjugate()))
#scores = abs(np.matmul(templates, templates.T.conjugate()))

In [None]:
print(scores)


In [None]:
hist = plt.hist(np.max(scores, axis=-1) / ideal_test_scores, 20)


In [None]:
print(np.mean(np.max(scores, axis=-1) / ideal_test_scores))

# plot maps for a specific test signal

In [None]:
rng = np.random.default_rng()

In [None]:
cmap = sns.color_palette('mako', as_cmap=True)

for k in range(10):
    
    sns.set_theme(context='talk', style='ticks')
    fig = plt.figure(figsize=(13,8))
    ax = fig.add_subplot(1,1,1)
    
    itest_signal = rng.integers(0, scores.shape[0], 1)
    print(itest_signal)
    
    
    score_array = scores[itest_signal, :].squeeze()
    imax = score_array.argmax()
    
    template_energy = template_metadata['energy']
    unique_energies = np.sort(template_energy.unique())

    template_angle = template_metadata['theta_min']
    unique_angles = np.sort(template_angle.unique())

    energy_grid, angle_grid = np.meshgrid(unique_energies, unique_angles)

    test_signal_scores = np.zeros(energy_grid.size).flatten()

    for i in range(energy_grid.flatten().size):

        test_signal_scores[i] = score_array[template_metadata[(template_metadata['energy'] == energy_grid.flatten()[i]) & (template_metadata['theta_min'] == angle_grid.flatten()[i])].index[0]]

    test_signal_scores = test_signal_scores.reshape(energy_grid.shape)
    
    # log color scale
    img = ax.imshow(test_signal_scores, interpolation='none', aspect='auto', cmap=cmap, extent=(18595, 18596, 87.1, 87), norm=matplotlib.colors.LogNorm())
    
    #img = ax.imshow(test_signal_scores, interpolation='none', aspect='auto', cmap=cmap, extent=(18595, 18596, 87.1, 87),)
    cbar = fig.colorbar(img)
    ax.plot(test_metadata['energy'][itest_signal], test_metadata['theta_min'][itest_signal], 'r*', markersize=20, label='Signal Parameters')
    
    test_energy = test_metadata['energy'][itest_signal]
    test_angle = test_metadata['theta_min'][itest_signal]
    
    near_energy = unique_energies[np.argmin(abs(unique_energies - test_energy.iloc[0]))]
    near_angle = unique_angles[np.argmin(abs(unique_angles - test_angle.iloc[0]))]
    
    inearest_template = template_metadata[(template_metadata['energy'] == near_energy) & (template_metadata['theta_min'] == near_angle)].index[0]

    ax.plot(template_metadata['energy'][inearest_template], template_metadata['theta_min'][inearest_template], 'y*', markersize=20, label='Nearest Template Parameters')
    ax.plot(template_metadata['energy'][imax], template_metadata['theta_min'][imax], 'g*', markersize=20, label='Best Template Parameters')
    ax.set_xlabel('Template Energy (eV)')
    ax.set_ylabel('Template Angle (deg)')
    
    ax.legend(loc=0)
    
    test_spectrum = abs(np.fft.fftshift(np.fft.fft(test_data.data[itest_signal, :].reshape(60, 8192).sum(axis=0)))) / 8192
    best_template_spectrum = abs(np.fft.fftshift(np.fft.fft(template_data.data[imax, :].reshape(60, 8192).sum(axis=0)))) / 8192
    near_template_spectrum = abs(np.fft.fftshift(np.fft.fft(template_data.data[inearest_template, :].reshape(60, 8192).sum(axis=0)))) / 8192
    freqs = np.fft.fftshift(np.fft.fftfreq(8192, 1/200e6))
    
    sns.set_theme(context='talk', style='whitegrid')
    fig = plt.figure(figsize=(13,8))
    
    ax = fig.add_subplot(1,1,1)
    ax.plot(freqs, best_template_spectrum, label='best template')
    ax.plot(freqs, near_template_spectrum, label='nearest template')
    ax.plot(freqs, test_spectrum, label='test signal')
    
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Mag')
    
    ax.legend(loc=2)
    
    sns.set_theme(context='talk', style='whitegrid')
    fig = plt.figure(figsize=(13,8))
    ax = fig.add_subplot(1,1,1)
    ax.plot(freqs, best_template_spectrum - test_spectrum, label='best template residual')
    
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Mag')
    
    ax.legend(loc=2)
    
    sns.set_theme(context='talk', style='whitegrid')
    fig = plt.figure(figsize=(13,8))
    ax = fig.add_subplot(1,1,1)
    
    ax.plot(freqs, near_template_spectrum - test_spectrum, label='nearest template residual')
    #ax.plot(freqs, test_spectrum, label='test signal')
    
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Mag')
    
    ax.legend(loc=2)
    
    
    
    

# plot histogram of E difference of between mf max and true parameters

In [None]:
energy_diff = np.zeros(100)

for k in range(100):
    itest_signal = k
    
    score_array = scores[itest_signal, :].squeeze()
    
    jtest_signal = np.argmax(score_array)
    
    test_energy = test_metadata['energy'][itest_signal]
    
    temp_energy = template_energy[jtest_signal]
    
    energy_diff[k] = temp_energy-test_energy


In [None]:
sns.set_theme(context='talk', style='whitegrid')

fig = plt.figure(figsize=(13,8))

ax = fig.add_subplot(1,1,1)


hist = ax.hist(energy_diff, 32)

# rician distributions of the template grid for a specific signal, compute the expectation value for maximum template score

In [None]:
itest_signal = 0
x_rice = np.linspace(0, 10, 101)
score_array = scores[itest_signal, :].squeeze()
imax = score_array.argmax()


rice_cdf = scipy.stats.rice.cdf(x_rice.reshape((x_rice.size, 1)).repeat(score_array.size, axis=-1), score_array.reshape((1, score_array.size)).repeat(x_rice.size, axis=0))

rice_pdf = scipy.stats.rice.pdf(x_rice.reshape((x_rice.size, 1)).repeat(score_array.size, axis=-1), score_array.reshape((1, score_array.size)).repeat(x_rice.size, axis=0))

In [None]:
for i in range(861):
    
    temp_energy  = template_metadata['energy'].iloc[i]
    
    #print(temp_energy)
    
    plt.plot(x_rice, rice_pdf[:, i])
    #plt.plot(x_rice, rice_cdf[:, imax])

In [None]:
for i in range(861):
    
    plt.plot(x_rice, 1 - rice_cdf[:, i])
    
plt.ylabel('Prob. False Alarm')
plt.xlabel('Threshold')

# random sample from rician pdf of scores to create a noisy score image

In [None]:
itest_signal = 0
score_array = scores[itest_signal, :].squeeze()
random_samples = scipy.stats.rice.rvs(score_array)

cmap = sns.color_palette('mako', as_cmap=True)
rand_scores = np.zeros(energy_grid.size).flatten()

for i in range(energy_grid.flatten().size):

    rand_scores[i] = random_samples[template_metadata[(template_metadata['energy'] == energy_grid.flatten()[i]) & (template_metadata['theta_min'] == angle_grid.flatten()[i])].index[0]]
    
rand_scores = rand_scores.reshape(energy_grid.shape)

sns.set_theme(context='talk', style='ticks')
fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)

img = ax.imshow(rand_scores, interpolation='none', aspect='auto', cmap=cmap, extent=(18595, 18596, 87.1, 87),)
cbar = fig.colorbar(img)

ax.plot(test_metadata['energy'][itest_signal], test_metadata['theta_min'][itest_signal], 'r*', markersize=20, label='Signal Parameters')

ax.set_xlabel('Template Energy (eV)')
ax.set_ylabel('Template Angle (deg)')


# Threshold the noisy image

In [None]:
itest_signal = 2
score_array = scores[itest_signal, :].squeeze()
random_samples = scipy.stats.rice.rvs(score_array)

thresh = 2

cmap = sns.color_palette('mako', as_cmap=True)
rand_scores = np.zeros(energy_grid.size).flatten()

for i in range(energy_grid.flatten().size):
    
    sample = random_samples[template_metadata[(template_metadata['energy'] == energy_grid.flatten()[i]) & (template_metadata['theta_min'] == angle_grid.flatten()[i])].index[0]]
    true_score = score_array[template_metadata[(template_metadata['energy'] == energy_grid.flatten()[i]) & (template_metadata['theta_min'] == angle_grid.flatten()[i])].index[0]]
    if sample > thresh:
        rand_scores[i] = sample
    else: 
        rand_scores[i] = 0.01

rand_scores = rand_scores.reshape(energy_grid.shape)

sns.set_theme(context='talk', style='ticks')
fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)

img = ax.imshow(rand_scores, interpolation='none', aspect='auto', cmap=cmap, extent=(18595, 18596, 87.1, 87))
cbar = fig.colorbar(img)

ax.plot(test_metadata['energy'][itest_signal], test_metadata['theta_min'][itest_signal], 'r*', markersize=20, label='Signal Parameters')

ax.set_xlabel('Template Energy (eV)')
ax.set_ylabel('Template Angle (deg)')


# simplest possible estimation, take maximum

In [None]:
ntrial = 1024

random_sample_energy_maxima = np.zeros((100, 1024))
    
for k in range(ntrial):

    random_samples = scipy.stats.rice.rvs(scores)
    #print(random_samples.shape)
    #print(np.argmax(random_samples, axis = -1))
    random_sample_energy_maxima[:, k] = template_metadata['energy'].iloc[np.argmax(random_samples, axis = -1)].array

    #rand_scores = np.zeros(energy_grid.size).flatten()

    #for i in range(energy_grid.flatten().size):

    #    sample = random_samples[template_metadata[(template_metadata['energy'] == energy_grid.flatten()[i]) & (template_metadata['theta_min'] == angle_grid.flatten()[i])].index[0]]

    #    rand_scores[i] = sample

    #imax = rand_scores.argmax()

    #random_sample_energy_maxima[n, k] = energy_grid.flatten()[imax]
        

In [None]:
hist = plt.hist(random_sample_energy_maxima[10, :] - 18595, 16)
print(test_metadata['energy'].iloc[10] - 18595)

In [None]:
for i in range(7):
    
    itest = rng.integers(0, 100, 1)[0]
    
    fig = plt.figure(figsize=(13,8))
    ax = fig.add_subplot(1,1,1)

    hist = ax.hist(random_sample_energy_maxima[itest, :] - test_metadata['energy'].iloc[itest], 16)
    
    ax.set_xlim(-1, 1)
    #print(test_metadata['energy'].iloc[10] - 18595)

In [None]:
fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)

hist = ax.hist(random_sample_energy_maxima.mean(axis=-1) - test_metadata['energy'].iloc[itest], 20)

print(np.std(random_sample_energy_maxima.mean(axis=-1) - test_metadata['energy'].iloc[itest]))