In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate
import math
from pathlib import Path

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""

def SumSignal(signal, freq, radius):
    
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = signal.shape[-1] // nch
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 1 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt(
        (
            x_antenna.reshape((x_antenna.size, 1))
            - x_electron.reshape((1, x_electron.size))
        ) ** 2
        + (
            y_antenna.reshape((y_antenna.size, 1))
            - y_electron.reshape((1, y_electron.size))
        ) ** 2
    )
    
    correction = 1 * np.exp( 1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna))

    
    return (signal.reshape(nch, signal.size // nch) * correction).sum(axis=0)
    
def SaveSummedDataset(data, metadata, name):
    
    savefile = h5py.File(name, 'w')
    
    dataset = savefile.create_dataset('data', data=data)
    
    metagroup = savefile.create_group('meta')
    
    for i,key in enumerate(metadata.keys()):
        
        metagroup.create_dataset(key, data = np.array(metadata[key].array))
        
    savefile.close()

def SumDatasetSingleFrequency(MFdata, frequency, radius, samples=8192):
    
    metadata = pd.DataFrame(MFdata.metadata)
    data_shape = MFdata.data.shape
    nch = 60
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    nsample = samples
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    #gradb_freq_grid = gradb_freq_grid.mean(axis = 0).reshape((1, gradb_freq_grid.shape[1])).repeat(gradb_freq_grid.shape[0], axis=0)
    
    grad_b_angles = 1 * 2 * np.pi * np.arange(0, nsample, 1).reshape((1, 1, nsample)) * frequency / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt(
        (
            x_antenna.reshape((x_antenna.size, 1,))
            - x_electron.reshape((1, x_electron.size))
        ) ** 2
        + (
            y_antenna.reshape((y_antenna.size, 1,))
            - y_electron.reshape((1, y_electron.size))
        ) ** 2
    )
    
    correction = CalculateChannelWeightsSingleRad(d_grad_b, radius) * np.exp( 1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna))
    #correction = 1 * np.exp( 1j * 2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna))
    data = MFdata.data[:].reshape(data_shape[0], nch, data_shape[-1] // nch)[:, :, 0:samples]
    
    #print(correction.shape, data.shape)
    #sort_correction = np.zeros(data.shape, dtype=np.complex64)
    
    #for i, pair in enumerate(zip(metadata['x_min'].array, metadata['theta_min'].array)):
        
    #    index = np.argwhere(np.logical_and(r_grid == pair[0], theta_grid == pair[1])).squeeze()
    #    sort_correction[:, i, :] = correction[:, index[0], index[1], :]
    
    return np.sum(correction * data , axis=1)

def InterpolateGradB(radius, pitch_angle, gradb_freq_grid):
    
    radii = gradb_freq_grid['radii']
    angles = gradb_freq_grid['angles']
    gradb_data = gradb_freq_grid['freq']
    
    interpolator = scipy.interpolate.interp2d(radii, angles, gradb_data)
    
    interpolated_freq = interpolator(radius, pitch_angle)
    
    return interpolated_freq

def ShiftAndSum(signal_subset, radius, freq):
    
    nch = 60
    signal_subset = signal_subset.reshape((signal_subset.shape[0], nch, signal_subset.shape[-1] // 60))
    nsample = signal_subset.shape[-1]
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    wavelength_lo = 3e8 / 25.86e9
    fsample = 200e6
    
    grad_b_angles = 1 * 2 * np.pi * np.arange(0, nsample, 1) * freq / fsample
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    
    r_electron = radius
    theta_electron = 0 + grad_b_angles
    
    x_electron = r_electron * np.cos(theta_electron)
    y_electron = r_electron * np.sin(theta_electron)
    
    d_grad_b = np.sqrt((x_antenna.reshape((x_antenna.size, 1)) - x_electron.reshape((1, x_electron.size))) ** 2 + (y_antenna.reshape((y_antenna.size, 1)) - y_electron.reshape((1, y_electron.size))) ** 2)
    phase_shift = -2 * np.pi * (d_grad_b) / wavelength_lo + AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna)
    
    shifted_signal_subset = signal_subset * np.exp(-1j * phase_shift).reshape((1, *phase_shift.shape))
    
    
    return WeightedChannelSum(shifted_signal_subset, radius, d_grad_b)


def AntispiralCorrection(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1, 1, 1))
        - y_electron.reshape((1, y_electron.shape[0], y_electron.shape[1], y_electron.shape[2])),
        x_antenna.reshape((x_antenna.size, 1, 1, 1))
        - x_electron.reshape((1, x_electron.shape[0], x_electron.shape[1], x_electron.shape[2]))
    )
    
    return angles

def AntispiralCorrectionSingle(x_electron, y_electron, x_antenna, y_antenna):
    
    angles = np.arctan2(
        y_antenna.reshape((y_antenna.size, 1,))
        - y_electron.reshape((1, y_electron.size)),
        x_antenna.reshape((x_antenna.size, 1,))
        - x_electron.reshape((1, x_electron.size))
    )
    
    return angles

def WeightedChannelSum(shifted_signals, radius, d_grad_b):
    
    if radius == 0:
        weights = weights_norm = np.ones(d_grad_b.shape)
        normed_weights = weights
    else:
        weights = (radius / (d_grad_b)) 
        weights_norm = 60 / weights.sum(axis=0) # normalization, previously all weights were 1 and thus would've summed to 60
        normed_weights = weights_norm.reshape((1, weights_norm.size)) * weights
        #normed_weights = weights
    
    return (shifted_signals * normed_weights.reshape(1, *normed_weights.shape)).sum(axis=1)
    
def CalculateChannelWeights(d_grad_b, r_grid):
    
    channel_weights = np.ones(d_grad_b.shape)
    
    r_grid_is_not_zero = r_grid != 0
    
    weights = r_grid[r_grid_is_not_zero].reshape((1, r_grid_is_not_zero.sum(), 1)) / d_grad_b[:, r_grid_is_not_zero, : ]
    weights_norm = 60 / weights.sum(axis=0)
    normed_weights = weights_norm.reshape((1, *weights_norm.shape)) * weights

    counter = 0
    #print(normed_weights.shape)
    for i in range(r_grid_is_not_zero.shape[0]):
        for j in range(r_grid_is_not_zero.shape[1]):
            if r_grid_is_not_zero[i, j]:
                channel_weights[:, i, j, :] = normed_weights[:, counter, :]
                counter += 1
                #print(counter)
                
    return channel_weights

def CalculateChannelWeightsSingleRad(d_grad_b, radius):
    
    channel_weights = np.ones(d_grad_b.shape)
    
    if radius == 0:
        return channel_weights
    else:
        weights = radius / d_grad_b
        weights_norm = 60 / weights.sum(axis=0)
        normed_weights = weights_norm.reshape((1, *weights_norm.shape)) * weights
        return normed_weights
    
    

In [None]:
for i in (Path.home()/'group'/'project'/'datasets'/'data'/'dense_template_random').iterdir(): print(i)

In [None]:
os.listdir(os.path.join(DATAPATH, 'dense_template_grid'))

# load data

In [None]:
# signal data
h5file = h5py.File(os.path.join(DATAPATH, 'dense_template_random', '220124_sens_est_dense_grid_87.0_3cm_random.h5'), 'r')
data = mf.data.MFDataset(os.path.join(DATAPATH, 'dense_template_random', '220124_sens_est_dense_grid_87.0_3cm_random.h5'))
metadata = pd.DataFrame(data.metadata)

# grad-b correction data
gradb_freq_grid = np.load(os.path.join(PATH, 'results/mayfly', '211129_grad_b_frequency_grid_radius_angle.npz'))

In [None]:
data.data.shape

In [None]:
for i,key in enumerate(gradb_freq_grid.keys()): print(key)

In [None]:
print(gradb_freq_grid['freq'].shape)

In [None]:
r_grid, pa_grid = np.meshgrid(gradb_freq_grid['radii'], gradb_freq_grid['angles'])

freq_ind = np.argwhere(np.logical_and(r_grid == 0.03, pa_grid == 87.0)).squeeze()
gradb_freq = gradb_freq_grid['freq'][freq_ind[0], freq_ind[1]]
print(gradb_freq)

In [None]:
summed_data = SumDatasetSingleFrequency(data, gradb_freq, 0.03)

In [None]:
save_path = Path.home()/'group'/'project'/'datasets'/'data'/'bf'
name = '220301_dense_grid_87.0deg_3cm_random.npy'

np.save(save_path/name,summed_data)

In [None]:
for i in save_path.iterdir(): print(i)

In [None]:
summed_data.shape

In [None]:
plt.plot(summed_data[0, :].real)
plt.xlim(0, 100)

In [None]:
plt.plot(data.data[0, :].reshape(60, (3 * 8192 * 60) // 60)[0, 0:8192].real)
plt.xlim(0, 100)

In [None]:
signal1 = summed_data[5000, :]
signal2 = data.data[5000, :].reshape(60, (3 * 8192 * 60) // 60)[:, 0:8192].flatten()

print(np.sqrt(np.vdot(signal1, signal1) / np.vdot(signal2, signal2)))

In [None]:
plt.plot((abs(np.fft.fft(signal1) / 8192) ** 2 ) / (50 * 60 * np.sqrt(60)))

In [None]:
plt.plot(abs(np.fft.fft(signal2.reshape(60, 8192)[0, :]) / 8192) ** 2)