In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate
import math
from pathlib import Path

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""


def ShiftAndSum(signal, x_range, y_range, freq):
    
    nch = signal.shape[0]
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    wavelength_lo = 3e8 / 25.86e9
    Nsample = 8192
    fsample = 200e6
    
    grad_b_angles = 2 * np.pi * np.arange(0, Nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    #x_antenna = x_antenna.reshape((x_antenna.size, 1)).repeat(signal.shape[-1], axis = -1)
    #y_antenna = y_antenna.reshape((y_antenna.size, 1)).repeat(signal.shape[-1], axis = -1)
    
    x_grid, y_grid = np.meshgrid(x_range, y_range)
    
    x_grid = x_grid.flatten()
    y_grid = y_grid.flatten()
    
    r_grid = np.sqrt(x_grid ** 2 + y_grid ** 2)
    
    theta_grid = np.arctan2(y_grid, x_grid)
    theta_grid_grad_b = theta_grid.reshape(theta_grid.size, 1) + grad_b_angles
    #theta_grid_grad_b = theta_grid.reshape((x_range.size, y_range.size, 1)).repeat(Nsample, axis=-1) + grad_b_phases
    
    #print(r_grid.shape, theta_grid_grad_b.shape, r_array, angles.shape)
    
    #delta_theta = angles.reshape((angles.size, 1, 1)) - theta_grid_grad_b.reshape((1, *theta_grid_grad_b.shape)) 
    
    x_grad_b = r_grid.reshape((r_grid.size, 1)) * np.cos(theta_grid_grad_b)
    y_grad_b = r_grid.reshape((r_grid.size, 1)) * np.sin(theta_grid_grad_b)
    
    #print(x_grad_b.shape, y_grad_b.shape)
    
    d_grad_b = np.sqrt((x_antenna.reshape((x_antenna.size, 1, 1)) - x_grad_b.reshape((1, *x_grad_b.shape))) ** 2 + (y_antenna.reshape((y_antenna.size, 1, 1)) - y_grad_b.reshape((1, *y_grad_b.shape))) ** 2)
    
    #source_angle = -1 * 2 * np.pi * np.arange(0, Nsample, 1) * freq / fsample
    
    #r_new = np.sqrt(x_new ** 2 + y_new ** 2)
    
    #x_new = r_new * np.cos(source_angle).reshape((1, Nsample)).repeat(nch, axis=0)
    #y_new = r_new * np.sin(source_angle).reshape((1, Nsample)).repeat(nch, axis=0)
    
    #d_old = np.sqrt(x_antenna ** 2 + y_antenna ** 2)
    #d_new = np.sqrt((x_new - x_antenna) ** 2 + (y_new - y_antenna) ** 2)
    
    phase_shift = 2 * np.pi * (d_grad_b) / wavelength_lo
    #print(phase_shift.shape, signal.shape)
    
    shifted_signal = np.exp(-1j * (phase_shift + angles.reshape((angles.size, 1, 1)))) * signal.reshape((signal.shape[0], 1, signal.shape[-1]))
    
    return shifted_signal.sum(axis=0)

def SumSignal(signal, freq, radius):
    
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = signal.shape[-1] // nch
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 1 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt(
        (
            x_antenna.reshape((x_antenna.size, 1))
            - x_electron.reshape((1, x_electron.size))
        ) ** 2
        + (
            y_antenna.reshape((y_antenna.size, 1))
            - y_electron.reshape((1, y_electron.size))
        ) ** 2
    )
    
    correction = CalculateChannelWeightsSingleRad(d_grad_b, radius) * np.exp( 1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna))

    
    return (signal.reshape(nch, signal.size // nch) * correction).sum(axis=0)

def SumSignalGrid(signal, freq, x_grid, y_grid):
    
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = signal.shape[-1]
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    x_grid, y_grid = x_grid.flatten(), y_grid.flatten()
    
    r_grid = np.sqrt(x_grid ** 2 + y_grid ** 2)
    theta_grid = np.arctan2(-y_grid, x_grid)
    theta_grid_grad_b = theta_grid.reshape(theta_grid.size, 1) + grad_b_angles
    x_grad_b = r_grid.reshape((r_grid.size, 1)) * np.cos(theta_grid_grad_b)
    y_grad_b = r_grid.reshape((r_grid.size, 1)) * np.sin(theta_grid_grad_b)
    
    d_grad_b = np.sqrt(
        (x_antenna.reshape((x_antenna.size, 1, 1))
         - x_grad_b.reshape((1, *x_grad_b.shape))
        ) ** 2
        + (y_antenna.reshape((y_antenna.size, 1, 1))
           - y_grad_b.reshape((1, *y_grad_b.shape))
          ) ** 2
    )
    
    antispiral_angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1, 1))
        - y_grad_b.reshape((1, *y_grad_b.shape)),
        x_antenna.reshape((x_antenna.size, 1, 1))
        - x_grad_b.reshape((1, *x_grad_b.shape)),
    )
    
    phase_shift = 2 * np.pi * (d_grad_b) / wavelength_lo 
    print(phase_shift.shape, signal.shape)
    
    shifted_signal = np.exp(1j * (-phase_shift - antispiral_angles)) * signal.reshape((signal.shape[0], 1, signal.shape[-1]))
    
    return shifted_signal.sum(axis=0)


def AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1, 1))
        - y_electron.reshape((1, *y_electron.shape)),
        x_antenna.reshape((x_antenna.size, 1, 1))
        - x_electron.reshape((1, *x_electron.shape))
    )
    
    return angles

def AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1,))
        - y_electron.reshape((1, y_electron.size)),
        x_antenna.reshape((x_antenna.size, 1,))
        - x_electron.reshape((1, x_electron.size))
    )
    
    return angles

def WeightedChannelSum(shifted_signals, radius, d_grad_b):
    
    if radius == 0:
        weights = weights_norm = np.ones(d_grad_b.shape)
        normed_weights = weights
    else:
        weights = (radius / (d_grad_b)) 
        weights_norm = 60 / weights.sum(axis=0) # normalization, previously all weights were 1 and thus would've summed to 60
        normed_weights = weights_norm.reshape((1, weights_norm.size)) * weights
        #normed_weights = weights
    
    return (shifted_signals * normed_weights.reshape(1, *normed_weights.shape)).sum(axis=1)
    
def CalculateChannelWeights(d_grad_b, r_grid):
    
    channel_weights = np.ones(d_grad_b.shape)
    
    r_grid_is_not_zero = r_grid != 0
    
    weights = r_grid[r_grid_is_not_zero].reshape((1, r_grid_is_not_zero.sum(), 1)) / d_grad_b[:, r_grid_is_not_zero, : ]
    weights_norm = 60 / weights.sum(axis=0)
    normed_weights = weights_norm.reshape((1, *weights_norm.shape)) * weights

    counter = 0
    #print(normed_weights.shape)
    for i in range(r_grid_is_not_zero.shape[0]):
        for j in range(r_grid_is_not_zero.shape[1]):
            if r_grid_is_not_zero[i, j]:
                channel_weights[:, i, j, :] = normed_weights[:, counter, :]
                counter += 1
                #print(counter)
                
    return channel_weights

def CalculateChannelWeightsSingleRad(d_grad_b, radius):
    
    channel_weights = np.ones(d_grad_b.shape)
    
    if radius == 0:
        return channel_weights
    else:
        weights = radius / d_grad_b
        weights_norm = 60 / weights.sum(axis=0)
        normed_weights = weights_norm.reshape((1, *weights_norm.shape)) * weights
        return normed_weights
    
    

In [None]:
for i in (Path.home()/'group'/'project'/'datasets'/'data').iterdir(): print(i)

# load data

In [None]:
# signal data
#h5file = h5py.File(os.path.join(DATAPATH, 'dense_template_random', '220124_sens_est_dense_grid_87.0_3cm_random.h5'), 'r')
data = mf.data.MFDataset(os.path.join(DATAPATH, '211116_grad_b_est.h5'))
metadata = pd.DataFrame(data.metadata)

radial_position = np.array(metadata['x_min'].array)
pitch_angle = np.array(metadata['theta_min'].array)



# grad-b correction data
gradb_freq_grid = np.load(os.path.join(PATH, 'results/mayfly', '211129_grad_b_frequency_grid_radius_angle.npz'))

In [None]:
pitch_plot = 90
rad_pos_plot = 0.015


r_grid, pa_grid = np.meshgrid(gradb_freq_grid['radii'], gradb_freq_grid['angles'])

freq_ind = np.argwhere(np.logical_and(r_grid == rad_pos_plot, pa_grid == pitch_plot)).squeeze()
gradb_freq = gradb_freq_grid['freq'][freq_ind[0], freq_ind[1]]
print(gradb_freq)

ind = np.argwhere(np.logical_and(radial_position == rad_pos_plot, pitch_angle == pitch_plot)).squeeze()

signal = data.data[ind, :].reshape((60, 8192))


In [None]:
grid_size = 0.04
n_grid = 81

