In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""

def SumSignal(signal, freq, radius):
    
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = signal.shape[-1] // nch
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 12 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt(
        (
            x_antenna.reshape((x_antenna.size, 1))
            - x_electron.reshape((1, x_electron.size))
        ) ** 2
        + (
            y_antenna.reshape((y_antenna.size, 1))
            - y_electron.reshape((1, y_electron.size))
        ) ** 2
    )
    
    correction = 1 * np.exp( 1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna))

    
    return (signal.reshape(nch, signal.size // nch) * correction).sum(axis=0)
    
def SaveSummedDataset(data, metadata, name):
    
    savefile = h5py.File(name, 'w')
    
    dataset = savefile.create_dataset('data', data=data)
    
    metagroup = savefile.create_group('meta')
    
    for i,key in enumerate(metadata.keys()):
        
        metagroup.create_dataset(key, data = np.array(metadata[key].array))
        
    savefile.close()
    
def SumDataset2(MFdata, gradb_freq_grid, r_grid, theta_grid):
    
    metadata = pd.DataFrame(MFdata.metadata)
    data_shape = MFdata.data.shape
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = data_shape[-1] // nch
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    #gradb_freq_grid = gradb_freq_grid.mean(axis = 0).reshape((1, gradb_freq_grid.shape[1])).repeat(gradb_freq_grid.shape[0], axis=0)
    
    grad_b_angles = 12 * 2 * np.pi * np.arange(0, nsample, 1).reshape((1, 1, nsample)) * gradb_freq_grid.reshape((*gradb_freq_grid.shape, 1)) / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = r_grid
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron.reshape((*r_electron.shape, 1)) * np.cos(theta_electron)
    y_electron = r_electron.reshape((*r_electron.shape, 1)) * np.sin(theta_electron)
    
    d_grad_b = np.sqrt(
        (
            x_antenna.reshape((x_antenna.size, 1, 1, 1))
            - x_electron.reshape((1, x_electron.shape[0], x_electron.shape[1], x_electron.shape[2]))
        ) ** 2
        + (
            y_antenna.reshape((y_antenna.size, 1, 1, 1))
            - y_electron.reshape((1, y_electron.shape[0], y_electron.shape[1], y_electron.shape[2]))
        ) ** 2
    )
    
    correction = CalculateChannelWeights(d_grad_b, r_grid) * np.exp( -1j * 2 * np.pi * (d_grad_b) / wavelength_lo - AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna))
    #correction = 1 * np.exp( -1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna))
    data = np.swapaxes(MFdata.data[:].reshape(data_shape[0], nch, data_shape[-1] // nch), 0, 1)
    sort_correction = np.zeros(data.shape, dtype=np.complex64)
    
    for i, pair in enumerate(zip(metadata['x_min'].array, metadata['theta_min'].array)):
        
        index = np.argwhere(np.logical_and(r_grid == pair[0], theta_grid == pair[1])).squeeze()
        sort_correction[:, i, :] = correction[:, index[0], index[1], :]
    
    return np.sum(sort_correction * data , axis=0)
    
    
def SumDataset(MFdata, radius, gradb_freq_grid):
    
    metadata = pd.DataFrame(MFdata.metadata)
    data_shape = MFdata.data.shape
    
    summed_data = np.zeros((data_shape[0], data_shape[-1] // 60), dtype=np.complex64)
    
    summed_indexes = np.zeros(data_shape[0])
    
    pitch_angles = np.sort(metadata['theta_min'].unique())
    #print(pitch_angles)
    
    total_num_summed = 0
    
    for i, angle in enumerate(pitch_angles):
        
        inds = np.array(metadata[metadata['theta_min'] == angle].index.array)

        signal_subset = MFdata.data[inds, :]
        
        nsignal = signal_subset.shape[0]
        
        gradb_freq = InterpolateGradB(radius, angle, gradb_freq_grid)
        
        summed_signals = ShiftAndSum(signal_subset, radius, gradb_freq)
        
        summed_data[total_num_summed:total_num_summed+nsignal, :] = summed_signals
        summed_indexes[total_num_summed:total_num_summed+nsignal] = inds
        
        total_num_summed += nsignal
        
        if i % 5 == 4:
            print(f'{i+1}/{len(pitch_angles)}')
          
    resorted_metadata = metadata.iloc[summed_indexes]
    
        
    return summed_data, resorted_metadata

def InterpolateGradB(radius, pitch_angle, gradb_freq_grid):
    
    radii = gradb_freq_grid['radii']
    angles = gradb_freq_grid['angles']
    gradb_data = gradb_freq_grid['freq']
    
    interpolator = scipy.interpolate.interp2d(radii, angles, gradb_data)
    
    interpolated_freq = interpolator(radius, pitch_angle)
    
    return interpolated_freq

def ShiftAndSum(signal_subset, radius, freq):
    
    nch = 60
    signal_subset = signal_subset.reshape((signal_subset.shape[0], nch, signal_subset.shape[-1] // 60))
    nsample = signal_subset.shape[-1]
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 1 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt((x_antenna.reshape((x_antenna.size, 1)) - x_electron.reshape((1, x_electron.size))) ** 2 + (y_antenna.reshape((y_antenna.size, 1)) - y_electron.reshape((1, y_electron.size))) ** 2)
    phase_shift = -2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna)
    
    shifted_signal_subset = signal_subset * np.exp(-1j * phase_shift).reshape((1, *phase_shift.shape))
    
    
    return WeightedChannelSum(shifted_signal_subset, radius, d_grad_b)


def AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1, 1, 1))
        - y_electron.reshape((1, y_electron.shape[0], y_electron.shape[1], y_electron.shape[2])),
        x_antenna.reshape((x_antenna.size, 1, 1, 1))
        - x_electron.reshape((1, x_electron.shape[0], x_electron.shape[1], x_electron.shape[2]))
    )
    
    return angles

def AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1,))
        - y_electron.reshape((1, y_electron.size)),
        x_antenna.reshape((x_antenna.size, 1,))
        - x_electron.reshape((1, x_electron.size))
    )
    
    return angles

