In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""

def SaveSummedDataset(data, metadata, name):
    
    savefile = h5py.File(name, 'w')
    
    dataset = savefile.create_dataset('data', data=data)
    
    metagroup = savefile.create_group('meta')
    
    for i,key in enumerate(metadata.keys()):
        
        metagroup.create_dataset(key, data = np.array(metadata[key].array))
        
    savefile.close()
    
    

def SumDataset(MFdata, radius, gradb_freq_grid):
    
    metadata = pd.DataFrame(MFdata.metadata)
    data_shape = MFdata.data.shape
    
    summed_data = np.zeros((data_shape[0], data_shape[-1] // 60), dtype=np.complex64)
    
    summed_indexes = np.zeros(data_shape[0])
    
    pitch_angles = np.sort(metadata['theta_min'].unique())
    #print(pitch_angles)
    
    total_num_summed = 0
    
    for i, angle in enumerate(pitch_angles):
        
        inds = np.array(metadata[metadata['theta_min'] == angle].index.array)

        signal_subset = MFdata.data[inds, :]
        
        nsignal = signal_subset.shape[0]
        
        gradb_freq = InterpolateGradB(radius, angle, gradb_freq_grid)
        
        summed_signals = ShiftAndSum(signal_subset, radius, gradb_freq)
        
        summed_data[total_num_summed:total_num_summed+nsignal, :] = summed_signals
        summed_indexes[total_num_summed:total_num_summed+nsignal] = inds
        
        total_num_summed += nsignal
        
        if i % 5 == 4:
            print(f'{i+1}/{len(pitch_angles)}')
        
        
    resorted_metadata = metadata.iloc[summed_indexes]
    
    
        
    return summed_data, resorted_metadata

def InterpolateGradB(radius, pitch_angle, gradb_freq_grid):
    
    radii = gradb_freq_grid['radii']
    angles = gradb_freq_grid['angles']
    gradb_data = gradb_freq_grid['freq']
    
    interpolator = scipy.interpolate.interp2d(radii, angles, gradb_data)
    
    interpolated_freq = interpolator(radius, pitch_angle)
    
    return interpolated_freq

def ShiftAndSum(signal_subset, radius, freq):
    
    nch = 60
    signal_subset = signal_subset.reshape((signal_subset.shape[0], nch, signal_subset.shape[-1] // 60))
    nsample = signal_subset.shape[-1]
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    
    # i don't know where this factor of 12 comes from????
    # this gradb correction seems to give optimal results however
    grad_b_angles = 12 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    #print(x_grad_b.shape, y_grad_b.shape)
    
    d_grad_b = np.sqrt((x_antenna.reshape((x_antenna.size, 1)) - x_electron.reshape((1, x_electron.size))) ** 2 + (y_antenna.reshape((y_antenna.size, 1)) - y_electron.reshape((1, y_electron.size))) ** 2)
    #print(d_grad_b[:, 0:5])
    phase_shift = -2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna)
    
    
    #plt.figure()
    #plt.plot(np.unwrap(np.angle(signal_subset[0, 29, :])) - np.unwrap(np.angle(signal_subset[0, 0, :])))
    #plt.plot(np.unwrap(np.angle(signal_subset[0, 1, :])))
    #plt.xlim(7500, 7520)
    ###plt.ylim(8080, 8100)
    
    
    plt.figure()
    plt.plot(signal_subset[0, 0, :].real)
    plt.plot(signal_subset[0, 20, :].real)
    plt.plot(signal_subset[0, 40, :].real)
    plt.xlim(0, 100)
    
    shifted_signal_subset = signal_subset * np.exp(-1j * phase_shift).reshape((1, *phase_shift.shape))
    
    #plt.figure()
    #plt.plot(np.unwrap(np.angle(shifted_signal_subset[0, 29, :])) - np.unwrap(np.angle(shifted_signal_subset[0, 0, :])))
    #plt.plot(np.unwrap(np.angle(signal_subset[0, 1, :])))
    #plt.xlim(7500, 7520)
    ###plt.ylim(8080, 8100)
    
    plt.figure()
    plt.plot(shifted_signal_subset[0, 0, :].real)
    plt.plot(shifted_signal_subset[0, 20, :].real)
    plt.plot(shifted_signal_subset[0, 40, :].real)
    plt.xlim(0, 100)
    
    plt.figure()
    plt.plot(shifted_signal_subset[0, 0, :].real)
    plt.plot(shifted_signal_subset[0, 20, :].real)
    plt.plot(shifted_signal_subset[0, 40, :].real)
    plt.xlim(7000, 7100)
    
    amplitudes = np.zeros(60)
    for i in range(60):
        amplitudes[i] = np.max(abs(shifted_signal_subset[0, i, :]))
        
    #plt.figure()
    #plt.plot(d_grad_b[:, 0], amplitudes)
    #plt.xscale('log')
    
    #WeightedChannelSum(shifted_signal_subset, radius, d_grad_b)
    
    return WeightedChannelSum(shifted_signal_subset, radius, d_grad_b)#


def AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(y_antenna.reshape((y_antenna.size, 1)) - y_electron.reshape((1, y_electron.size)), 
                        x_antenna.reshape((x_antenna.size, 1)) - x_electron.reshape((1, x_electron.size)))
    
    
    return angles

def WeightedChannelSum(shifted_signals, radius, d_grad_b):
    
    if radius == 0:
        weights = weights_norm = np.ones(d_grad_b.shape)
        normed_weights = weights
    else:
        weights = (radius / (d_grad_b)) 
        weights_norm = 60 / weights.sum(axis=0) # normalization, previously all weights were 1 and thus would've summed to 60
        normed_weights = weights_norm.reshape((1, weights_norm.size)) * weights
        #normed_weights = weights
    
    return (shifted_signals * normed_weights.reshape(1, *normed_weights.shape)).sum(axis=1)
    
    
    
    


In [None]:
os.listdir(os.path.join(DATAPATH))

In [None]:
os.listdir(os.path.join(DATAPATH))

# load data

In [None]:
# signal data
data = mf.data.MFDataset(os.path.join(DATAPATH, '211116_grad_b_est.h5'))
metadata = pd.DataFrame(data.metadata)

# grad-b correction data
gradb_freq_grid = np.load(os.path.join(PATH, 'results/mayfly', '211129_grad_b_frequency_grid_radius_angle.npz'))

In [None]:
data.data.shape

In [None]:
metadata[(metadata['x_min'] == 0.0)]['theta_min']

In [None]:
angle = 84.5
rad = 0.00

ind = metadata[(metadata['x_min'] == rad) & (metadata['theta_min'] == angle)].index[0]

signal = data.data[ind, :]

var_no_sum = 1.38e-23 * 10 * 50 * 200e6
var_sum = 60 * var_no_sum

norm_no_sum = 1 / np.sqrt(var_no_sum * np.vdot(signal, signal))

template_no_sum = norm_no_sum * signal

score_no_sum = abs(np.vdot(template_no_sum, signal))

print(score_no_sum)

In [None]:
gradb_freq = InterpolateGradB(rad, angle, gradb_freq_grid)
print(gradb_freq)
sum_signal = ShiftAndSum(signal.reshape(1, signal.size), rad, gradb_freq)

norm_sum = 1 / np.sqrt(var_sum * np.vdot(sum_signal, sum_signal))

template_sum = norm_sum * sum_signal

score_sum = abs(np.vdot(template_sum, sum_signal))

print(score_sum)

In [None]:
summed_data, summed_metadata = SumDataset(data, 0.0, gradb_freq_grid)

In [None]:
name = os.path.join(DATAPATH, 'bf', '211130_sens_est_dense_grid_84.5_0cm_sum.h5')
SaveSummedDataset(summed_data, summed_metadata, name)


In [None]:
summed_metadata['theta_min'].unique()