In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import scipy.signal
import scipy.stats

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'damselfly/data/datasets')
SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""


def MakeTemplates(signals, var =  1.38e-23 * 10 * 50 * 200e6):
    norm_mat = 1 / np.sqrt(var * np.diag(np.matmul(signals, signals.conjugate().T)))

    templates = norm_mat.reshape((*norm_mat.shape, 1)).repeat(signals.shape[-1], axis=-1) * signals
    
    return templates


In [None]:
os.listdir(os.path.join(PATH, 'mayfly/data/datasets'))

In [None]:
os.listdir(os.path.join(PATH, 'damselfly/data'))

# load pc's

In [None]:
#pca_all = np.load(os.path.join(PATH, 'damselfly/data', '210914_frequency_spectra_imag_principle_components.npz'))
pca_all_evals = np.load(os.path.join(PATH, 'damselfly/data', '211007_principle_components_all_evals.npy'))
pca_all_evecs = np.load(os.path.join(PATH, 'damselfly/data', '211007_principle_components_all_evecs.npy'))

pca_range_evals = np.load(os.path.join(PATH, 'damselfly/data', '211007_principle_components_1deg_pitch_angle_range_evals.npy'))
pca_range_evecs = np.load(os.path.join(PATH, 'damselfly/data', '211007_principle_components_1deg_pitch_angle_range_evecs.npy'))

In [None]:
pca_all_evals.shape
pca_range_evals.shape

In [None]:
sns.set_theme(context='talk', style='whitegrid')
fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)
ax.set_yscale('log')
ax.plot(abs(pca_all_evals))
ax.set_xlim(0, 256)
ax.set_ylim(1e-15, 1e-12)

for i in range(pca_range_evals.shape[0]):
    
    ax.plot(abs(pca_range_evals[i, :]))

In [None]:
m = 55
k = 5

sns.set_theme(context='talk', style='whitegrid')
fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)
#ax.set_yscale('log')
#ax.plot(pca_all_evecs[:, m].real)
#ax.plot(pca_all_evecs[:, m].imag)

ax.plot(pca_range_evecs[k, :, m].real)
ax.plot(pca_range_evecs[k, :, m].imag)

#ax.plot(abs(pca_range_evecs[k, :, :].max(axis=0)))

#ax.set_xlim(0, 256)
#ax.set_ylim(1e-15, 1e-12)

# load data

In [None]:
data = os.path.join(os.path.join(PATH, 'mayfly/data/datasets'), '211002_mf_84_100_slice8192.h5')
h5datafile = h5py.File(data, 'r')

Nsignal = h5datafile['data'].shape[0]
Nsample = 8192

dataset = np.zeros((Nsignal, Nsample), dtype=np.complex64)

for i in range(Nsignal):
    dataset[i, :] = np.fft.fftshift(np.fft.fft(h5datafile['data'][i, :].reshape(60, 8192).sum(axis=0))) / 8192

# select pc's

In [None]:
pc_max_vals = abs(pca_range_evecs).max(axis=1)

In [None]:
selected_pc_inds = np.argwhere(pc_max_vals >= 0.15)

select_pc = pca_range_evecs[selected_pc_inds[:, 0], :, selected_pc_inds[:, 1]]

print(select_pc.shape)


In [None]:
n = 2
m = 3

var = 1.38e-23 * 10 * 200e6 * 50
#print(var / (np.sqrt(8192)))
noise = np.random.multivariate_normal([0,0], np.eye(2) * var / 2, 8192)

noise = np.fft.fft(noise[:, 0] + 1j * noise[:, 1]) / 8192
pc = select_pc[n, :]
signal = dataset[m, :]

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)

#ax.plot(abs(dataset[m, :]))
#ax.plot(abs(noise))

#ax.plot(abs(scipy.signal.convolve(dataset[m, :], dataset[m, :], mode='full')))
ax.plot(abs(scipy.signal.correlate(pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), signal, mode='full')))
#ax.plot(abs(scipy.signal.convolve(noise, noise, mode='full')))

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)

norm_noise = noise / np.sqrt((abs(noise) ** 2).sum())
norm_signal = dataset[m, :] / np.sqrt((abs(dataset[m, :]) ** 2 ).sum())

ax.plot(abs(scipy.signal.convolve(pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), np.flip(signal), mode='full')))
#ax.plot(abs(scipy.signal.convolve(signal, signal, mode='full')))

In [None]:
#N_comp = 256
m = 67
projection_n, projection_noise, projection_signal = [], [], []