def WeightedChannelSum(shifted_signals, radius, d_grad_b):
    
    if radius == 0:
        weights = weights_norm = np.ones(d_grad_b.shape)
        normed_weights = weights
    else:
        weights = (radius / (d_grad_b)) 
        weights_norm = 60 / weights.sum(axis=0) # normalization, previously all weights were 1 and thus would've summed to 60
        normed_weights = weights_norm.reshape((1, weights_norm.size)) * weights
        #normed_weights = weights
    
    return (shifted_signals * normed_weights.reshape(1, *normed_weights.shape)).sum(axis=1)
    
def CalculateChannelWeights(d_grad_b, r_grid):
    
    channel_weights = np.ones(d_grad_b.shape)
    
    r_grid_is_not_zero = r_grid != 0
    
    weights = r_grid[r_grid_is_not_zero].reshape((1, r_grid_is_not_zero.sum(), 1)) / d_grad_b[:, r_grid_is_not_zero, : ]
    weights_norm = 60 / weights.sum(axis=0)
    normed_weights = weights_norm.reshape((1, *weights_norm.shape)) * weights

    counter = 0
    #print(normed_weights.shape)
    for i in range(r_grid_is_not_zero.shape[0]):
        for j in range(r_grid_is_not_zero.shape[1]):
            if r_grid_is_not_zero[i, j]:
                channel_weights[:, i, j, :] = normed_weights[:, counter, :]
                counter += 1
                #print(counter)
                
    return channel_weights
    
    
    

In [None]:
os.listdir(os.path.join(DATAPATH))

# load data

In [None]:
# signal data
h5file = h5py.File(os.path.join(DATAPATH, '211116_grad_b_est.h5'), 'r')
data = mf.data.MFDataset(os.path.join(DATAPATH, '211116_grad_b_est.h5'))
metadata = pd.DataFrame(data.metadata)

# grad-b correction data
gradb_freq_grid = np.load(os.path.join(PATH, 'results/mayfly', '211129_grad_b_frequency_grid_radius_angle.npz'))

In [None]:
data.data.shape

In [None]:
print(np.argwhere(metadata['theta_min'].array == 88))
print(np.argwhere(metadata['x_min'].array == 0.00))

In [None]:
gradb_freq_grid['freq'].shape

In [None]:
signal = data.data[1450, :]

In [None]:
r_grid, theta_grid = np.meshgrid(gradb_freq_grid['radii'], gradb_freq_grid['angles'])

theta_signal = metadata['theta_min'].iloc[1450]
r_signal = metadata['x_min'].iloc[1450]

idx = np.argwhere(np.logical_and(r_grid.flatten() == r_signal, theta_grid.flatten() == theta_signal)).squeeze()
freq = gradb_freq_grid['freq'].flatten()[idx]