coord_array = np.linspace(-grid_size, grid_size, n_grid)
x_grid, y_grid = np.meshgrid(coord_array, coord_array)

bf_grid = SumSignalGrid(signal, gradb_freq * 10, x_grid, y_grid)

#sum_signal = ShiftAndSum(signal, coord_array, coord_array, 0)


In [None]:
bf_grid = np.flip(bf_grid, axis=0)

In [None]:
data.data.shape

In [None]:
plt.imshow((abs(bf_grid.reshape((*x_grid.shape, 8192))) ** 2).mean(axis=-1))
plt.colorbar()

In [None]:
print(bf_grid.shape)

In [None]:
sns.set_theme(context = 'talk', style='ticks')
cmap = sns.color_palette('mako_r', as_cmap=True)
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

img = ax.imshow(
    (abs(bf_grid.reshape((*x_grid.shape, 8192))) ** 2).mean(axis=-1) / (50 * 60),
    cmap=cmap,
    extent=(-0.05, 0.05, 0.05, -0.05),
    
    
)
cbar = fig.colorbar(img, label='Power (W)')
ax.set_xlabel('X-Position (m)')
ax.set_ylabel('Y-Position (m)')

plt.tight_layout()
name = f'220303_beamforming_map_89deg_offaxis{int(100*rad_pos_plot)}cm_corrected.png'
save_path = Path.home()/'group'/'project'/'plots'/'analysis'/'beamforming'/'maps'

#plt.savefig(save_path/name)

In [None]:
sns.set_theme(context = 'talk', style='whitegrid')
cmap = sns.color_palette('mako_r', as_cmap=True)
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)
time = np.arange(0, 8192, 1) / 200e6

ax.plot(time, bf_grid.reshape((*x_grid.shape, 8192))[10, 10, :].real, label='Beamformed Signal')
ax.plot(time, signal[0, :].real, label='Single Channel')
ax.set_xlim(time[0], time[256])
ax.set_xlabel('Time (s)')
ax.set_ylabel('V')
ax.legend(loc=1)

plt.tight_layout()
name = f'220302_beamformed_signal_amplitude_comparison.png'
save_path = Path.home()/'group'/'project'/'plots'/'analysis'/'beamforming'/'time_series'
save_path.mkdir(parents=True, exist_ok=True)
#plt.savefig(save_path/name)




In [None]:
sns.set_theme(context = 'talk', style='darkgrid')
clist = sns.color_palette('deep', n_colors=10)
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)

ax.plot(abs(np.fft.fftshift(np.fft.fft(bf_grid.reshape((*x_grid.shape, 8192))[10, 10, :]) / (8192 * 50)) ** 2))

ax.plot(abs(np.fft.fftshift(np.fft.fft(signal[0, :]) / (8192 * 50))) ** 2,)
ax.set_yscale('log')

In [None]:
for i,key in enumerate(gradb_freq_grid.keys()): print(key)

In [None]:
print(gradb_freq_grid['freq'].shape)

In [None]:
r_grid, pa_grid = np.meshgrid(gradb_freq_grid['radii'], gradb_freq_grid['angles'])

freq_ind = np.argwhere(np.logical_and(r_grid == 0.03, pa_grid == 87.0)).squeeze()
gradb_freq = gradb_freq_grid['freq'][freq_ind[0], freq_ind[1]]
print(gradb_freq)

In [None]:
summed_data = SumDatasetSingleFrequency(data, gradb_freq, 0.03)

In [None]:
save_path = Path.home()/'group'/'project'/'datasets'/'data'/'bf'
name = '220301_dense_grid_87.0deg_3cm_random.npy'

np.save(save_path/name,summed_data)

In [None]:
for i in save_path.iterdir(): print(i)

In [None]:
summed_data.shape

In [None]:
plt.plot(summed_data[0, :].real)
plt.xlim(0, 100)

In [None]:
plt.plot(data.data[0, :].reshape(60, (3 * 8192 * 60) // 60)[0, 0:8192].real)
plt.xlim(0, 100)

In [None]:
signal1 = summed_data[5000, :]
signal2 = data.data[5000, :].reshape(60, (3 * 8192 * 60) // 60)[:, 0:8192].flatten()

print(np.sqrt(np.vdot(signal1, signal1) / np.vdot(signal2, signal2)))

In [None]:
plt.plot((abs(np.fft.fft(signal1) / 8192) ** 2 ) / (50 * 60 * np.sqrt(60)))

In [None]:
plt.plot(abs(np.fft.fft(signal2.reshape(60, 8192)[0, :]) / 8192) ** 2)