for n in range(select_pc.shape[0]):
    
    noise = np.random.multivariate_normal([0,0], np.eye(2) * 0 / 2, 8192)

    noise = np.fft.fft(noise[:, 0] + 1j * noise[:, 1]) / 8192
    
    pc = select_pc[n, :]
    
    projection_signal.append(abs(scipy.signal.convolve(noise + dataset[m, :], np.flip(dataset[m, :], axis=-1))).max())
    projection_n.append(abs(scipy.signal.fftconvolve(pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), noise + np.flip(dataset[m, :], axis=-1), mode='full')).max())
    #projection_n.append(abs(scipy.signal.correlate(pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), noise + dataset[m, :], mode='full')).max())
    projection_noise.append(abs(scipy.signal.correlate(noise , pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), mode='full')).max())
    
projection_n = np.array(projection_n)
projection_noise = np.array(projection_noise)
projection_signal = np.array(projection_signal)

print(projection_n.sum(), projection_signal.sum(), projection_noise.sum())

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)

ax.plot(projection_n, label='signal+noise')
ax.plot(projection_noise, label='noise')
ax.legend(loc=1)
ax.set_xlabel('Principle Component')
ax.set_ylabel('Correlation Max')
ax.set_title('Correlation Using Principle Components',pad=20)

plt.tight_layout()
#plt.savefig(os.path.join(PLOTPATH, '210914_correlation_using_pc_hard'))


fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)


ax.set_xlabel('Trial')
ax.set_ylabel('Correlation Max')
ax.set_title('Correlation Using Signal, N-Trials',pad=20)


ax.plot(projection_signal, label='signal+noise')
ax.plot(projection_noise,label='noise')

ax.legend(loc=1)

plt.tight_layout()
#plt.savefig(os.path.join(PLOTPATH, '210914_correlation_using_signal_hard'))

In [None]:
N_comp = 256
N_trial = 50
m = 1

pc_max, signal_max, noise_max = [], [], []

for k in range(N_trial):

    projection_n, projection_noise, projection_signal = [], [], []

    for n in range(N_comp):

        noise = np.random.multivariate_normal([0,0], np.eye(2) * var / 2, 8192)

        noise = np.fft.fft(noise[:, 0] + 1j * noise[:, 1]) / 8192

        pc = evecs[:, n]

        projection_signal.append(abs(scipy.signal.correlate(noise + dataset[m, :], dataset[m, :], mode='full')).max())
        projection_n.append(abs(scipy.signal.correlate(pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), noise + dataset[m, :], mode='full')).max())
        projection_noise.append(abs(scipy.signal.correlatet(noise , pc * np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), mode='full')).max())

    projection_n = np.array(projection_n)
    projection_noise = np.array(projection_noise)
    projection_signal = np.array(projection_signal)
    
    pc_max.append(projection_n.mean())
    signal_max.append(projection_signal.mean())
    noise_max.append(projection_noise.mean())
    
    

    
    



In [None]:
pc_max = np.array(pc_max)
signal_max = np.array(signal_max)
noise_max = np.array(noise_max)

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)


hist1 = ax.hist(pc_max, histtype='step')
hist2 = ax.hist(noise_max, histtype='step')
hist3 = ax.hist(signal_max, histtype='step')


In [None]:
n = 5
m = 6

pc = evecs[:, n]
signal = dataset[m, :]

match_pc = abs(scipy.signal.correlate(pc, signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), mode='same')).max()
match_sig = (abs(scipy.signal.correlate(signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), mode='same')).max())


print(match_pc, match_sig, match_pc / match_sig)

In [None]:
sns.set_theme(context='talk', style='whitegrid')

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)


ax.plot(freqs, abs(pc), label='Principle Component')
ax.plot(freqs, abs(signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum())), label='Signal')

ax.set_ylabel('Amplitude')
ax.set_xlabel('Frequency (Hz)')

ax.legend(loc=2)

plt.tight_layout()
plt.savefig(os.path.join(PLOTPATH, '210914_normalized_signal_and_pc'))

#print(np.vdot(pc, pc))
#print(np.vdot(signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum()), signal / np.sqrt((abs(dataset[m, :]) ** 2 ).sum())))

In [None]:
x = scipy.signal.convolve(signal / np.sqrt((abs(signal) ** 2 ).sum()), signal / np.sqrt((abs(signal) ** 2 ).sum()), mode='same')
print(np.vdot(x, x))
plt.plot(abs(x))
#plt.plot(scipy.signal.convolve(signal / np.sqrt((abs(signal) ** 2 ).sum()), signal / np.sqrt((abs(signal) ** 2 ).sum()), mode='same').imag)

In [None]:
x = scipy.signal.convolve(pc, signal / np.sqrt((abs(signal) ** 2 ).sum()), mode='same')
print(np.vdot(x,x))
plt.plot(abs(x))