sum_signal = SumSignal(signal, freq, r_signal)

In [None]:
#no_sum_signal = no_sum_data[6039, :, :].flatten()
var = 1.38e-23 * 10 * 50 * 200e6
norm = 1 / np.sqrt(var * np.vdot(signal, signal))

score = abs(np.vdot(signal, signal) * norm)
print(score)

In [None]:
var = 1.38e-23 * 10 * 50 * 200e6 * 60 * np.sqrt(60)

norm = 1 / np.sqrt(var * np.vdot(sum_signal, sum_signal))

score = abs(np.vdot(sum_signal, sum_signal) * norm)
print(score)

In [None]:
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot((abs(np.fft.fft(sum_signal) / (60 * 8192))))
#ax.plot(summed_data[6040, :].imag)

#ax.set_xlim(0, 200)

In [None]:
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot(abs(np.fft.fft(signal.reshape(60, 8192)[0, :]) / 8192))
#ax.plot(no_sum_data[0, 0, :].imag)

#ax.set_xlim(0, 200)

In [None]:


summed_data = SumDataset2(data, gradb_freq_grid['freq'], r_grid, theta_grid)

In [None]:
summed_data.shape

In [None]:
norm_summed_data = 1 / np.sqrt(10 * np.sqrt(60) * 60 * 1.38e-23 * 50 * 200e6 * abs(summed_data.conjugate() * summed_data).sum(axis = -1))

In [None]:
summed_scores = abs(norm_summed_data.reshape((6070, 1)) * summed_data.conjugate() * summed_data).sum(axis=-1)

In [None]:
summed_scores

In [None]:
radial_positions = np.unique(metadata['x_min'][:])
pitch_angles = np.unique(metadata['theta_min'][:])

rad_grid, angle_grid = np.meshgrid(radial_positions, pitch_angles)
    
sorted_summed_scores = np.zeros(rad_grid.size)

for i, pair in enumerate(zip(rad_grid.flatten(), angle_grid.flatten())):
    #print(np.argwhere(h5file['meta']['x_min'][:].array == pair[0]))
    #print(np.argwhere(h5file['meta']['theta_min'][:] == pair[1]))
    try:
        index = np.argwhere(np.logical_and(h5file['meta']['x_min'][:] == pair[0], h5file['meta']['theta_min'][:] == pair[1])).squeeze()
        #print(index)
        sorted_summed_scores[i] = summed_scores[index]
    except BaseException as err:
        #print(err)
        pass

In [None]:
plt.imshow(sorted_summed_scores.reshape(rad_grid.shape), aspect='auto', interpolation='none')
cbar = plt.colorbar()

# save the summed mf scores

In [None]:
save_path = '/storage/home/adz6/group/project/results/beamforming/time_dependent'
name = '220218_time_dependent_bf_exact_correction_mf_scores'
np.savez(
    os.path.join(save_path, name),
    power=sorted_summed_scores.reshape(rad_grid.shape),
    radius = rad_grid,
    pitch = angle_grid,
)

In [None]:
summed_power = (np.mean(abs(summed_data) ** 2, axis = -1) / (50 * 60 * np.sqrt(60))) 
#summed_power = (np.mean(abs(summed_data) ** 2, axis = -1) / (1)) 

# sort the summed power

In [None]:
radial_positions = np.unique(metadata['x_min'][:])
pitch_angles = np.unique(metadata['theta_min'][:])

rad_grid, angle_grid = np.meshgrid(radial_positions, pitch_angles)
    
sorted_summed_pow = np.zeros(rad_grid.size)

for i, pair in enumerate(zip(rad_grid.flatten(), angle_grid.flatten())):
    #print(np.argwhere(h5file['meta']['x_min'][:].array == pair[0]))
    #print(np.argwhere(h5file['meta']['theta_min'][:] == pair[1]))
    try:
        index = np.argwhere(np.logical_and(h5file['meta']['x_min'][:] == pair[0], h5file['meta']['theta_min'][:] == pair[1])).squeeze()
        #print(index)
        sorted_summed_pow[i] = summed_power[index]
    except BaseException as err:
        #print(err)
        pass

In [None]:
plt.imshow(sorted_summed_pow.reshape(rad_grid.shape), aspect='auto', )
cbar = plt.colorbar()

# save the summed signal powers

