In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""


    


In [None]:
os.listdir(os.path.join(DATAPATH, 'bf'))

In [None]:
os.listdir(os.path.join(PATH, 'datasets', 'kass'))

# load data

In [None]:
# signal data
data = mf.data.MFDataset(os.path.join(DATAPATH, 'bf', '211203_84_25_2cm_sum.h5'))
metadata = pd.DataFrame(data.metadata)

# kass data
h5kass_data = h5py.File(os.path.join(PATH, 'datasets', 'kass', '211116_grad_b_est_kass.h5'), 'r')

kass_data = h5kass_data['kass']
kass_metadata = {}
for i, key in enumerate(h5kass_data['meta'].keys()):
    kass_metadata[key] = h5kass_data['meta'][key][:]
    
kass_metadata = pd.DataFrame.from_dict(kass_metadata)

# compute MF scores

In [None]:
slicesize = 8192

data_slice = data.data[:, 0:slicesize]



In [None]:
var = 1.38e-23 * 3 * 50 * 60 * 200e6

templates = data_slice * 1 / np.sqrt(var * (abs(data_slice) ** 2).sum(axis=-1) ).reshape(data_slice.shape[0],1)

scores = abs(np.sum(data_slice * templates.conjugate(), axis=-1))
print(scores.mean())

# plot example mf score pdf's

In [None]:
signal = data_slice[10000, :]
template = templates[10000, :]

sample_scores = np.zeros(10000)
for i in range(10000):

    noise = np.random.multivariate_normal([0, 0], np.eye(2) * var / 2, signal.size)
    noise = noise[:, 0] + 1j * noise[:, 0]
    
    score = abs(np.vdot(template, noise))
    sample_scores[i] = score

In [None]:
hist = plt.hist(sample_scores, 31)

In [None]:

x_rice = np.linspace(0, 20, 301)
y_rice = scipy.stats.rayleigh.pdf(x_rice, loc=0, scale=1/np.sqrt(2))

plt.plot(hist[1][0:31], hist[0])
plt.plot(x_rice, 1050 * y_rice, label='Rician')
plt.xlim(0, 5)

In [None]:
signal = data_slice[6000, :]
template = templates[6000, :]

sample_scores = np.zeros(10000)
sample_noise = np.zeros(10000)
for i in range(10000):

    noise = np.random.multivariate_normal([0, 0], np.eye(2) * var / 2, signal.size)
    noise = noise[:, 0] + 1j * noise[:, 0]
    
    score = abs(np.vdot(template, signal + noise))
    score_noise = abs(np.vdot(template, noise))
    sample_scores[i] = score
    sample_noise[i] = score_noise

In [None]:
sns.set_theme(context='poster')
fig = plt.figure(figsize=(13,8))

ax = fig.add_subplot(1,1,1)

x_rice = np.linspace(0, 20, 301)
y_rice = scipy.stats.rice.pdf(x_rice , scores[6000], scores[6000] * (1-1/np.sqrt(2)), scale = 1/np.sqrt(2))

x_ray = np.linspace(0, 20, 301)
y_ray = scipy.stats.rice.pdf(x_rice , 0,scale = 1/np.sqrt(2))

ax.hist(sample_scores, 31, density=True, label = 'Signal Samples')
ax.hist(sample_noise, 31, density=True, label='Noise Samples')
ax.plot(x_rice, y_rice, label='Rician')
ax.plot(x_ray, y_ray, label='Rayleigh')
ax.set_xlim(0, 7)

ax.set_xlabel('MF Score')
ax.legend(loc=1)
ax.set_ylabel('Probability Density')
ax.set_title('Example MF Score Distributions')

#plt.savefig(os.path.join(PATH, 'plots/mayfly', '220107_example_rician_and_rayleigh_mf_score_dist'))

In [None]:

x_rice = np.linspace(0, 20, 301)
y_rice = scipy.stats.rice.pdf(x_rice, scores[0],scores[0] * (1-1/np.sqrt(2)), 1/np.sqrt(2))