In [None]:
pc_match_matrix = np.zeros((dataset.shape[0], 256))

norm_dataset = dataset / np.sqrt((abs(dataset) ** 2 ).sum(axis=-1)).reshape((dataset.shape[0], 1)).repeat(dataset.shape[-1], axis=-1)

for k in range(dataset.shape[0]):
    if k % 100 == 99:
        print(k+1)
    for i in range(256):
        pc = evecs[:, i]

        x = abs(scipy.signal.correlate(pc, dataset[k, :] / np.sqrt((abs(dataset[k, :]) ** 2 ).sum()), mode='same'))

        pc_match_matrix[k, i] = x.max()

        #if i % 10 == 9:
        #    print(i + 1)



In [None]:
np.save(os.path.join(PATH, 'damselfly/data', '210915_pc_match_matrix_normalized'), pc_match_matrix)

In [None]:
sig_match_matrix = ((abs(dataset ) ** 2).sum(axis=-1)).reshape((dataset.shape[0], 1)).repeat(256, axis=-1)

In [None]:
sns.set_theme(style='ticks', context='talk')
cmap = sns.color_palette('mako', as_cmap=True)

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)

img = ax.imshow((pc_match_matrix).T, aspect='auto', cmap=cmap, interpolation='none')

fig.colorbar(img, label='Match Ratio')

ax.set_xlabel('Signal Index')
ax.set_ylabel('Principle Component')

ax.set_title('Match Ratio')

plt.tight_layout()
plt.savefig(os.path.join(PLOTPATH, '210914_principle_components_match_ratio_matrix'))



In [None]:
sns.set_theme(style='whitegrid', context='talk')
cmap = sns.color_palette('mako', as_cmap=True)

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)



ax.plot((pc_match_matrix).mean(axis=0))


ax.set_ylabel('Mean Match')
ax.set_xlabel('Principle Component')
ax.set_title('Mean Match for Principle Component')

plt.tight_layout()
plt.savefig(os.path.join(PLOTPATH, '210914_mean_match_for_principle_component'))

In [None]:
sns.set_theme(style='whitegrid', context='talk')
cmap = sns.color_palette('mako', as_cmap=True)

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)



ax.plot(np.flip(np.sort((pc_match_matrix).mean(axis=0))))


ax.set_ylabel('Mean Match')
ax.set_xlabel('Principle Component')
ax.set_title('Mean Match for Principle Component, Sorted')

plt.tight_layout()
plt.savefig(os.path.join(PLOTPATH, '210914_mean_match_for_principle_component_sorted'))

In [None]:
sns.set_theme(style='whitegrid', context='talk')
cmap = sns.color_palette('mako', as_cmap=True)

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)


hist = ax.hist(pc_match_matrix.max(axis=-1), 32)
#ax.plot(np.mean(pc_match_matrix.max(axis=-1)) * np.ones(5492), '--' ,color='tab:red', label = f'Mean Match = {np.round(np.mean(pc_match_matrix.max(axis=-1)), 3)}')
#ax.legend(loc=(0.3, 0.6))

ax.set_ylabel('N')
ax.set_xlabel('Match Ratio')
ax.set_title('Best Principle Component Match Ratio per Signal')


plt.tight_layout()
plt.savefig(os.path.join(PLOTPATH, '210914_best_match_for_principle_component'))
#print(np.mean(pc_match_matrix.max(axis=-1)))

In [None]:
covariance_real = np.matmul(signal_real_norm.T, signal_real_norm)
covariance_imag = np.matmul(signal_imag_norm.T, signal_imag_norm)
plt.figure()
plt.imshow(covariance_real, interpolation = 'none')
#plt.xlim(0, 20)
#plt.ylim(0, 20)
plt.figure()
plt.imshow(covariance_imag, interpolation = 'none')
#plt.xlim(0, 20)
#plt.ylim(0, 20)

In [None]:
print('real')
real_evals, real_evecs = np.linalg.eig(covariance_real)
print('imag')
imag_evals, imag_evecs = np.linalg.eig(covariance_imag)

In [None]:
print(real_evecs.shape)

In [None]:
plt.plot(real_evals)
plt.plot(imag_evals)
plt.xlim(0, 10)

In [None]:
plt.plot(real_evecs[1, :])
#plt.plot(real_evecs[1, :])
#plt.plot(real_evecs[2, :])
#plt.plot(real_evecs[3, :])
#plt.plot(real_evecs[4, :])
#plt.plot(real_evecs[5, :])
#plt.plot(real_evecs[6, :])
plt.xlim(8000, 8192)

# try 100 signals