In [None]:
save_path = '/storage/home/adz6/group/project/results/beamforming/time_dependent'
name = '220218_time_dependent_bf_ex_total_signal_power'
np.savez(
    os.path.join(save_path, name),
    power=sorted_summed_pow.reshape(rad_grid.shape),
    radius = rad_grid,
    pitch = angle_grid,
)

In [None]:
sum_data_fft = np.fft.fft(summed_data, axis=-1)

#pow_spectrum = (abs(sum_data_fft / (8192)) ** 2) / (50 * 60 * np.sqrt(60))
pow_spectrum = (abs(sum_data_fft / (8192)) ** 2) / (1)

pow_spectrum_max = np.max(pow_spectrum, axis=-1)

print(pow_spectrum_max.shape)

# sort the power spectrum max

In [None]:
radial_positions = np.unique(metadata['x_min'][:])
pitch_angles = np.unique(metadata['theta_min'][:])

rad_grid, angle_grid = np.meshgrid(radial_positions, pitch_angles)
    
sorted_summed_maxima = np.zeros(rad_grid.size)

for i, pair in enumerate(zip(rad_grid.flatten(), angle_grid.flatten())):
    #print(np.argwhere(h5file['meta']['x_min'][:].array == pair[0]))
    #print(np.argwhere(h5file['meta']['theta_min'][:] == pair[1]))
    try:
        index = np.argwhere(np.logical_and(h5file['meta']['x_min'][:] == pair[0], h5file['meta']['theta_min'][:] == pair[1])).squeeze()
        #print(index)
        sorted_summed_maxima[i] = pow_spectrum_max[index]
    except BaseException as err:
        #print(err)
        pass

In [None]:
plt.imshow(sorted_summed_maxima.reshape(rad_grid.shape), aspect='auto', interpolation='none')
cbar = plt.colorbar()

# save summed maxima grid

In [None]:
save_path = '/storage/home/adz6/group/project/results/beamforming/time_dependent'
name = '220218_time_dependent_bf_pitch_average_fft_spetrum_max_voltage_square'
np.savez(
    os.path.join(save_path, name),
    power=sorted_summed_maxima.reshape(rad_grid.shape),
    radius = rad_grid,
    pitch = angle_grid,
)

In [None]:
carrier_power_grid = np.zeros(r_grid.shape).flatten()

for i, pair in enumerate(zip(r_grid.flatten(), theta_grid.flatten())):
    idx = np.argwhere(np.logical_and(metadata['theta_min'].array == pair[1], metadata['theta_min'].array == pair[1])).squeeze()
    
    carrier_power_grid[i] = pow_spectrum_max[idx[0]]


In [None]:
# This looks different than the plots i've shown before, 
# I should check this using just the average grad-b frequency for a specific radius.

In [None]:
plt.imshow(carrier_power_grid.reshape(r_grid.shape).T, aspect='auto', interpolation='none')
plt.colorbar()


In [None]:
print(carrier_power_grid.max())

In [None]:
path_save = '/storage/home/adz6/group/project/datasets/data/bf/220216_gradb_est_summed_data_needs_check'

In [None]:
np.save(path_save, summed_data)

In [None]:
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot((np.fft.fftshift(abs(np.fft.fft(summed_data[135, :]) / (60 * 8192 * 50))) ** 2))
#ax.plot(summed_data[6040, :].imag)

#ax.set_xlim(0, 200)

In [None]:
no_sum_data = data.data[:].reshape(6070, 60, 8192)

In [None]:
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot(np.fft.fftshift(abs(np.fft.fft(no_sum_data[135, 0, :]) / (8192*50)) ** 2))
#ax.plot(no_sum_data[0, 0, :].imag)

#ax.set_xlim(0, 200)

In [None]:
np.sum(abs(no_sum_data[94, :, :] ** 2))

In [None]:
energy = np.sum(abs(summed_data[94, :]) ** 2)

In [None]:
energy / (50)

In [None]:
no_sum_signal = no_sum_data[94, :, :].flatten()
var = 1.38e-23 * 10 * 50 * 200e6
norm = 1 / np.sqrt(var * np.vdot(no_sum_signal, no_sum_signal))

score = abs(np.vdot(no_sum_signal, no_sum_signal) * norm)
print(score)

In [None]:
sum_signal = summed_data[94, :]
var = 1.38e-23 * 10 * 50 * 200e6 * 60 * np.sqrt(60)

norm = 1 / np.sqrt(var * np.vdot(sum_signal, sum_signal))

score = abs(np.vdot(sum_signal, sum_signal) * norm)
print(score)