x_rayleigh = np.linspace(0, 20, 301)
y_rayleigh = scipy.stats.rice.pdf(x_rayleigh, 0, 1/np.sqrt(2))


plt.figure()

plt.plot(x_rayleigh, y_rayleigh, label='Rayleigh')
plt.plot(x_rice, y_rice, label='Rician')
plt.xlim(0, 8)

# pick a threshold

In [None]:
plt.figure()

plt.plot(x_rayleigh, y_rayleigh, label='Rayleigh')
plt.xlim(0, 8)

In [None]:
y_rayleigh = scipy.stats.rice.cdf(x_rayleigh, 0, 1/np.sqrt(2))
y_rice = scipy.stats.rice.cdf(x_rice, scores[4000], scores[4000] * (1-1/np.sqrt(2)), 1/np.sqrt(2))


plt.figure()

plt.plot(x_rayleigh, y_rayleigh, label='Rayleigh')
plt.plot(x_rice, y_rice, label='Rayleigh')
plt.xlim(0, 8)
print(x_rayleigh[81], y_rayleigh[81])



In [None]:
y_rayleigh = scipy.stats.rice.sf(x_rayleigh, 0, scale=1/np.sqrt(2))
y_rice = scipy.stats.rice.sf(x_rice, scores[5000], scores[5000] * (1-1/np.sqrt(2)), scale=1/np.sqrt(2))


plt.figure()

plt.plot(x_rayleigh, y_rayleigh, label='Rayleigh')
plt.plot(x_rice, y_rice, label='Rayleigh')
plt.xlim(0, 8)
print(x_rayleigh[81], y_rayleigh[81])
plt.yscale('log')

In [None]:
plt.plot(y_rayleigh, y_rice)

plt.xscale('log')

In [None]:
y_rice = scipy.stats.rice.sf(x_rice, scores.reshape(scores.size, 1), (scores * (1-1/np.sqrt(2))).reshape(scores.size, 1), scale=1/np.sqrt(2))

In [None]:
mean_sf = np.mean(y_rice, axis=0)

In [None]:
plt.plot(y_rayleigh, mean_sf)

plt.xscale('log')
#plt.yscale('log')
#plt.xlim(1e-3, 1)
#plt.ylim(0.9, 1)

In [None]:
score_sample = scipy.stats.rice.rvs(scores, size=scores.size)

In [None]:
len(np.argwhere(score_sample >= 3.4)) / len(scores)

In [None]:
print(scores.mean())

# plot MF ROC curves vs temp

In [None]:
sns.set_theme(context='poster')

fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

x = np.linspace(0, 20, 301)
y_rayleigh = scipy.stats.rice.sf(x, 0, scale=1/np.sqrt(2))

for temp in [1, 3, 5, 10, 13]:
    print(temp)

    var = 1.38e-23 * temp * 50 * 60 * 200e6

    templates = data_slice * 1 / np.sqrt(var * (abs(data_slice) ** 2).sum(axis=-1) ).reshape(data_slice.shape[0],1)

    scores = abs(np.sum(data_slice * templates.conjugate(), axis=-1)) * 0.3
    
    y_rice = scipy.stats.rice.sf(x_rice, scores.reshape(scores.size, 1), (scores * (1-1/np.sqrt(2))).reshape(scores.size, 1), scale=1/np.sqrt(2))
    
    mean_sf = np.mean(y_rice, axis=0)
    
    ax.plot(y_rayleigh, mean_sf, label=f'{temp} K')
    
plt.legend(title='Noise Temp.', loc=4)
ax.set_xscale('log')
ax.set_xlim(1e-4, 1)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Ideal MF ROC Curves')

#plt.savefig(os.path.join(PATH, 'plots/mayfly', '220107_idealized_mf_roc_curve'))

In [None]:
file['train']['data'].shape

In [None]:
plt.plot(file['train']['data'][0, 0, :])

In [None]:
plt.plot(file['train']['data'][-1, 0, :])